LCOV - code coverage report
Current view: top level - src/FFT - FFT.h (source / functions) Coverage Total Hit
Test: final_report.info Lines: 100.0 % 2 2
Test Date: 2025-07-17 08:40:11 Functions: 54.5 % 22 12

            Line data    Source code
       1              : //
       2              : // Class FFT
       3              : //   The FFT class performs complex-to-complex,
       4              : //   real-to-complex on IPPL Fields.
       5              : //   FFT is templated on the type of transform to be performed,
       6              : //   the dimensionality of the Field to transform, and the
       7              : //   floating-point precision type of the Field (float or double).
       8              : //   Currently, we use heffte for taking the transforms and the class FFT
       9              : //   serves as an interface between IPPL and heffte. In making this interface,
      10              : //   we have referred Cabana library
      11              : //   https://github.com/ECP-copa/Cabana.
      12              : //
      13              : //
      14              : 
      15              : #ifndef IPPL_FFT_FFT_H
      16              : #define IPPL_FFT_FFT_H
      17              : 
      18              : #include <Kokkos_Complex.hpp>
      19              : #include <array>
      20              : #include <heffte_fft3d.h>
      21              : #include <heffte_fft3d_r2c.h>
      22              : #include <memory>
      23              : #include <type_traits>
      24              : 
      25              : #include "Utility/IpplException.h"
      26              : #include "Utility/ParameterList.h"
      27              : 
      28              : #include "Field/Field.h"
      29              : 
      30              : #include "FieldLayout/FieldLayout.h"
      31              : #include "Index/NDIndex.h"
      32              : 
      33              : namespace heffte {
      34              :     template <>
      35              :     struct is_ccomplex<Kokkos::complex<float>> : std::true_type {};
      36              : 
      37              :     template <>
      38              :     struct is_zcomplex<Kokkos::complex<double>> : std::true_type {};
      39              : }  // namespace heffte
      40              : 
      41              : namespace ippl {
      42              : 
      43              :     /**
      44              :        Tag classes for Fourier transforms
      45              :     */
      46              :     class CCTransform {};
      47              :     class RCTransform {};
      48              :     class SineTransform {};
      49              :     class CosTransform {};
      50              :     /**
      51              :        Tag classes for Cosine of type 1 transforms
      52              :     */
      53              :     class Cos1Transform {};
      54              : 
      55              :     enum FFTComm {
      56              :         a2av   = 0,
      57              :         a2a    = 1,
      58              :         p2p    = 2,
      59              :         p2p_pl = 3
      60              :     };
      61              : 
      62              :     enum TransformDirection {
      63              :         FORWARD,
      64              :         BACKWARD
      65              :     };
      66              : 
      67              :     namespace detail {
      68              :         /*!
      69              :          * Wrapper type for heFFTe backends, templated
      70              :          * on the Kokkos memory space
      71              :          */
      72              :         template <typename>
      73              :         struct HeffteBackendType;
      74              : 
      75              : #if defined(Heffte_ENABLE_FFTW)
      76              :         template <>
      77              :         struct HeffteBackendType<Kokkos::HostSpace> {
      78              :             using backend     = heffte::backend::fftw;
      79              :             using backendSine = heffte::backend::fftw_sin;
      80              :             using backendCos  = heffte::backend::fftw_cos;
      81              :             using backendCos1 = heffte::backend::fftw_cos1;
      82              :         };
      83              : #elif defined(Heffte_ENABLE_MKL)
      84              :         template <>
      85              :         struct HeffteBackendType<Kokkos::HostSpace> {
      86              :             using backend     = heffte::backend::mkl;
      87              :             using backendSine = heffte::backend::mkl_sin;
      88              :             using backendCos  = heffte::backend::mkl_cos;
      89              :         };
      90              : #endif
      91              : 
      92              : #ifdef Heffte_ENABLE_CUDA
      93              : #ifdef KOKKOS_ENABLE_CUDA
      94              :         template <>
      95              :         struct HeffteBackendType<Kokkos::CudaSpace> {
      96              :             using backend     = heffte::backend::cufft;
      97              :             using backendSine = heffte::backend::cufft_sin;
      98              :             using backendCos  = heffte::backend::cufft_cos;
      99              :             using backendCos1 = heffte::backend::cufft_cos1;
     100              :         };
     101              : #else
     102              : #error cuFFT backend is enabled for heFFTe but CUDA is not enabled for Kokkos!
     103              : #endif
     104              : #endif
     105              : 
     106              : #ifdef KOKKOS_ENABLE_HIP
     107              : #ifdef Heffte_ENABLE_ROCM
     108              :         template <>
     109              :         struct HeffteBackendType<Kokkos::HIPSpace> {
     110              :             using backend     = heffte::backend::rocfft;
     111              :             using backendSine = heffte::backend::rocfft_sin;
     112              :             using backendCos  = heffte::backend::rocfft_cos;
     113              :             using backendCos1 = heffte::backend::rocfft_cos1;
     114              :         };
     115              : #else
     116              :         template <>
     117              :         struct HeffteBackendType<Kokkos::HIPSpace> {
     118              :             using backend     = heffte::backend::stock;
     119              :             using backendSine = heffte::backend::stock_sin;
     120              :             using backendCos  = heffte::backend::stock_cos;
     121              :             using backendCos1 = heffte::backend::stock_cos1;
     122              :         };
     123              : #endif
     124              : #endif
     125              : 
     126              : #if !defined(Heffte_ENABLE_MKL) && !defined(Heffte_ENABLE_FFTW)
     127              :         /**
     128              :          * Use heFFTe's inbuilt 1D fft computation on CPUs if no
     129              :          * vendor specific or optimized backend is found
     130              :          */
     131              :         template <>
     132              :         struct HeffteBackendType<Kokkos::HostSpace> {
     133              :             using backend     = heffte::backend::stock;
     134              :             using backendSine = heffte::backend::stock_sin;
     135              :             using backendCos  = heffte::backend::stock_cos;
     136              :             using backendCos1 = heffte::backend::stock_cos1;
     137              :         };
     138              : #endif
     139              : 
     140              :     }  // namespace detail
     141              : 
     142              :     template <typename Field, template <typename...> class FFT, typename Backend,
     143              :               typename BufferType = typename Field::value_type>
     144              :     class FFTBase {
     145              :         constexpr static unsigned Dim = Field::dim;
     146              : 
     147              :     public:
     148              :         using heffteBackend = Backend;
     149              :         using workspace_t   = typename FFT<heffteBackend>::template buffer_container<BufferType>;
     150              :         using Layout_t      = FieldLayout<Dim>;
     151              : 
     152              :         FFTBase(const Layout_t& layout, const ParameterList& params);
     153            8 :         ~FFTBase() = default;
     154              : 
     155              :     protected:
     156            4 :         FFTBase() = default;
     157              : 
     158              :         void domainToBounds(const NDIndex<Dim>& domain, std::array<long long, 3>& low,
     159              :                             std::array<long long, 3>& high);
     160              :         void setup(const heffte::box3d<long long>& inbox, const heffte::box3d<long long>& outbox,
     161              :                    const ParameterList& params);
     162              : 
     163              :         std::shared_ptr<FFT<heffteBackend, long long>> heffte_m;
     164              :         workspace_t workspace_m;
     165              : 
     166              :         template <typename FieldType>
     167              :         using temp_view_type =
     168              :             typename Kokkos::View<typename FieldType::view_type::data_type, Kokkos::LayoutLeft,
     169              :                                   typename FieldType::memory_space>::uniform_type;
     170              :         temp_view_type<Field> tempField;
     171              :     };
     172              : 
     173              : #define IN_PLACE_FFT_BASE_CLASS(Field, Backend) \
     174              :     FFTBase<Field, heffte::fft3d,               \
     175              :             typename detail::HeffteBackendType<typename Field::memory_space>::Backend>
     176              : #define EXT_FFT_BASE_CLASS(Field, Backend, Type)                                       \
     177              :     FFTBase<Field, heffte::fft3d_r2c,                                                  \
     178              :             typename detail::HeffteBackendType<typename Field::memory_space>::Backend, \
     179              :             typename Type>
     180              : 
     181              :     /**
     182              :        Non-specialized FFT class.  We specialize based on Transform tag class
     183              :     */
     184              :     template <class Transform, typename Field>
     185              :     class FFT {};
     186              : 
     187              :     /**
     188              :        complex-to-complex FFT class
     189              :     */
     190              :     template <typename ComplexField>
     191              :     class FFT<CCTransform, ComplexField> : public IN_PLACE_FFT_BASE_CLASS(ComplexField, backend) {
     192              :         constexpr static unsigned Dim = ComplexField::dim;
     193              :         using Base                    = IN_PLACE_FFT_BASE_CLASS(ComplexField, backend);
     194              : 
     195              :     public:
     196              :         using Complex_t = typename ComplexField::value_type;
     197              : 
     198              :         using Base::Base;
     199              :         using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
     200              : 
     201              :         /*!
     202              :          * Warmup the FFT object by forward & backward FFT on an empty field
     203              :          * @param f Field whose transformation to compute (and overwrite)
     204              :          */
     205              :         void warmup(ComplexField& f);
     206              : 
     207              :         /*!
     208              :          * Perform in-place FFT
     209              :          * @param direction Forward or backward transformation
     210              :          * @param f Field whose transformation to compute (and overwrite)
     211              :          */
     212              :         void transform(TransformDirection direction, ComplexField& f);
     213              :     };
     214              : 
     215              :     /**
     216              :        real-to-complex FFT class
     217              :     */
     218              :     template <typename RealField>
     219              :     class FFT<RCTransform, RealField>
     220              :         : public EXT_FFT_BASE_CLASS(RealField, backend,
     221              :                                     Kokkos::complex<typename RealField::value_type>) {
     222              :         constexpr static unsigned Dim = RealField::dim;
     223              :         using Real_t                  = typename RealField::value_type;
     224              :         using Base                    = EXT_FFT_BASE_CLASS(RealField, backend,
     225              :                                                            Kokkos::complex<typename RealField::value_type>);
     226              : 
     227              :     public:
     228              :         using Complex_t    = Kokkos::complex<Real_t>;
     229              :         using ComplexField = typename Field<Complex_t, Dim, typename RealField::Mesh_t,
     230              :                                             typename RealField::Centering_t,
     231              :                                             typename RealField::execution_space>::uniform_type;
     232              : 
     233              :         using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
     234              : 
     235              :         /** Create a new FFT object with the layout for the input and output Fields
     236              :          * and parameters for heffte.
     237              :          */
     238              :         FFT(const Layout_t& layoutInput, const Layout_t& layoutOutput, const ParameterList& params);
     239              : 
     240              :         /*!
     241              :          * Warmup the FFT object by forward & backward FFT on an empty field
     242              :          * @param f Field whose transformation to compute
     243              :          * @param g Field in which to store the transformation
     244              :          */
     245              :         void warmup(RealField& f, ComplexField& g);
     246              : 
     247              :         /*!
     248              :          * Perform FFT
     249              :          * @param direction Forward or backward transformation
     250              :          * @param f Field whose transformation to compute
     251              :          * @param g Field in which to store the transformation
     252              :          */
     253              :         void transform(TransformDirection direction, RealField& f, ComplexField& g);
     254              : 
     255              :     private:
     256              :         typename Base::template temp_view_type<ComplexField> tempFieldComplex;
     257              :     };
     258              : 
     259              :     /**
     260              :        Sine transform class
     261              :     */
     262              :     template <typename Field>
     263              :     class FFT<SineTransform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendSine) {
     264              :         constexpr static unsigned Dim = Field::dim;
     265              :         using Base                    = IN_PLACE_FFT_BASE_CLASS(Field, backendSine);
     266              : 
     267              :     public:
     268              :         using Base::Base;
     269              :         using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
     270              : 
     271              :         /*!
     272              :          * Warmup the FFT object by forward & backward FFT on an empty field
     273              :          * @param f Field whose transformation to compute (and overwrite)
     274              :          */
     275              :         void warmup(Field& f);
     276              : 
     277              :         /*!
     278              :          * Perform in-place FFT
     279              :          * @param direction Forward or backward transformation
     280              :          * @param f Field whose transformation to compute (and overwrite)
     281              :          */
     282              :         void transform(TransformDirection direction, Field& f);
     283              :     };
     284              :     /**
     285              :        Cosine transform class
     286              :     */
     287              :     template <typename Field>
     288              :     class FFT<CosTransform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendCos) {
     289              :         constexpr static unsigned Dim = Field::dim;
     290              :         using Base                    = IN_PLACE_FFT_BASE_CLASS(Field, backendCos);
     291              : 
     292              :     public:
     293              :         using Base::Base;
     294              :         using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
     295              : 
     296              :         /*!
     297              :          * Warmup the FFT object by forward & backward FFT on an empty field
     298              :          * @param f Field whose transformation to compute (and overwrite)
     299              :          */
     300              :         void warmup(Field& f);
     301              : 
     302              :         /*!
     303              :          * Perform in-place FFT
     304              :          * @param direction Forward or backward transformation
     305              :          * @param f Field whose transformation to compute (and overwrite)
     306              :          */
     307              :         void transform(TransformDirection direction, Field& f);
     308              :     };
     309              :     /**
     310              :        Cosine type 1 transform class
     311              :     */
     312              :     template <typename Field>
     313              :     class FFT<Cos1Transform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendCos1) {
     314              :         constexpr static unsigned Dim = Field::dim;
     315              :         using Base                    = IN_PLACE_FFT_BASE_CLASS(Field, backendCos1);
     316              : 
     317              :     public:
     318              :         using Base::Base;
     319              :         using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
     320              : 
     321              :         /*!
     322              :          * Warmup the FFT object by forward & backward FFT on an empty field
     323              :          * @param f Field whose transformation to compute (and overwrite)
     324              :          */
     325              :         void warmup(Field& f);
     326              : 
     327              :         /*!
     328              :          * Perform in-place FFT
     329              :          * @param direction Forward or backward transformation
     330              :          * @param f Field whose transformation to compute (and overwrite)
     331              :          */
     332              :         void transform(TransformDirection direction, Field& f);
     333              :     };
     334              : }  // namespace ippl
     335              : 
     336              : #include "FFT/FFT.hpp"
     337              : 
     338              : #endif  // IPPL_FFT_FFT_H
     339              : 
     340              : // vi: set et ts=4 sw=4 sts=4:
     341              : // Local Variables:
     342              : // mode:c
     343              : // c-basic-offset: 4
     344              : // indent-tabs-mode: nil
     345              : // require-final-newline: nil
     346              : // End:
        

Generated by: LCOV version 2.0-1