Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT Radix 2 implementation #15

Merged
merged 30 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ghworkflow.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: CI

on: push
on:
workflow_dispatch:
pull_request:
push:
branches: [master]

jobs:

Expand Down
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)
project(xtensor-signal)

set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
Expand Down Expand Up @@ -67,6 +67,8 @@ set(XTENSOR_SIGNAL_HEADERS
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/xtensor_signal.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/find_peaks.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/lfilter.hpp
${XTENSOR_SIGNAL_INCLUDE_DIR}/xtensor-signal/fft.hpp

)

add_library(xtensor-signal INTERFACE)
Expand All @@ -78,11 +80,18 @@ target_include_directories(xtensor-signal INTERFACE
target_link_libraries(xtensor-signal INTERFACE xtensor xsimd)

OPTION(BUILD_TESTS "xtensor test suite" OFF)
OPTION(XTENSOR_USE_TBB "Use tbb libraries" OFF)

if(BUILD_TESTS)
add_subdirectory(test)
endif()

if(XTENSOR_USE_TBB)
find_package(TBB REQUIRED)
message(STATUS "Found intel TBB: ${TBB_INCLUDE_DIRS}")
endif()


# Installation
# ============

Expand Down
99 changes: 99 additions & 0 deletions include/xtensor-signal/fft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

#ifdef XTENSOR_USE_TBB
#include <oneapi/tbb.h>
#endif
#include <stdexcept>
#include <xtensor/xarray.hpp>
#include <xtensor/xaxis_slice_iterator.hpp>
#include <xtensor/xbuilder.hpp>
#include <xtensor/xnoalias.hpp>
#include <xtensor/xview.hpp>
#include <xtl/xcomplex.hpp>

namespace xt::fft {
namespace detail {
template <class E,
typename std::enable_if<
xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e) {
using namespace xt::placeholders;
using namespace std::complex_literals;
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
auto N = e.size();
const bool powerOfTwo = !(N == 0) && !(N & (N - 1));
// check for power of 2
if (!powerOfTwo || N == 0) {
// TODO: Replace implementation with dft
XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2");
}
auto pi = xt::numeric_constants<precision>::PI;
xt::xtensor<value_type, 1> ev = e;
if (N <= 1) {
return ev;
} else {
#ifdef XTENSOR_USE_TBB
xt::xtensor<value_type, 1> even;
xt::xtensor<value_type, 1> odd;
oneapi::tbb::parallel_invoke(
[&] { even = fft(xt::view(ev, xt::range(0, _, 2))); },
[&] { odd = fft(xt::view(ev, xt::range(1, _, 2))); });
#else
auto even = fft(xt::view(ev, xt::range(0, _, 2)));
auto odd = fft(xt::view(ev, xt::range(1, _, 2)));
#endif

auto range = xt::arange<double>(N / 2);
auto exp = xt::exp(static_cast<value_type>(-2i) * pi * range / N);
auto t = exp * odd;
auto first_half = even + t;
auto second_half = even - t;
// TODO: should be a call to stack if performance was improved
auto spectrum = xt::xtensor<value_type, 1>::from_shape({N});
xt::view(spectrum, xt::range(0, N / 2)) = first_half;
xt::view(spectrum, xt::range(N / 2, N)) = second_half;
return spectrum;
}
}
} // namespace detail

/**
* @breif 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <class E,
typename std::enable_if<
xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e, std::ptrdiff_t axis = -1) {
using value_type = typename std::decay_t<E>::value_type;
using precision = typename value_type::value_type;
xt::xarray<std::complex<precision>> out = xt::eval(e);
auto saxis = xt::normalize_axis(e.dimension(), axis);
auto begin = xt::axis_slice_begin(out, saxis);
auto end = xt::axis_slice_end(out, saxis);
for (auto iter = begin; iter != end; iter++) {
xt::noalias(*iter) = detail::fft(*iter);
}
return out;
}

/**
* @breif 1D FFT of an Nd array along a specified axis
* @param e an Nd expression to be transformed to the fourier domain
* @param axis the axis along which to perform the 1D FFT
* @return a transformed xarray of the specified precision
*/
template <class E,
typename std::enable_if<
!xtl::is_complex<typename std::decay<E>::type::value_type>::value,
bool>::type = true>
inline auto fft(E &&e, std::ptrdiff_t axis = -1) {
using value_type = typename std::decay<E>::type::value_type;
return fft(xt::cast<std::complex<value_type>>(e), axis);
}

} // namespace xt::fft
11 changes: 10 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# The full license is in the file LICENSE, distributed with this software. #
############################################################################

cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)

if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
project(xtensor-signal-test)
Expand Down Expand Up @@ -40,15 +40,24 @@ set(XTENSOR_SIGNAL_TESTS
test_config.cpp
find_peaks_test.cpp
lfilter_test.cpp
fft_test.cpp
)

if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${_cxx_std_flag} /MP /bigobj")
endif()



file(COPY "test_data" DESTINATION "${CMAKE_BINARY_DIR}/test")

add_executable(test_xtensor_signal ${XTENSOR_SIGNAL_TESTS} ${XTENSOR_SIGNAL_HEADERS})
if(XTENSOR_USE_TBB)
target_compile_definitions(test_xtensor_signal PRIVATE XTENSOR_USE_TBB)
target_include_directories(test_xtensor_signal PRIVATE ${TBB_INCLUDE_DIRS})
target_link_libraries(test_xtensor_signal PRIVATE ${TBB_LIBRARIES})
endif()

target_link_libraries(test_xtensor_signal PRIVATE ZLIB::ZLIB xtensor-signal doctest::doctest ${CMAKE_THREAD_LIBS_INIT})

add_custom_target(xtest COMMAND ./test_xtensor_signal DEPENDS test_xtensor_signal)
49 changes: 49 additions & 0 deletions test/fft_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

#include "doctest/doctest.h"
#include "xtensor-signal/fft.hpp"
#include <xtensor/xio.hpp>

TEST_SUITE("fft") {

TEST_CASE("fft_single") {
bool powerOfTwo = !(8 == 0) && !(8 & (8 - 1));
xt::xtensor<float, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_double") {
xt::xtensor<double, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_csingle") {
xt::xtensor<std::complex<float>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<float, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}
TEST_CASE("fft_cdouble") {
xt::xtensor<std::complex<double>, 1> input = {1, 1, 1, 1, 0, 0, 0, 0};
xt::xtensor<double, 1> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input);
REQUIRE(xt::all(xt::isclose(xt::abs(result), expectation, .001)));
}

TEST_CASE("fft_double_axis0") {
xt::xarray<double> input = {{1, 1}, {1, 1}, {1, 1}, {1, 1},
{0, 0}, {0, 0}, {0, 0}, {0, 0}};
xt::xarray<double> expectation = {4.000, 2.613, 0.000, 1.082,
0.000, 1.082, 0.000, 2.613};
auto result = xt::fft::fft(input, 0);
auto first_column = xt::view(result, xt::all(), 0);
REQUIRE(xt::all(xt::isclose(xt::abs(first_column), expectation, .001)));
auto second_column = xt::view(result, xt::all(), 1);
REQUIRE(xt::all(xt::isclose(xt::abs(second_column), expectation, .001)));
}
}
Loading