Line data Source code
1 : //
2 : // Class FFT
3 : // The FFT class performs complex-to-complex,
4 : // real-to-complex on IPPL Fields.
5 : // FFT is templated on the type of transform to be performed,
6 : // the dimensionality of the Field to transform, and the
7 : // floating-point precision type of the Field (float or double).
8 : // Currently, we use heffte for taking the transforms and the class FFT
9 : // serves as an interface between IPPL and heffte. In making this interface,
10 : // we have referred Cabana library
11 : // https://github.com/ECP-copa/Cabana.
12 : //
13 : //
14 :
15 : #ifndef IPPL_FFT_FFT_H
16 : #define IPPL_FFT_FFT_H
17 :
18 : #include <Kokkos_Complex.hpp>
19 : #include <array>
20 : #include <heffte_fft3d.h>
21 : #include <heffte_fft3d_r2c.h>
22 : #include <memory>
23 : #include <type_traits>
24 :
25 : #include "Utility/IpplException.h"
26 : #include "Utility/ParameterList.h"
27 :
28 : #include "Field/Field.h"
29 :
30 : #include "FieldLayout/FieldLayout.h"
31 : #include "Index/NDIndex.h"
32 :
33 : namespace heffte {
34 : template <>
35 : struct is_ccomplex<Kokkos::complex<float>> : std::true_type {};
36 :
37 : template <>
38 : struct is_zcomplex<Kokkos::complex<double>> : std::true_type {};
39 : } // namespace heffte
40 :
41 : namespace ippl {
42 :
43 : /**
44 : Tag classes for Fourier transforms
45 : */
46 : class CCTransform {};
47 : class RCTransform {};
48 : class SineTransform {};
49 : class CosTransform {};
50 : /**
51 : Tag classes for Cosine of type 1 transforms
52 : */
53 : class Cos1Transform {};
54 :
55 : enum FFTComm {
56 : a2av = 0,
57 : a2a = 1,
58 : p2p = 2,
59 : p2p_pl = 3
60 : };
61 :
62 : enum TransformDirection {
63 : FORWARD,
64 : BACKWARD
65 : };
66 :
67 : namespace detail {
68 : /*!
69 : * Wrapper type for heFFTe backends, templated
70 : * on the Kokkos memory space
71 : */
72 : template <typename>
73 : struct HeffteBackendType;
74 :
75 : #if defined(Heffte_ENABLE_FFTW)
76 : template <>
77 : struct HeffteBackendType<Kokkos::HostSpace> {
78 : using backend = heffte::backend::fftw;
79 : using backendSine = heffte::backend::fftw_sin;
80 : using backendCos = heffte::backend::fftw_cos;
81 : using backendCos1 = heffte::backend::fftw_cos1;
82 : };
83 : #elif defined(Heffte_ENABLE_MKL)
84 : template <>
85 : struct HeffteBackendType<Kokkos::HostSpace> {
86 : using backend = heffte::backend::mkl;
87 : using backendSine = heffte::backend::mkl_sin;
88 : using backendCos = heffte::backend::mkl_cos;
89 : };
90 : #endif
91 :
92 : #ifdef Heffte_ENABLE_CUDA
93 : #ifdef KOKKOS_ENABLE_CUDA
94 : template <>
95 : struct HeffteBackendType<Kokkos::CudaSpace> {
96 : using backend = heffte::backend::cufft;
97 : using backendSine = heffte::backend::cufft_sin;
98 : using backendCos = heffte::backend::cufft_cos;
99 : using backendCos1 = heffte::backend::cufft_cos1;
100 : };
101 : #else
102 : #error cuFFT backend is enabled for heFFTe but CUDA is not enabled for Kokkos!
103 : #endif
104 : #endif
105 :
106 : #ifdef KOKKOS_ENABLE_HIP
107 : #ifdef Heffte_ENABLE_ROCM
108 : template <>
109 : struct HeffteBackendType<Kokkos::HIPSpace> {
110 : using backend = heffte::backend::rocfft;
111 : using backendSine = heffte::backend::rocfft_sin;
112 : using backendCos = heffte::backend::rocfft_cos;
113 : using backendCos1 = heffte::backend::rocfft_cos1;
114 : };
115 : #else
116 : template <>
117 : struct HeffteBackendType<Kokkos::HIPSpace> {
118 : using backend = heffte::backend::stock;
119 : using backendSine = heffte::backend::stock_sin;
120 : using backendCos = heffte::backend::stock_cos;
121 : using backendCos1 = heffte::backend::stock_cos1;
122 : };
123 : #endif
124 : #endif
125 :
126 : #if !defined(Heffte_ENABLE_MKL) && !defined(Heffte_ENABLE_FFTW)
127 : /**
128 : * Use heFFTe's inbuilt 1D fft computation on CPUs if no
129 : * vendor specific or optimized backend is found
130 : */
131 : template <>
132 : struct HeffteBackendType<Kokkos::HostSpace> {
133 : using backend = heffte::backend::stock;
134 : using backendSine = heffte::backend::stock_sin;
135 : using backendCos = heffte::backend::stock_cos;
136 : using backendCos1 = heffte::backend::stock_cos1;
137 : };
138 : #endif
139 :
140 : } // namespace detail
141 :
142 : template <typename Field, template <typename...> class FFT, typename Backend,
143 : typename BufferType = typename Field::value_type>
144 : class FFTBase {
145 : constexpr static unsigned Dim = Field::dim;
146 :
147 : public:
148 : using heffteBackend = Backend;
149 : using workspace_t = typename FFT<heffteBackend>::template buffer_container<BufferType>;
150 : using Layout_t = FieldLayout<Dim>;
151 :
152 : FFTBase(const Layout_t& layout, const ParameterList& params);
153 8 : ~FFTBase() = default;
154 :
155 : protected:
156 4 : FFTBase() = default;
157 :
158 : void domainToBounds(const NDIndex<Dim>& domain, std::array<long long, 3>& low,
159 : std::array<long long, 3>& high);
160 : void setup(const heffte::box3d<long long>& inbox, const heffte::box3d<long long>& outbox,
161 : const ParameterList& params);
162 :
163 : std::shared_ptr<FFT<heffteBackend, long long>> heffte_m;
164 : workspace_t workspace_m;
165 :
166 : template <typename FieldType>
167 : using temp_view_type =
168 : typename Kokkos::View<typename FieldType::view_type::data_type, Kokkos::LayoutLeft,
169 : typename FieldType::memory_space>::uniform_type;
170 : temp_view_type<Field> tempField;
171 : };
172 :
173 : #define IN_PLACE_FFT_BASE_CLASS(Field, Backend) \
174 : FFTBase<Field, heffte::fft3d, \
175 : typename detail::HeffteBackendType<typename Field::memory_space>::Backend>
176 : #define EXT_FFT_BASE_CLASS(Field, Backend, Type) \
177 : FFTBase<Field, heffte::fft3d_r2c, \
178 : typename detail::HeffteBackendType<typename Field::memory_space>::Backend, \
179 : typename Type>
180 :
181 : /**
182 : Non-specialized FFT class. We specialize based on Transform tag class
183 : */
184 : template <class Transform, typename Field>
185 : class FFT {};
186 :
187 : /**
188 : complex-to-complex FFT class
189 : */
190 : template <typename ComplexField>
191 : class FFT<CCTransform, ComplexField> : public IN_PLACE_FFT_BASE_CLASS(ComplexField, backend) {
192 : constexpr static unsigned Dim = ComplexField::dim;
193 : using Base = IN_PLACE_FFT_BASE_CLASS(ComplexField, backend);
194 :
195 : public:
196 : using Complex_t = typename ComplexField::value_type;
197 :
198 : using Base::Base;
199 : using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
200 :
201 : /*!
202 : * Warmup the FFT object by forward & backward FFT on an empty field
203 : * @param f Field whose transformation to compute (and overwrite)
204 : */
205 : void warmup(ComplexField& f);
206 :
207 : /*!
208 : * Perform in-place FFT
209 : * @param direction Forward or backward transformation
210 : * @param f Field whose transformation to compute (and overwrite)
211 : */
212 : void transform(TransformDirection direction, ComplexField& f);
213 : };
214 :
215 : /**
216 : real-to-complex FFT class
217 : */
218 : template <typename RealField>
219 : class FFT<RCTransform, RealField>
220 : : public EXT_FFT_BASE_CLASS(RealField, backend,
221 : Kokkos::complex<typename RealField::value_type>) {
222 : constexpr static unsigned Dim = RealField::dim;
223 : using Real_t = typename RealField::value_type;
224 : using Base = EXT_FFT_BASE_CLASS(RealField, backend,
225 : Kokkos::complex<typename RealField::value_type>);
226 :
227 : public:
228 : using Complex_t = Kokkos::complex<Real_t>;
229 : using ComplexField = typename Field<Complex_t, Dim, typename RealField::Mesh_t,
230 : typename RealField::Centering_t,
231 : typename RealField::execution_space>::uniform_type;
232 :
233 : using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
234 :
235 : /** Create a new FFT object with the layout for the input and output Fields
236 : * and parameters for heffte.
237 : */
238 : FFT(const Layout_t& layoutInput, const Layout_t& layoutOutput, const ParameterList& params);
239 :
240 : /*!
241 : * Warmup the FFT object by forward & backward FFT on an empty field
242 : * @param f Field whose transformation to compute
243 : * @param g Field in which to store the transformation
244 : */
245 : void warmup(RealField& f, ComplexField& g);
246 :
247 : /*!
248 : * Perform FFT
249 : * @param direction Forward or backward transformation
250 : * @param f Field whose transformation to compute
251 : * @param g Field in which to store the transformation
252 : */
253 : void transform(TransformDirection direction, RealField& f, ComplexField& g);
254 :
255 : private:
256 : typename Base::template temp_view_type<ComplexField> tempFieldComplex;
257 : };
258 :
259 : /**
260 : Sine transform class
261 : */
262 : template <typename Field>
263 : class FFT<SineTransform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendSine) {
264 : constexpr static unsigned Dim = Field::dim;
265 : using Base = IN_PLACE_FFT_BASE_CLASS(Field, backendSine);
266 :
267 : public:
268 : using Base::Base;
269 : using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
270 :
271 : /*!
272 : * Warmup the FFT object by forward & backward FFT on an empty field
273 : * @param f Field whose transformation to compute (and overwrite)
274 : */
275 : void warmup(Field& f);
276 :
277 : /*!
278 : * Perform in-place FFT
279 : * @param direction Forward or backward transformation
280 : * @param f Field whose transformation to compute (and overwrite)
281 : */
282 : void transform(TransformDirection direction, Field& f);
283 : };
284 : /**
285 : Cosine transform class
286 : */
287 : template <typename Field>
288 : class FFT<CosTransform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendCos) {
289 : constexpr static unsigned Dim = Field::dim;
290 : using Base = IN_PLACE_FFT_BASE_CLASS(Field, backendCos);
291 :
292 : public:
293 : using Base::Base;
294 : using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
295 :
296 : /*!
297 : * Warmup the FFT object by forward & backward FFT on an empty field
298 : * @param f Field whose transformation to compute (and overwrite)
299 : */
300 : void warmup(Field& f);
301 :
302 : /*!
303 : * Perform in-place FFT
304 : * @param direction Forward or backward transformation
305 : * @param f Field whose transformation to compute (and overwrite)
306 : */
307 : void transform(TransformDirection direction, Field& f);
308 : };
309 : /**
310 : Cosine type 1 transform class
311 : */
312 : template <typename Field>
313 : class FFT<Cos1Transform, Field> : public IN_PLACE_FFT_BASE_CLASS(Field, backendCos1) {
314 : constexpr static unsigned Dim = Field::dim;
315 : using Base = IN_PLACE_FFT_BASE_CLASS(Field, backendCos1);
316 :
317 : public:
318 : using Base::Base;
319 : using typename Base::heffteBackend, typename Base::workspace_t, typename Base::Layout_t;
320 :
321 : /*!
322 : * Warmup the FFT object by forward & backward FFT on an empty field
323 : * @param f Field whose transformation to compute (and overwrite)
324 : */
325 : void warmup(Field& f);
326 :
327 : /*!
328 : * Perform in-place FFT
329 : * @param direction Forward or backward transformation
330 : * @param f Field whose transformation to compute (and overwrite)
331 : */
332 : void transform(TransformDirection direction, Field& f);
333 : };
334 : } // namespace ippl
335 :
336 : #include "FFT/FFT.hpp"
337 :
338 : #endif // IPPL_FFT_FFT_H
339 :
340 : // vi: set et ts=4 sw=4 sts=4:
341 : // Local Variables:
342 : // mode:c
343 : // c-basic-offset: 4
344 : // indent-tabs-mode: nil
345 : // require-final-newline: nil
346 : // End:
|