LCOV - code coverage report
Current view: top level - src/PoissonSolvers - P3MSolver.hpp (source / functions) Coverage Total Hit
Test: final_report.info Lines: 0.0 % 153 0
Test Date: 2025-07-17 08:40:11 Functions: 0.0 % 11 0

            Line data    Source code
       1              : //
       2              : // Class P3MSolver
       3              : //   Poisson solver for periodic boundaries, based on FFTs.
       4              : //   Solves laplace(phi) = -rho, and E = -grad(phi).
       5              : //
       6              : //   Uses a convolution with a Green's function given by:
       7              : //      G(r) = ke * erf(alpha * r) / r,
       8              : //   where ke = Coulomb constant,
       9              : //         alpha = controls long-range interaction.
      10              : //
      11              : //
      12              : 
      13              : namespace ippl {
      14              : 
      15              :     /////////////////////////////////////////////////////////////////////////
      16              :     // constructor and destructor
      17              : 
      18              :     template <typename FieldLHS, typename FieldRHS>
      19            0 :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver()
      20              :         : Base()
      21            0 :         , mesh_mp(nullptr)
      22            0 :         , layout_mp(nullptr)
      23            0 :         , meshComplex_m(nullptr)
      24            0 :         , layoutComplex_m(nullptr) {
      25            0 :         setDefaultParameters();
      26            0 :     }
      27              : 
      28              :     template <typename FieldLHS, typename FieldRHS>
      29              :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver(rhs_type& rhs, ParameterList& params)
      30              :         : mesh_mp(nullptr)
      31              :         , layout_mp(nullptr)
      32              :         , meshComplex_m(nullptr)
      33              :         , layoutComplex_m(nullptr) {
      34              :         setDefaultParameters();
      35              : 
      36              :         this->params_m.merge(params);
      37              :         this->params_m.update("output_type", Base::SOL);
      38              : 
      39              :         this->setRhs(rhs);
      40              :     }
      41              : 
      42              :     template <typename FieldLHS, typename FieldRHS>
      43              :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver(lhs_type& lhs, rhs_type& rhs, ParameterList& params)
      44              :         : mesh_mp(nullptr)
      45              :         , layout_mp(nullptr)
      46              :         , meshComplex_m(nullptr)
      47              :         , layoutComplex_m(nullptr) {
      48              :         setDefaultParameters();
      49              : 
      50              :         this->params_m.merge(params);
      51              : 
      52              :         this->setLhs(lhs);
      53              :         this->setRhs(rhs);
      54              :     }
      55              : 
      56              :     template <typename FieldLHS, typename FieldRHS>
      57            0 :     void P3MSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
      58            0 :         Base::setRhs(rhs);
      59            0 :         initializeFields();
      60            0 :     }
      61              : 
      62              :     /////////////////////////////////////////////////////////////////////////
      63              :     // initializeFields method, called in constructor
      64              : 
      65              :     template <typename FieldLHS, typename FieldRHS>
      66            0 :     void P3MSolver<FieldLHS, FieldRHS>::initializeFields() {
      67              :         static_assert(Dim == 3, "Dimension other than 3 not supported in P3MSolver!");
      68              : 
      69              :         // get layout and mesh
      70            0 :         layout_mp              = &(this->rhs_mp->getLayout());
      71            0 :         mesh_mp                = &(this->rhs_mp->get_mesh());
      72            0 :         mpi::Communicator comm = layout_mp->comm;
      73              : 
      74              :         // get mesh spacing
      75            0 :         hr_m = mesh_mp->getMeshSpacing();
      76              : 
      77              :         // get origin
      78            0 :         Vector_t origin = mesh_mp->getOrigin();
      79              : 
      80              :         // create domain for the real fields
      81            0 :         domain_m = layout_mp->getDomain();
      82              : 
      83              :         // get the mesh spacings and sizes for each dimension
      84            0 :         for (unsigned int i = 0; i < Dim; ++i) {
      85            0 :             nr_m[i] = domain_m[i].length();
      86              :         }
      87              : 
      88              :         // define decomposition (parallel / serial)
      89            0 :         std::array<bool, Dim> isParallel = layout_mp->isParallel();
      90              : 
      91              :         // create the domain for the transformed (complex) fields
      92              :         // since we use HeFFTe for the transforms it doesn't require permuting to the right
      93              :         // one of the dimensions has only (n/2 +1) as our original fields are fully real
      94              :         // the dimension is given by the user via r2c_direction
      95            0 :         unsigned int RCDirection = this->params_m.template get<int>("r2c_direction");
      96            0 :         for (unsigned int i = 0; i < Dim; ++i) {
      97            0 :             if (i == RCDirection)
      98            0 :                 domainComplex_m[RCDirection] = Index(nr_m[RCDirection] / 2 + 1);
      99              :             else
     100            0 :                 domainComplex_m[i] = Index(nr_m[i]);
     101              :         }
     102              : 
     103              :         // create mesh and layout for the real to complex FFT transformed fields
     104              :         using mesh_type = typename lhs_type::Mesh_t;
     105            0 :         meshComplex_m   = std::unique_ptr<mesh_type>(new mesh_type(domainComplex_m, hr_m, origin));
     106            0 :         layoutComplex_m =
     107            0 :             std::unique_ptr<FieldLayout_t>(new FieldLayout_t(comm, domainComplex_m, isParallel));
     108              : 
     109              :         // initialize fields
     110            0 :         grn_m.initialize(*mesh_mp, *layout_mp);
     111            0 :         rhotr_m.initialize(*meshComplex_m, *layoutComplex_m);
     112            0 :         grntr_m.initialize(*meshComplex_m, *layoutComplex_m);
     113            0 :         tempFieldComplex_m.initialize(*meshComplex_m, *layoutComplex_m);
     114              : 
     115              :         // create the FFT object
     116            0 :         fft_m = std::make_unique<FFT_t>(*layout_mp, *layoutComplex_m, this->params_m);
     117            0 :         fft_m->warmup(grn_m, grntr_m);  // warmup the FFT object
     118              : 
     119              :         // these are fields that are used for calculating the Green's function
     120            0 :         for (unsigned int d = 0; d < Dim; ++d) {
     121            0 :             grnIField_m[d].initialize(*mesh_mp, *layout_mp);
     122              : 
     123              :             // get number of ghost points and the Kokkos view to iterate over field
     124            0 :             auto view        = grnIField_m[d].getView();
     125            0 :             const int nghost = grnIField_m[d].getNghost();
     126            0 :             const auto& ldom = layout_mp->getLocalNDIndex();
     127              : 
     128              :             // the length of the physical domain
     129            0 :             const int size = nr_m[d];
     130              : 
     131              :             // Kokkos parallel for loop to initialize grnIField[d]
     132            0 :             switch (d) {
     133            0 :                 case 0:
     134            0 :                     Kokkos::parallel_for(
     135              :                         "Helper index Green field initialization",
     136            0 :                         ippl::getRangePolicy(view, nghost),
     137            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     138              :                             // go from local indices to global
     139            0 :                             const int ig = i + ldom[0].first() - nghost;
     140            0 :                             const int jg = j + ldom[1].first() - nghost;
     141            0 :                             const int kg = k + ldom[2].first() - nghost;
     142              : 
     143              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     144            0 :                             const bool outsideN = (ig >= size / 2);
     145            0 :                             view(i, j, k)       = (size * outsideN - ig) * (size * outsideN - ig);
     146              : 
     147              :                             // add 1.0 if at (0,0,0) to avoid singularity
     148            0 :                             const bool isOrig = ((ig == 0) && (jg == 0) && (kg == 0));
     149            0 :                             view(i, j, k) += isOrig * 1.0;
     150              :                         });
     151            0 :                     break;
     152            0 :                 case 1:
     153            0 :                     Kokkos::parallel_for(
     154              :                         "Helper index Green field initialization",
     155            0 :                         ippl::getRangePolicy(view, nghost),
     156            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     157              :                             // go from local indices to global
     158            0 :                             const int jg = j + ldom[1].first() - nghost;
     159              : 
     160              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     161            0 :                             const bool outsideN = (jg >= size / 2);
     162            0 :                             view(i, j, k)       = (size * outsideN - jg) * (size * outsideN - jg);
     163              :                         });
     164            0 :                     break;
     165            0 :                 case 2:
     166            0 :                     Kokkos::parallel_for(
     167              :                         "Helper index Green field initialization",
     168            0 :                         ippl::getRangePolicy(view, nghost),
     169            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     170              :                             // go from local indices to global
     171            0 :                             const int kg = k + ldom[2].first() - nghost;
     172              : 
     173              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     174            0 :                             const bool outsideN = (kg >= size / 2);
     175            0 :                             view(i, j, k)       = (size * outsideN - kg) * (size * outsideN - kg);
     176              :                         });
     177            0 :                     break;
     178              :             }
     179              :         }
     180              : 
     181              :         // call greensFunction and we will get the transformed G in the class attribute
     182              :         // this is done in initialization so that we already have the precomputed fct
     183              :         // for all timesteps (green's fct will only change if mesh size changes)
     184              : 
     185            0 :         greensFunction();
     186            0 :     };
     187              : 
     188              :     /////////////////////////////////////////////////////////////////////////
     189              :     // compute electric potential by solving Poisson's eq given a field rho and mesh spacings hr
     190              :     template <typename FieldLHS, typename FieldRHS>
     191            0 :     void P3MSolver<FieldLHS, FieldRHS>::solve() {
     192              :         // get the output type (sol, grad, or sol & grad)
     193            0 :         const int out = this->params_m.template get<int>("output_type");
     194              : 
     195              :         // set the mesh & spacing, which may change each timestep
     196            0 :         mesh_mp = &(this->rhs_mp->get_mesh());
     197              : 
     198              :         // check whether the mesh spacing has changed with respect to the old one
     199              :         // if yes, update and set green flag to true
     200            0 :         bool green = false;
     201            0 :         for (unsigned int i = 0; i < Dim; ++i) {
     202            0 :             if (hr_m[i] != mesh_mp->getMeshSpacing(i)) {
     203            0 :                 hr_m[i] = mesh_mp->getMeshSpacing(i);
     204            0 :                 green   = true;
     205              :             }
     206              :         }
     207              : 
     208              :         // set mesh spacing on the other grids again
     209            0 :         meshComplex_m->setMeshSpacing(hr_m);
     210              : 
     211              :         // forward FFT of the charge density field on doubled grid
     212            0 :         rhotr_m = 0.0;
     213            0 :         fft_m->transform(FORWARD, *(this->rhs_mp), rhotr_m);
     214              : 
     215              :         // call greensFunction to recompute if the mesh spacing has changed
     216            0 :         if (green) {
     217            0 :             greensFunction();
     218              :         }
     219              : 
     220              :         // multiply FFT(rho2)*FFT(green)
     221              :         // convolution becomes multiplication in FFT
     222            0 :         rhotr_m = -rhotr_m * grntr_m;
     223              : 
     224              :         using index_array_type = typename RangePolicy<Dim>::index_array_type;
     225            0 :         if ((out == Base::GRAD) || (out == Base::SOL_AND_GRAD)) {
     226              :             // Compute gradient in Fourier space and then
     227              :             // take inverse FFT.
     228              : 
     229            0 :             const Trhs pi              = Kokkos::numbers::pi_v<Trhs>;
     230            0 :             Kokkos::complex<Trhs> imag = {0.0, 1.0};
     231              : 
     232            0 :             auto view               = rhotr_m.getView();
     233            0 :             const int nghost        = rhotr_m.getNghost();
     234            0 :             auto tempview           = tempFieldComplex_m.getView();
     235            0 :             auto viewRhs            = this->rhs_mp->getView();
     236            0 :             auto viewLhs            = this->lhs_mp->getView();
     237            0 :             const int nghostL       = this->lhs_mp->getNghost();
     238            0 :             const auto& lDomComplex = layoutComplex_m->getLocalNDIndex();
     239              : 
     240              :             // define some member variables in local scope for the parallel_for
     241            0 :             Vector_t hsize     = hr_m;
     242            0 :             Vector<int, Dim> N = nr_m;
     243            0 :             Vector_t origin    = mesh_mp->getOrigin();
     244              : 
     245            0 :             for (size_t gd = 0; gd < Dim; ++gd) {
     246            0 :                 ippl::parallel_for(
     247            0 :                     "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
     248            0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
     249            0 :                         Vector<int, Dim> iVec = args - nghost;
     250              : 
     251            0 :                         for (unsigned d = 0; d < Dim; ++d) {
     252            0 :                             iVec[d] += lDomComplex[d].first();
     253              :                         }
     254              : 
     255            0 :                         Vector_t kVec;
     256              : 
     257            0 :                         for (size_t d = 0; d < Dim; ++d) {
     258            0 :                             const Trhs Len = N[d] * hsize[d];
     259            0 :                             bool shift     = (iVec[d] > (N[d] / 2));
     260            0 :                             bool notMid    = (iVec[d] != (N[d] / 2));
     261              :                             // For the noMid part see
     262              :                             // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
     263            0 :                             kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
     264              :                         }
     265              : 
     266            0 :                         Trhs Dr = 0;
     267            0 :                         for (unsigned d = 0; d < Dim; ++d) {
     268            0 :                             Dr += kVec[d] * kVec[d];
     269              :                         }
     270              : 
     271            0 :                         apply(tempview, args) = apply(view, args);
     272              : 
     273            0 :                         bool isNotZero = (Dr != 0.0);
     274              : 
     275            0 :                         apply(tempview, args) *= -(isNotZero * imag * kVec[gd]);
     276            0 :                     });
     277              : 
     278            0 :                 fft_m->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
     279              : 
     280            0 :                 ippl::parallel_for(
     281            0 :                     "Assign Gradient FFTPeriodicPoissonSolver", getRangePolicy(viewLhs, nghostL),
     282            0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
     283            0 :                         apply(viewLhs, args)[gd] = apply(viewRhs, args);
     284              :                     });
     285              :             }
     286              : 
     287              :             // normalization is double counted due to 2 transforms
     288            0 :             *(this->lhs_mp) = *(this->lhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
     289              :             // discretization of integral requires h^3 factor
     290            0 :             *(this->lhs_mp) = *(this->lhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
     291            0 :         }
     292              : 
     293            0 :         if ((out == Base::SOL) || (out == Base::SOL_AND_GRAD)) {
     294              :             // inverse FFT of the product and store the electrostatic potential in rho2_mr
     295            0 :             fft_m->transform(BACKWARD, *(this->rhs_mp), rhotr_m);
     296              : 
     297              :             // normalization is double counted due to 2 transforms
     298            0 :             *(this->rhs_mp) = *(this->rhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
     299              :             // discretization of integral requires h^3 factor
     300            0 :             *(this->rhs_mp) = *(this->rhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
     301              :         }
     302            0 :     };
     303              : 
     304              :     ////////////////////////////////////////////////////////////////////////
     305              :     // calculate FFT of the Green's function
     306              : 
     307              :     template <typename FieldLHS, typename FieldRHS>
     308            0 :     void P3MSolver<FieldLHS, FieldRHS>::greensFunction() {
     309            0 :         grn_m = 0.0;
     310              : 
     311              :         // define pi
     312            0 :         Trhs pi = Kokkos::numbers::pi_v<Trhs>;
     313              : 
     314              :         // This alpha parameter is a choice for the Green's function
     315              :         // it controls the "range" of the Green's function (e.g.
     316              :         // for the P3M collision modelling method, it indicates
     317              :         // the splitting between Particle-Particle interactions
     318              :         // and the Particle-Mesh computations).
     319            0 :         Trhs alpha = 1e6;
     320              : 
     321              :         // calculate square of the mesh spacing for each dimension
     322            0 :         Vector_t hrsq(hr_m * hr_m);
     323              : 
     324              :         // use the grnIField_m helper field to compute Green's function
     325            0 :         for (unsigned int i = 0; i < Dim; ++i) {
     326            0 :             grn_m = grn_m + grnIField_m[i] * hrsq[i];
     327              :         }
     328              : 
     329            0 :         typename Field_t::view_type view = grn_m.getView();
     330            0 :         const int nghost                 = grn_m.getNghost();
     331            0 :         const auto& ldom                 = layout_mp->getLocalNDIndex();
     332              : 
     333              :         // Kokkos parallel for loop to find (0,0,0) point and regularize
     334            0 :         Kokkos::parallel_for(
     335            0 :             "Assign Green's function ", ippl::getRangePolicy(view, nghost),
     336            0 :             KOKKOS_LAMBDA(const int i, const int j, const int k) {
     337              :                 // go from local indices to global
     338            0 :                 const int ig = i + ldom[0].first() - nghost;
     339            0 :                 const int jg = j + ldom[1].first() - nghost;
     340            0 :                 const int kg = k + ldom[2].first() - nghost;
     341              : 
     342            0 :                 const bool isOrig = (ig == 0 && jg == 0 && kg == 0);
     343              : 
     344            0 :                 Trhs r        = Kokkos::real(Kokkos::sqrt(view(i, j, k)));
     345            0 :                 view(i, j, k) = (!isOrig) * (-1.0 / (4.0 * pi)) * (Kokkos::erf(alpha * r) / r);
     346              :             });
     347              : 
     348              :         // perform the FFT of the Green's function for the convolution
     349            0 :         fft_m->transform(FORWARD, grn_m, grntr_m);
     350            0 :     };
     351              : 
     352              : }  // namespace ippl
        

Generated by: LCOV version 2.0-1