Line data Source code
1 : //
2 : // Class FFTPeriodicPoissonSolver
3 : // Solves the periodic Poisson problem using Fourier transforms
4 : // cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
5 : //
6 : //
7 :
8 : #ifndef IPPL_FFT_PERIODIC_POISSON_SOLVER_H
9 : #define IPPL_FFT_PERIODIC_POISSON_SOLVER_H
10 :
11 : #include <Kokkos_MathematicalConstants.hpp>
12 :
13 : #include "Types/ViewTypes.h"
14 :
15 : #include "FFT/FFT.h"
16 : #include "FieldLayout/FieldLayout.h"
17 : #include "Index/NDIndex.h"
18 : #include "Poisson.h"
19 :
20 : namespace ippl {
21 :
22 : template <typename FieldLHS, typename FieldRHS>
23 : class FFTPeriodicPoissonSolver : public Poisson<FieldLHS, FieldRHS> {
24 : constexpr static unsigned Dim = FieldLHS::dim;
25 : using Trhs = typename FieldRHS::value_type;
26 : using mesh_type = typename FieldRHS::Mesh_t;
27 :
28 : public:
29 : using Field_t = FieldRHS;
30 : using FFT_t = FFT<RCTransform, FieldRHS>;
31 : using Complex_t = typename FFT_t::Complex_t;
32 : using CxField_t = typename FFT_t::ComplexField;
33 : using Layout_t = FieldLayout<Dim>;
34 : using Vector_t = Vector<Trhs, Dim>;
35 :
36 : using Base = Poisson<FieldLHS, FieldRHS>;
37 : using typename Base::lhs_type, typename Base::rhs_type;
38 : using scalar_type = typename FieldLHS::Mesh_t::value_type;
39 : using vector_type = typename FieldLHS::Mesh_t::vector_type;
40 :
41 0 : FFTPeriodicPoissonSolver()
42 0 : : Base() {
43 : using T = typename FieldLHS::value_type::value_type;
44 : static_assert(std::is_floating_point<T>::value, "Not a floating point type");
45 :
46 0 : setDefaultParameters();
47 0 : }
48 :
49 : FFTPeriodicPoissonSolver(lhs_type& lhs, rhs_type& rhs)
50 : : Base(lhs, rhs) {
51 : using T = typename FieldLHS::value_type::value_type;
52 : static_assert(std::is_floating_point<T>::value, "Not a floating point type");
53 :
54 : setDefaultParameters();
55 : }
56 :
57 : //~FFTPeriodicPoissonSolver() {}
58 :
59 : void setRhs(rhs_type& rhs) override;
60 :
61 : void solve() override;
62 :
63 : private:
64 : void initialize();
65 :
66 : std::shared_ptr<FFT_t> fft_mp;
67 : CxField_t fieldComplex_m;
68 : CxField_t tempFieldComplex_m;
69 : NDIndex<Dim> domain_m;
70 : std::shared_ptr<Layout_t> layoutComplex_mp;
71 :
72 : protected:
73 0 : virtual void setDefaultParameters() override {
74 : using heffteBackend = typename FFT_t::heffteBackend;
75 0 : heffte::plan_options opts = heffte::default_options<heffteBackend>();
76 0 : this->params_m.add("use_pencils", opts.use_pencils);
77 0 : this->params_m.add("use_reorder", opts.use_reorder);
78 0 : this->params_m.add("use_gpu_aware", opts.use_gpu_aware);
79 0 : this->params_m.add("r2c_direction", 0);
80 :
81 0 : switch (opts.algorithm) {
82 0 : case heffte::reshape_algorithm::alltoall:
83 0 : this->params_m.add("comm", a2a);
84 0 : break;
85 0 : case heffte::reshape_algorithm::alltoallv:
86 0 : this->params_m.add("comm", a2av);
87 0 : break;
88 0 : case heffte::reshape_algorithm::p2p:
89 0 : this->params_m.add("comm", p2p);
90 0 : break;
91 0 : case heffte::reshape_algorithm::p2p_plined:
92 0 : this->params_m.add("comm", p2p_pl);
93 0 : break;
94 0 : default:
95 0 : throw IpplException("FFTPeriodicPoissonSolver::setDefaultParameters",
96 : "Unrecognized heffte communication type");
97 : }
98 0 : }
99 : };
100 : } // namespace ippl
101 :
102 : #include "PoissonSolvers/FFTPeriodicPoissonSolver.hpp"
103 : #endif
|