Line data Source code
1 : //
2 : // Class FFTTruncatedGreenPeriodicPoissonSolver
3 : // Poisson solver for periodic boundaries, based on FFTs.
4 : // Solves laplace(phi) = -rho, and E = -grad(phi).
5 : //
6 : // Uses a convolution with a Green's function given by:
7 : // G(r) = forceConstant * erf(alpha * r) / r,
8 : // alpha = controls long-range interaction.
9 : //
10 : //
11 :
12 : namespace ippl {
13 :
14 : /////////////////////////////////////////////////////////////////////////
15 : // constructor and destructor
16 :
17 : template <typename FieldLHS, typename FieldRHS>
18 0 : FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver()
19 : : Base()
20 0 : , mesh_mp(nullptr)
21 0 : , layout_mp(nullptr)
22 0 : , meshComplex_m(nullptr)
23 0 : , layoutComplex_m(nullptr) {
24 0 : FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
25 0 : }
26 :
27 : template <typename FieldLHS, typename FieldRHS>
28 : FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver(rhs_type& rhs, ParameterList& params)
29 : : mesh_mp(nullptr)
30 : , layout_mp(nullptr)
31 : , meshComplex_m(nullptr)
32 : , layoutComplex_m(nullptr) {
33 : FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
34 :
35 : this->params_m.merge(params);
36 : this->params_m.update("output_type", Base::SOL);
37 :
38 : FFTTruncatedGreenPeriodicPoissonSolver::setRhs(rhs);
39 : }
40 :
41 : template <typename FieldLHS, typename FieldRHS>
42 : FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver(lhs_type& lhs, rhs_type& rhs, ParameterList& params)
43 : : mesh_mp(nullptr)
44 : , layout_mp(nullptr)
45 : , meshComplex_m(nullptr)
46 : , layoutComplex_m(nullptr) {
47 : FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
48 :
49 : this->params_m.merge(params);
50 :
51 : this->setLhs(lhs);
52 : FFTTruncatedGreenPeriodicPoissonSolver::setRhs(rhs);
53 : }
54 :
55 : template <typename FieldLHS, typename FieldRHS>
56 0 : void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
57 0 : Base::setRhs(rhs);
58 0 : initializeFields();
59 0 : }
60 :
61 : /////////////////////////////////////////////////////////////////////////
62 : // initializeFields method, called in constructor
63 :
64 : template <typename FieldLHS, typename FieldRHS>
65 0 : void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::initializeFields() {
66 : static_assert(Dim == 3, "Dimension other than 3 not supported in FFTTruncatedGreenPeriodicPoissonSolver!");
67 :
68 : // get layout and mesh
69 0 : layout_mp = &(this->rhs_mp->getLayout());
70 0 : mesh_mp = &(this->rhs_mp->get_mesh());
71 0 : mpi::Communicator comm = layout_mp->comm;
72 :
73 : // get mesh spacing
74 0 : hr_m = mesh_mp->getMeshSpacing();
75 :
76 : // get origin
77 0 : Vector_t origin = mesh_mp->getOrigin();
78 :
79 : // create domain for the real fields
80 0 : domain_m = layout_mp->getDomain();
81 :
82 : // get the mesh spacings and sizes for each dimension
83 0 : for (unsigned int i = 0; i < Dim; ++i) {
84 0 : nr_m[i] = domain_m[i].length();
85 : }
86 :
87 : // define decomposition (parallel / serial)
88 0 : std::array<bool, Dim> isParallel = layout_mp->isParallel();
89 :
90 : // create the domain for the transformed (complex) fields
91 : // since we use HeFFTe for the transforms it doesn't require permuting to the right
92 : // one of the dimensions has only (n/2 +1) as our original fields are fully real
93 : // the dimension is given by the user via r2c_direction
94 0 : unsigned int RCDirection = this->params_m.template get<int>("r2c_direction");
95 0 : for (unsigned int i = 0; i < Dim; ++i) {
96 0 : if (i == RCDirection)
97 0 : domainComplex_m[RCDirection] = Index(nr_m[RCDirection] / 2 + 1);
98 : else
99 0 : domainComplex_m[i] = Index(nr_m[i]);
100 : }
101 :
102 : // create mesh and layout for the real to complex FFT transformed fields
103 : using mesh_type = typename lhs_type::Mesh_t;
104 0 : meshComplex_m = std::unique_ptr<mesh_type>(new mesh_type(domainComplex_m, hr_m, origin));
105 0 : layoutComplex_m =
106 0 : std::unique_ptr<FieldLayout_t>(new FieldLayout_t(comm, domainComplex_m, isParallel));
107 :
108 : // initialize fields
109 0 : grn_m.initialize(*mesh_mp, *layout_mp);
110 0 : rhotr_m.initialize(*meshComplex_m, *layoutComplex_m);
111 0 : grntr_m.initialize(*meshComplex_m, *layoutComplex_m);
112 0 : tempFieldComplex_m.initialize(*meshComplex_m, *layoutComplex_m);
113 :
114 : // create the FFT object
115 0 : fft_m = std::make_unique<FFT_t>(*layout_mp, *layoutComplex_m, this->params_m);
116 0 : fft_m->warmup(grn_m, grntr_m); // warmup the FFT object
117 :
118 : // these are fields that are used for calculating the Green's function
119 0 : for (unsigned int d = 0; d < Dim; ++d) {
120 0 : grnIField_m[d].initialize(*mesh_mp, *layout_mp);
121 :
122 : // get number of ghost points and the Kokkos view to iterate over field
123 0 : auto view = grnIField_m[d].getView();
124 0 : const int nghost = grnIField_m[d].getNghost();
125 0 : const auto& ldom = layout_mp->getLocalNDIndex();
126 :
127 : // the length of the physical domain
128 0 : const int size = nr_m[d];
129 :
130 : // Kokkos parallel for loop to initialize grnIField[d]
131 0 : switch (d) {
132 0 : case 0:
133 0 : Kokkos::parallel_for(
134 : "Helper index Green field initialization",
135 0 : ippl::getRangePolicy(view, nghost),
136 0 : KOKKOS_LAMBDA(const int i, const int j, const int k) {
137 : // go from local indices to global
138 0 : const int ig = i + ldom[0].first() - nghost;
139 0 : const int jg = j + ldom[1].first() - nghost;
140 0 : const int kg = k + ldom[2].first() - nghost;
141 :
142 : // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
143 0 : const bool outsideN = (ig >= size / 2);
144 0 : view(i, j, k) = (size * outsideN - ig) * (size * outsideN - ig);
145 :
146 : // add 1.0 if at (0,0,0) to avoid singularity
147 0 : const bool isOrig = ((ig == 0) && (jg == 0) && (kg == 0));
148 0 : view(i, j, k) += isOrig * 1.0;
149 : });
150 0 : break;
151 0 : case 1:
152 0 : Kokkos::parallel_for(
153 : "Helper index Green field initialization",
154 0 : ippl::getRangePolicy(view, nghost),
155 0 : KOKKOS_LAMBDA(const int i, const int j, const int k) {
156 : // go from local indices to global
157 0 : const int jg = j + ldom[1].first() - nghost;
158 :
159 : // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
160 0 : const bool outsideN = (jg >= size / 2);
161 0 : view(i, j, k) = (size * outsideN - jg) * (size * outsideN - jg);
162 : });
163 0 : break;
164 0 : case 2:
165 0 : Kokkos::parallel_for(
166 : "Helper index Green field initialization",
167 0 : ippl::getRangePolicy(view, nghost),
168 0 : KOKKOS_LAMBDA(const int i, const int j, const int k) {
169 : // go from local indices to global
170 0 : const int kg = k + ldom[2].first() - nghost;
171 :
172 : // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
173 0 : const bool outsideN = (kg >= size / 2);
174 0 : view(i, j, k) = (size * outsideN - kg) * (size * outsideN - kg);
175 : });
176 0 : break;
177 : }
178 : }
179 :
180 : // call greensFunction and we will get the transformed G in the class attribute
181 : // this is done in initialization so that we already have the precomputed fct
182 : // for all timesteps (green's fct will only change if mesh size changes)
183 :
184 0 : greensFunction();
185 0 : };
186 :
187 : /////////////////////////////////////////////////////////////////////////
188 : // compute electric potential by solving Poisson's eq given a field rho and mesh spacings hr
189 : template <typename FieldLHS, typename FieldRHS>
190 0 : void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
191 : // get the output type (sol, grad, or sol & grad)
192 0 : const int out = this->params_m.template get<int>("output_type");
193 :
194 : // set the mesh & spacing, which may change each timestep
195 0 : mesh_mp = &(this->rhs_mp->get_mesh());
196 :
197 : // check whether the mesh spacing has changed with respect to the old one
198 : // if yes, update and set green flag to true
199 0 : bool green = false;
200 0 : for (unsigned int i = 0; i < Dim; ++i) {
201 0 : if (hr_m[i] != mesh_mp->getMeshSpacing(i)) {
202 0 : hr_m[i] = mesh_mp->getMeshSpacing(i);
203 0 : green = true;
204 : }
205 : }
206 :
207 : // set mesh spacing on the other grids again
208 0 : meshComplex_m->setMeshSpacing(hr_m);
209 :
210 : // forward FFT of the charge density field on doubled grid
211 0 : rhotr_m = 0.0;
212 0 : fft_m->transform(FORWARD, *(this->rhs_mp), rhotr_m);
213 :
214 : // call greensFunction to recompute if the mesh spacing has changed
215 0 : if (green) {
216 0 : greensFunction();
217 : }
218 :
219 : // multiply FFT(rho2)*FFT(green)
220 : // convolution becomes multiplication in FFT
221 0 : rhotr_m = -rhotr_m * grntr_m;
222 :
223 : using index_array_type = typename RangePolicy<Dim>::index_array_type;
224 0 : if ((out == Base::GRAD) || (out == Base::SOL_AND_GRAD)) {
225 : // Compute gradient in Fourier space and then
226 : // take inverse FFT.
227 :
228 0 : const Trhs pi = Kokkos::numbers::pi_v<Trhs>;
229 0 : Kokkos::complex<Trhs> imag = {0.0, 1.0};
230 :
231 0 : auto view = rhotr_m.getView();
232 0 : const int nghost = rhotr_m.getNghost();
233 0 : auto tempview = tempFieldComplex_m.getView();
234 0 : auto viewRhs = this->rhs_mp->getView();
235 0 : auto viewLhs = this->lhs_mp->getView();
236 0 : const int nghostL = this->lhs_mp->getNghost();
237 0 : const auto& lDomComplex = layoutComplex_m->getLocalNDIndex();
238 :
239 : // define some member variables in local scope for the parallel_for
240 0 : Vector_t hsize = hr_m;
241 0 : Vector<int, Dim> N = nr_m;
242 0 : Vector_t origin = mesh_mp->getOrigin();
243 :
244 0 : for (size_t gd = 0; gd < Dim; ++gd) {
245 0 : ippl::parallel_for(
246 0 : "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
247 0 : KOKKOS_LAMBDA(const index_array_type& args) {
248 0 : Vector<int, Dim> iVec = args - nghost;
249 :
250 0 : for (unsigned d = 0; d < Dim; ++d) {
251 0 : iVec[d] += lDomComplex[d].first();
252 : }
253 :
254 0 : Vector_t kVec;
255 :
256 0 : for (size_t d = 0; d < Dim; ++d) {
257 0 : const Trhs Len = N[d] * hsize[d];
258 0 : bool shift = (iVec[d] > (N[d] / 2));
259 0 : bool notMid = (iVec[d] != (N[d] / 2));
260 : // For the noMid part see
261 : // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
262 0 : kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
263 : }
264 :
265 0 : Trhs Dr = 0;
266 0 : for (unsigned d = 0; d < Dim; ++d) {
267 0 : Dr += kVec[d] * kVec[d];
268 : }
269 :
270 0 : apply(tempview, args) = apply(view, args);
271 :
272 0 : bool isNotZero = (Dr != 0.0);
273 :
274 0 : apply(tempview, args) *= -(isNotZero * imag * kVec[gd]);
275 0 : });
276 :
277 0 : fft_m->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
278 :
279 0 : ippl::parallel_for(
280 0 : "Assign Gradient FFTPeriodicPoissonSolver", getRangePolicy(viewLhs, nghostL),
281 0 : KOKKOS_LAMBDA(const index_array_type& args) {
282 0 : apply(viewLhs, args)[gd] = apply(viewRhs, args);
283 : });
284 : }
285 :
286 : // normalization is double counted due to 2 transforms
287 0 : *(this->lhs_mp) = *(this->lhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
288 : // discretization of integral requires h^3 factor
289 0 : *(this->lhs_mp) = *(this->lhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
290 0 : }
291 :
292 0 : if ((out == Base::SOL) || (out == Base::SOL_AND_GRAD)) {
293 : // inverse FFT of the product and store the electrostatic potential in rho2_mr
294 0 : fft_m->transform(BACKWARD, *(this->rhs_mp), rhotr_m);
295 :
296 : // normalization is double counted due to 2 transforms
297 0 : *(this->rhs_mp) = *(this->rhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
298 : // discretization of integral requires h^3 factor
299 0 : *(this->rhs_mp) = *(this->rhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
300 : }
301 0 : };
302 :
303 : ////////////////////////////////////////////////////////////////////////
304 : // calculate FFT of the Green's function
305 :
306 : template <typename FieldLHS, typename FieldRHS>
307 0 : void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::greensFunction() {
308 0 : grn_m = 0.0;
309 :
310 : // This alpha parameter is a choice for the Green's function
311 : // it controls the "range" of the Green's function (e.g.
312 : // for the collision modelling method, it indicates
313 : // the splitting between Particle-Particle interactions
314 : // and the Particle-Mesh computations).
315 0 : const Trhs alpha = this->params_m. template get<Trhs>("alpha");
316 0 : const Trhs forceConstant = this->params_m. template get<Trhs>("force_constant");
317 :
318 : // calculate square of the mesh spacing for each dimension
319 0 : Vector_t hrsq(hr_m * hr_m);
320 :
321 : // use the grnIField_m helper field to compute Green's function
322 0 : for (unsigned int i = 0; i < Dim; ++i) {
323 0 : grn_m = grn_m + grnIField_m[i] * hrsq[i];
324 : }
325 :
326 0 : typename Field_t::view_type view = grn_m.getView();
327 0 : const int nghost = grn_m.getNghost();
328 0 : const auto& ldom = layout_mp->getLocalNDIndex();
329 :
330 : // Kokkos parallel for loop to find (0,0,0) point and regularize
331 0 : Kokkos::parallel_for(
332 0 : "Assign Green's function ", ippl::getRangePolicy(view, nghost),
333 0 : KOKKOS_LAMBDA(const int i, const int j, const int k) {
334 : // go from local indices to global
335 0 : const int ig = i + ldom[0].first() - nghost;
336 0 : const int jg = j + ldom[1].first() - nghost;
337 0 : const int kg = k + ldom[2].first() - nghost;
338 :
339 0 : const bool isOrig = (ig == 0 && jg == 0 && kg == 0);
340 :
341 0 : Trhs r = Kokkos::real(Kokkos::sqrt(view(i, j, k)));
342 0 : view(i, j, k) = (!isOrig) * forceConstant * (Kokkos::erf(alpha * r) / r);
343 : });
344 :
345 : // perform the FFT of the Green's function for the convolution
346 0 : fft_m->transform(FORWARD, grn_m, grntr_m);
347 0 : };
348 :
349 : } // namespace ippl
|