Branch data Line data Source code
1 : : //
2 : : // Class FFTPeriodicPoissonSolver
3 : : // Solves the periodic Poisson problem using Fourier transforms
4 : : // cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
5 : : //
6 : : //
7 : :
8 : : namespace ippl {
9 : :
10 : : template <typename FieldLHS, typename FieldRHS>
11 : 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
12 : 0 : Base::setRhs(rhs);
13 : 0 : initialize();
14 : 0 : }
15 : :
16 : : template <typename FieldLHS, typename FieldRHS>
17 : 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::initialize() {
18 [ # # ]: 0 : const Layout_t& layout_r = this->rhs_mp->getLayout();
19 : 0 : domain_m = layout_r.getDomain();
20 : :
21 : 0 : NDIndex<Dim> domainComplex;
22 : :
23 [ # # ]: 0 : vector_type hComplex;
24 [ # # ]: 0 : vector_type originComplex;
25 : :
26 : 0 : std::array<bool, Dim> isParallel = layout_r.isParallel();
27 [ # # ]: 0 : for (unsigned d = 0; d < Dim; ++d) {
28 : 0 : hComplex[d] = 1.0;
29 : 0 : originComplex[d] = 0.0;
30 : :
31 [ # # # # : 0 : if (this->params_m.template get<int>("r2c_direction") == (int)d) {
# # ]
32 : 0 : domainComplex[d] = Index(domain_m[d].length() / 2 + 1);
33 : : } else {
34 : 0 : domainComplex[d] = Index(domain_m[d].length());
35 : : }
36 : : }
37 : :
38 [ # # ]: 0 : layoutComplex_mp = std::make_shared<Layout_t>(layout_r.comm, domainComplex, isParallel);
39 : :
40 [ # # ]: 0 : mesh_type meshComplex(domainComplex, hComplex, originComplex);
41 : :
42 [ # # ]: 0 : fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
43 : :
44 [ # # # # : 0 : if (this->params_m.template get<int>("output_type") == Base::GRAD) {
# # ]
45 [ # # ]: 0 : tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
46 : : }
47 : :
48 [ # # ]: 0 : fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
49 [ # # ]: 0 : fft_mp->warmup(*this->rhs_mp, fieldComplex_m); // warmup the FFT object
50 : 0 : }
51 : :
52 : : template <typename FieldLHS, typename FieldRHS>
53 : 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
54 [ # # ]: 0 : fft_mp->transform(FORWARD, *this->rhs_mp, fieldComplex_m);
55 : :
56 [ # # ]: 0 : auto view = fieldComplex_m.getView();
57 : 0 : const int nghost = fieldComplex_m.getNghost();
58 : :
59 : 0 : scalar_type pi = Kokkos::numbers::pi_v<scalar_type>;
60 : 0 : const mesh_type& mesh = this->rhs_mp->get_mesh();
61 [ # # ]: 0 : const auto& lDomComplex = layoutComplex_mp->getLocalNDIndex();
62 : : using vector_type = typename mesh_type::vector_type;
63 : 0 : const vector_type& origin = mesh.getOrigin();
64 [ # # ]: 0 : const vector_type& hx = mesh.getMeshSpacing();
65 : :
66 [ # # ]: 0 : vector_type rmax;
67 [ # # ]: 0 : Vector<int, Dim> N;
68 [ # # ]: 0 : for (size_t d = 0; d < Dim; ++d) {
69 : 0 : N[d] = domain_m[d].length();
70 : 0 : rmax[d] = origin[d] + (N[d] * hx[d]);
71 : : }
72 : :
73 : : // Based on output_type calculate either solution
74 : : // or gradient
75 : :
76 : : using index_array_type = typename RangePolicy<Dim>::index_array_type;
77 [ # # # # : 0 : switch (this->params_m.template get<int>("output_type")) {
# # # ]
78 : 0 : case Base::SOL: {
79 [ # # # # ]: 0 : ippl::parallel_for(
80 : 0 : "Solution FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
81 [ # # # # : 0 : KOKKOS_LAMBDA(const index_array_type& args) {
# # # # #
# ]
82 [ # # # # ]: 0 : Vector<int, Dim> iVec = args - nghost;
83 [ # # ]: 0 : for (unsigned d = 0; d < Dim; ++d) {
84 : 0 : iVec[d] += lDomComplex[d].first();
85 : : }
86 : :
87 [ # # ]: 0 : Vector_t kVec;
88 : :
89 [ # # ]: 0 : for (size_t d = 0; d < Dim; ++d) {
90 : 0 : const scalar_type Len = rmax[d] - origin[d];
91 : 0 : bool shift = (iVec[d] > (N[d] / 2));
92 : 0 : kVec[d] = 2 * pi / Len * (iVec[d] - shift * N[d]);
93 : : }
94 : :
95 : 0 : scalar_type Dr = 0;
96 [ # # ]: 0 : for (unsigned d = 0; d < Dim; ++d) {
97 : 0 : Dr += kVec[d] * kVec[d];
98 : : }
99 : :
100 : 0 : bool isNotZero = (Dr != 0.0);
101 : 0 : scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
102 : :
103 [ # # ]: 0 : apply(view, args) *= factor;
104 : 0 : });
105 : :
106 [ # # ]: 0 : fft_mp->transform(BACKWARD, *this->rhs_mp, fieldComplex_m);
107 : :
108 : 0 : break;
109 : : }
110 : 0 : case Base::GRAD: {
111 : : // Compute gradient in Fourier space and then
112 : : // take inverse FFT.
113 : :
114 : 0 : Complex_t imag = {0.0, 1.0};
115 [ # # ]: 0 : auto tempview = tempFieldComplex_m.getView();
116 [ # # ]: 0 : auto viewRhs = this->rhs_mp->getView();
117 [ # # ]: 0 : auto viewLhs = this->lhs_mp->getView();
118 : 0 : const int nghostL = this->lhs_mp->getNghost();
119 : :
120 [ # # ]: 0 : for (size_t gd = 0; gd < Dim; ++gd) {
121 [ # # # # ]: 0 : ippl::parallel_for(
122 : 0 : "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
123 [ # # # # : 0 : KOKKOS_LAMBDA(const index_array_type& args) {
# # # # #
# # # # #
# # ]
124 [ # # # # ]: 0 : Vector<int, Dim> iVec = args - nghost;
125 [ # # ]: 0 : for (unsigned d = 0; d < Dim; ++d) {
126 : 0 : iVec[d] += lDomComplex[d].first();
127 : : }
128 : :
129 [ # # ]: 0 : Vector_t kVec;
130 : :
131 [ # # ]: 0 : for (size_t d = 0; d < Dim; ++d) {
132 : 0 : const scalar_type Len = rmax[d] - origin[d];
133 : 0 : bool shift = (iVec[d] > (N[d] / 2));
134 : 0 : bool notMid = (iVec[d] != (N[d] / 2));
135 : : // For the noMid part see
136 : : // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
137 : 0 : kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
138 : : }
139 : :
140 : 0 : scalar_type Dr = 0;
141 [ # # ]: 0 : for (unsigned d = 0; d < Dim; ++d) {
142 : 0 : Dr += kVec[d] * kVec[d];
143 : : }
144 : :
145 [ # # # # ]: 0 : apply(tempview, args) = apply(view, args);
146 : :
147 : 0 : bool isNotZero = (Dr != 0.0);
148 : 0 : scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
149 : :
150 [ # # ]: 0 : apply(tempview, args) *= -(imag * kVec[gd] * factor);
151 : 0 : });
152 : :
153 [ # # ]: 0 : fft_mp->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
154 : :
155 [ # # # # ]: 0 : ippl::parallel_for(
156 : : "Assign Gradient FFTPeriodicPoissonSolver",
157 : 0 : getRangePolicy(viewLhs, nghostL),
158 [ # # # # : 0 : KOKKOS_LAMBDA(const index_array_type& args) {
# # # # ]
159 : 0 : apply(viewLhs, args)[gd] = apply(viewRhs, args);
160 : : });
161 : : }
162 : :
163 : 0 : break;
164 : 0 : }
165 : :
166 : 0 : default:
167 [ # # # # : 0 : throw IpplException("FFTPeriodicPoissonSolver::solve", "Unrecognized output_type");
# # ]
168 : : }
169 : 0 : }
170 : : } // namespace ippl
|