Line data Source code
1 : //
2 : // Class FFTPeriodicPoissonSolver
3 : // Solves the periodic Poisson problem using Fourier transforms
4 : // cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
5 : //
6 : //
7 :
8 : namespace ippl {
9 :
10 : template <typename FieldLHS, typename FieldRHS>
11 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
12 0 : Base::setRhs(rhs);
13 0 : initialize();
14 0 : }
15 :
16 : template <typename FieldLHS, typename FieldRHS>
17 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::initialize() {
18 0 : const Layout_t& layout_r = this->rhs_mp->getLayout();
19 0 : domain_m = layout_r.getDomain();
20 :
21 0 : NDIndex<Dim> domainComplex;
22 :
23 0 : vector_type hComplex;
24 0 : vector_type originComplex;
25 :
26 0 : std::array<bool, Dim> isParallel = layout_r.isParallel();
27 0 : for (unsigned d = 0; d < Dim; ++d) {
28 0 : hComplex[d] = 1.0;
29 0 : originComplex[d] = 0.0;
30 :
31 0 : if (this->params_m.template get<int>("r2c_direction") == (int)d) {
32 0 : domainComplex[d] = Index(domain_m[d].length() / 2 + 1);
33 : } else {
34 0 : domainComplex[d] = Index(domain_m[d].length());
35 : }
36 : }
37 :
38 0 : layoutComplex_mp = std::make_shared<Layout_t>(layout_r.comm, domainComplex, isParallel);
39 :
40 0 : mesh_type meshComplex(domainComplex, hComplex, originComplex);
41 :
42 0 : fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
43 :
44 0 : if (this->params_m.template get<int>("output_type") == Base::GRAD) {
45 0 : tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
46 : }
47 :
48 0 : fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
49 0 : fft_mp->warmup(*this->rhs_mp, fieldComplex_m); // warmup the FFT object
50 0 : }
51 :
52 : template <typename FieldLHS, typename FieldRHS>
53 0 : void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
54 0 : fft_mp->transform(FORWARD, *this->rhs_mp, fieldComplex_m);
55 :
56 0 : auto view = fieldComplex_m.getView();
57 0 : const int nghost = fieldComplex_m.getNghost();
58 :
59 0 : scalar_type pi = Kokkos::numbers::pi_v<scalar_type>;
60 0 : const mesh_type& mesh = this->rhs_mp->get_mesh();
61 0 : const auto& lDomComplex = layoutComplex_mp->getLocalNDIndex();
62 : using vector_type = typename mesh_type::vector_type;
63 0 : const vector_type& origin = mesh.getOrigin();
64 0 : const vector_type& hx = mesh.getMeshSpacing();
65 :
66 0 : vector_type rmax;
67 0 : Vector<int, Dim> N;
68 0 : for (size_t d = 0; d < Dim; ++d) {
69 0 : N[d] = domain_m[d].length();
70 0 : rmax[d] = origin[d] + (N[d] * hx[d]);
71 : }
72 :
73 : // Based on output_type calculate either solution
74 : // or gradient
75 :
76 : using index_array_type = typename RangePolicy<Dim>::index_array_type;
77 0 : switch (this->params_m.template get<int>("output_type")) {
78 0 : case Base::SOL: {
79 0 : ippl::parallel_for(
80 0 : "Solution FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
81 0 : KOKKOS_LAMBDA(const index_array_type& args) {
82 0 : Vector<int, Dim> iVec = args - nghost;
83 0 : for (unsigned d = 0; d < Dim; ++d) {
84 0 : iVec[d] += lDomComplex[d].first();
85 : }
86 :
87 0 : Vector_t kVec;
88 :
89 0 : for (size_t d = 0; d < Dim; ++d) {
90 0 : const scalar_type Len = rmax[d] - origin[d];
91 0 : bool shift = (iVec[d] > (N[d] / 2));
92 0 : kVec[d] = 2 * pi / Len * (iVec[d] - shift * N[d]);
93 : }
94 :
95 0 : scalar_type Dr = 0;
96 0 : for (unsigned d = 0; d < Dim; ++d) {
97 0 : Dr += kVec[d] * kVec[d];
98 : }
99 :
100 0 : bool isNotZero = (Dr != 0.0);
101 0 : scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
102 :
103 0 : apply(view, args) *= factor;
104 0 : });
105 :
106 0 : fft_mp->transform(BACKWARD, *this->rhs_mp, fieldComplex_m);
107 :
108 0 : break;
109 : }
110 0 : case Base::GRAD: {
111 : // Compute gradient in Fourier space and then
112 : // take inverse FFT.
113 :
114 0 : Complex_t imag = {0.0, 1.0};
115 0 : auto tempview = tempFieldComplex_m.getView();
116 0 : auto viewRhs = this->rhs_mp->getView();
117 0 : auto viewLhs = this->lhs_mp->getView();
118 0 : const int nghostL = this->lhs_mp->getNghost();
119 :
120 0 : for (size_t gd = 0; gd < Dim; ++gd) {
121 0 : ippl::parallel_for(
122 0 : "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
123 0 : KOKKOS_LAMBDA(const index_array_type& args) {
124 0 : Vector<int, Dim> iVec = args - nghost;
125 0 : for (unsigned d = 0; d < Dim; ++d) {
126 0 : iVec[d] += lDomComplex[d].first();
127 : }
128 :
129 0 : Vector_t kVec;
130 :
131 0 : for (size_t d = 0; d < Dim; ++d) {
132 0 : const scalar_type Len = rmax[d] - origin[d];
133 0 : bool shift = (iVec[d] > (N[d] / 2));
134 0 : bool notMid = (iVec[d] != (N[d] / 2));
135 : // For the noMid part see
136 : // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
137 0 : kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
138 : }
139 :
140 0 : scalar_type Dr = 0;
141 0 : for (unsigned d = 0; d < Dim; ++d) {
142 0 : Dr += kVec[d] * kVec[d];
143 : }
144 :
145 0 : apply(tempview, args) = apply(view, args);
146 :
147 0 : bool isNotZero = (Dr != 0.0);
148 0 : scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
149 :
150 0 : apply(tempview, args) *= -(imag * kVec[gd] * factor);
151 0 : });
152 :
153 0 : fft_mp->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
154 :
155 0 : ippl::parallel_for(
156 : "Assign Gradient FFTPeriodicPoissonSolver",
157 0 : getRangePolicy(viewLhs, nghostL),
158 0 : KOKKOS_LAMBDA(const index_array_type& args) {
159 0 : apply(viewLhs, args)[gd] = apply(viewRhs, args);
160 : });
161 : }
162 :
163 0 : break;
164 0 : }
165 :
166 0 : default:
167 0 : throw IpplException("FFTPeriodicPoissonSolver::solve", "Unrecognized output_type");
168 : }
169 0 : }
170 : } // namespace ippl
|