LCOV - code coverage report
Current view: top level - src/PoissonSolvers - FFTPeriodicPoissonSolver.hpp (source / functions) Coverage Total Hit
Test: final_report.info Lines: 0.0 % 98 0
Test Date: 2025-07-18 09:50:00 Functions: 0.0 % 6 0

            Line data    Source code
       1              : //
       2              : // Class FFTPeriodicPoissonSolver
       3              : //   Solves the periodic Poisson problem using Fourier transforms
       4              : //   cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
       5              : //
       6              : //
       7              : 
       8              : namespace ippl {
       9              : 
      10              :     template <typename FieldLHS, typename FieldRHS>
      11            0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
      12            0 :         Base::setRhs(rhs);
      13            0 :         initialize();
      14            0 :     }
      15              : 
      16              :     template <typename FieldLHS, typename FieldRHS>
      17            0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::initialize() {
      18            0 :         const Layout_t& layout_r = this->rhs_mp->getLayout();
      19            0 :         domain_m                 = layout_r.getDomain();
      20              : 
      21            0 :         NDIndex<Dim> domainComplex;
      22              : 
      23            0 :         vector_type hComplex;
      24            0 :         vector_type originComplex;
      25              : 
      26            0 :         std::array<bool, Dim> isParallel = layout_r.isParallel();
      27            0 :         for (unsigned d = 0; d < Dim; ++d) {
      28            0 :             hComplex[d]      = 1.0;
      29            0 :             originComplex[d] = 0.0;
      30              : 
      31            0 :             if (this->params_m.template get<int>("r2c_direction") == (int)d) {
      32            0 :                 domainComplex[d] = Index(domain_m[d].length() / 2 + 1);
      33              :             } else {
      34            0 :                 domainComplex[d] = Index(domain_m[d].length());
      35              :             }
      36              :         }
      37              : 
      38            0 :         layoutComplex_mp = std::make_shared<Layout_t>(layout_r.comm, domainComplex, isParallel);
      39              : 
      40            0 :         mesh_type meshComplex(domainComplex, hComplex, originComplex);
      41              : 
      42            0 :         fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
      43              : 
      44            0 :         if (this->params_m.template get<int>("output_type") == Base::GRAD) {
      45            0 :             tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
      46              :         }
      47              : 
      48            0 :         fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
      49            0 :         fft_mp->warmup(*this->rhs_mp, fieldComplex_m);  // warmup the FFT object
      50            0 :     }
      51              : 
      52              :     template <typename FieldLHS, typename FieldRHS>
      53            0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
      54            0 :         fft_mp->transform(FORWARD, *this->rhs_mp, fieldComplex_m);
      55              : 
      56            0 :         auto view        = fieldComplex_m.getView();
      57            0 :         const int nghost = fieldComplex_m.getNghost();
      58              : 
      59            0 :         scalar_type pi            = Kokkos::numbers::pi_v<scalar_type>;
      60            0 :         const mesh_type& mesh     = this->rhs_mp->get_mesh();
      61            0 :         const auto& lDomComplex   = layoutComplex_mp->getLocalNDIndex();
      62              :         using vector_type         = typename mesh_type::vector_type;
      63            0 :         const vector_type& origin = mesh.getOrigin();
      64            0 :         const vector_type& hx     = mesh.getMeshSpacing();
      65              : 
      66            0 :         vector_type rmax;
      67            0 :         Vector<int, Dim> N;
      68            0 :         for (size_t d = 0; d < Dim; ++d) {
      69            0 :             N[d]    = domain_m[d].length();
      70            0 :             rmax[d] = origin[d] + (N[d] * hx[d]);
      71              :         }
      72              : 
      73              :         // Based on output_type calculate either solution
      74              :         // or gradient
      75              : 
      76              :         using index_array_type = typename RangePolicy<Dim>::index_array_type;
      77            0 :         switch (this->params_m.template get<int>("output_type")) {
      78            0 :             case Base::SOL: {
      79            0 :                 ippl::parallel_for(
      80            0 :                     "Solution FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
      81            0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
      82            0 :                         Vector<int, Dim> iVec = args - nghost;
      83            0 :                         for (unsigned d = 0; d < Dim; ++d) {
      84            0 :                             iVec[d] += lDomComplex[d].first();
      85              :                         }
      86              : 
      87            0 :                         Vector_t kVec;
      88              : 
      89            0 :                         for (size_t d = 0; d < Dim; ++d) {
      90            0 :                             const scalar_type Len = rmax[d] - origin[d];
      91            0 :                             bool shift            = (iVec[d] > (N[d] / 2));
      92            0 :                             kVec[d]               = 2 * pi / Len * (iVec[d] - shift * N[d]);
      93              :                         }
      94              : 
      95            0 :                         scalar_type Dr = 0;
      96            0 :                         for (unsigned d = 0; d < Dim; ++d) {
      97            0 :                             Dr += kVec[d] * kVec[d];
      98              :                         }
      99              : 
     100            0 :                         bool isNotZero     = (Dr != 0.0);
     101            0 :                         scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
     102              : 
     103            0 :                         apply(view, args) *= factor;
     104            0 :                     });
     105              : 
     106            0 :                 fft_mp->transform(BACKWARD, *this->rhs_mp, fieldComplex_m);
     107              : 
     108            0 :                 break;
     109              :             }
     110            0 :             case Base::GRAD: {
     111              :                 // Compute gradient in Fourier space and then
     112              :                 // take inverse FFT.
     113              : 
     114            0 :                 Complex_t imag    = {0.0, 1.0};
     115            0 :                 auto tempview     = tempFieldComplex_m.getView();
     116            0 :                 auto viewRhs      = this->rhs_mp->getView();
     117            0 :                 auto viewLhs      = this->lhs_mp->getView();
     118            0 :                 const int nghostL = this->lhs_mp->getNghost();
     119              : 
     120            0 :                 for (size_t gd = 0; gd < Dim; ++gd) {
     121            0 :                     ippl::parallel_for(
     122            0 :                         "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
     123            0 :                         KOKKOS_LAMBDA(const index_array_type& args) {
     124            0 :                             Vector<int, Dim> iVec = args - nghost;
     125            0 :                             for (unsigned d = 0; d < Dim; ++d) {
     126            0 :                                 iVec[d] += lDomComplex[d].first();
     127              :                             }
     128              : 
     129            0 :                             Vector_t kVec;
     130              : 
     131            0 :                             for (size_t d = 0; d < Dim; ++d) {
     132            0 :                                 const scalar_type Len = rmax[d] - origin[d];
     133            0 :                                 bool shift            = (iVec[d] > (N[d] / 2));
     134            0 :                                 bool notMid           = (iVec[d] != (N[d] / 2));
     135              :                                 // For the noMid part see
     136              :                                 // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
     137            0 :                                 kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
     138              :                             }
     139              : 
     140            0 :                             scalar_type Dr = 0;
     141            0 :                             for (unsigned d = 0; d < Dim; ++d) {
     142            0 :                                 Dr += kVec[d] * kVec[d];
     143              :                             }
     144              : 
     145            0 :                             apply(tempview, args) = apply(view, args);
     146              : 
     147            0 :                             bool isNotZero     = (Dr != 0.0);
     148            0 :                             scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
     149              : 
     150            0 :                             apply(tempview, args) *= -(imag * kVec[gd] * factor);
     151            0 :                         });
     152              : 
     153            0 :                     fft_mp->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
     154              : 
     155            0 :                     ippl::parallel_for(
     156              :                         "Assign Gradient FFTPeriodicPoissonSolver",
     157            0 :                         getRangePolicy(viewLhs, nghostL),
     158            0 :                         KOKKOS_LAMBDA(const index_array_type& args) {
     159            0 :                             apply(viewLhs, args)[gd] = apply(viewRhs, args);
     160              :                         });
     161              :                 }
     162              : 
     163            0 :                 break;
     164            0 :             }
     165              : 
     166            0 :             default:
     167            0 :                 throw IpplException("FFTPeriodicPoissonSolver::solve", "Unrecognized output_type");
     168              :         }
     169            0 :     }
     170              : }  // namespace ippl
        

Generated by: LCOV version 2.0-1