LCOV - code coverage report
Current view: top level - src/PoissonSolvers - FFTTruncatedGreenPeriodicPoissonSolver.hpp (source / functions) Coverage Total Hit
Test: final_report.info Lines: 0.0 % 153 0
Test Date: 2025-07-18 17:15:09 Functions: 0.0 % 11 0

            Line data    Source code
       1              : //
       2              : // Class FFTTruncatedGreenPeriodicPoissonSolver
       3              : //   Poisson solver for periodic boundaries, based on FFTs.
       4              : //   Solves laplace(phi) = -rho, and E = -grad(phi).
       5              : //
       6              : //   Uses a convolution with a Green's function given by:
       7              : //      G(r) = forceConstant * erf(alpha * r) / r,
       8              : //         alpha = controls long-range interaction.
       9              : //
      10              : //
      11              : 
      12              : namespace ippl {
      13              : 
      14              :     /////////////////////////////////////////////////////////////////////////
      15              :     // constructor and destructor
      16              : 
      17              :     template <typename FieldLHS, typename FieldRHS>
      18            0 :     FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver()
      19              :         : Base()
      20            0 :         , mesh_mp(nullptr)
      21            0 :         , layout_mp(nullptr)
      22            0 :         , meshComplex_m(nullptr)
      23            0 :         , layoutComplex_m(nullptr) {
      24            0 :         FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
      25            0 :     }
      26              : 
      27              :     template <typename FieldLHS, typename FieldRHS>
      28              :     FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver(rhs_type& rhs, ParameterList& params)
      29              :         : mesh_mp(nullptr)
      30              :         , layout_mp(nullptr)
      31              :         , meshComplex_m(nullptr)
      32              :         , layoutComplex_m(nullptr) {
      33              :         FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
      34              : 
      35              :         this->params_m.merge(params);
      36              :         this->params_m.update("output_type", Base::SOL);
      37              : 
      38              :         FFTTruncatedGreenPeriodicPoissonSolver::setRhs(rhs);
      39              :     }
      40              : 
      41              :     template <typename FieldLHS, typename FieldRHS>
      42              :     FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::FFTTruncatedGreenPeriodicPoissonSolver(lhs_type& lhs, rhs_type& rhs, ParameterList& params)
      43              :         : mesh_mp(nullptr)
      44              :         , layout_mp(nullptr)
      45              :         , meshComplex_m(nullptr)
      46              :         , layoutComplex_m(nullptr) {
      47              :         FFTTruncatedGreenPeriodicPoissonSolver::setDefaultParameters();
      48              : 
      49              :         this->params_m.merge(params);
      50              : 
      51              :         this->setLhs(lhs);
      52              :         FFTTruncatedGreenPeriodicPoissonSolver::setRhs(rhs);
      53              :     }
      54              : 
      55              :     template <typename FieldLHS, typename FieldRHS>
      56            0 :     void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
      57            0 :         Base::setRhs(rhs);
      58            0 :         initializeFields();
      59            0 :     }
      60              : 
      61              :     /////////////////////////////////////////////////////////////////////////
      62              :     // initializeFields method, called in constructor
      63              : 
      64              :     template <typename FieldLHS, typename FieldRHS>
      65            0 :     void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::initializeFields() {
      66              :         static_assert(Dim == 3, "Dimension other than 3 not supported in FFTTruncatedGreenPeriodicPoissonSolver!");
      67              : 
      68              :         // get layout and mesh
      69            0 :         layout_mp              = &(this->rhs_mp->getLayout());
      70            0 :         mesh_mp                = &(this->rhs_mp->get_mesh());
      71            0 :         mpi::Communicator comm = layout_mp->comm;
      72              : 
      73              :         // get mesh spacing
      74            0 :         hr_m = mesh_mp->getMeshSpacing();
      75              : 
      76              :         // get origin
      77            0 :         Vector_t origin = mesh_mp->getOrigin();
      78              : 
      79              :         // create domain for the real fields
      80            0 :         domain_m = layout_mp->getDomain();
      81              : 
      82              :         // get the mesh spacings and sizes for each dimension
      83            0 :         for (unsigned int i = 0; i < Dim; ++i) {
      84            0 :             nr_m[i] = domain_m[i].length();
      85              :         }
      86              : 
      87              :         // define decomposition (parallel / serial)
      88            0 :         std::array<bool, Dim> isParallel = layout_mp->isParallel();
      89              : 
      90              :         // create the domain for the transformed (complex) fields
      91              :         // since we use HeFFTe for the transforms it doesn't require permuting to the right
      92              :         // one of the dimensions has only (n/2 +1) as our original fields are fully real
      93              :         // the dimension is given by the user via r2c_direction
      94            0 :         unsigned int RCDirection = this->params_m.template get<int>("r2c_direction");
      95            0 :         for (unsigned int i = 0; i < Dim; ++i) {
      96            0 :             if (i == RCDirection)
      97            0 :                 domainComplex_m[RCDirection] = Index(nr_m[RCDirection] / 2 + 1);
      98              :             else
      99            0 :                 domainComplex_m[i] = Index(nr_m[i]);
     100              :         }
     101              : 
     102              :         // create mesh and layout for the real to complex FFT transformed fields
     103              :         using mesh_type = typename lhs_type::Mesh_t;
     104            0 :         meshComplex_m   = std::unique_ptr<mesh_type>(new mesh_type(domainComplex_m, hr_m, origin));
     105            0 :         layoutComplex_m =
     106            0 :             std::unique_ptr<FieldLayout_t>(new FieldLayout_t(comm, domainComplex_m, isParallel));
     107              : 
     108              :         // initialize fields
     109            0 :         grn_m.initialize(*mesh_mp, *layout_mp);
     110            0 :         rhotr_m.initialize(*meshComplex_m, *layoutComplex_m);
     111            0 :         grntr_m.initialize(*meshComplex_m, *layoutComplex_m);
     112            0 :         tempFieldComplex_m.initialize(*meshComplex_m, *layoutComplex_m);
     113              : 
     114              :         // create the FFT object
     115            0 :         fft_m = std::make_unique<FFT_t>(*layout_mp, *layoutComplex_m, this->params_m);
     116            0 :         fft_m->warmup(grn_m, grntr_m);  // warmup the FFT object
     117              : 
     118              :         // these are fields that are used for calculating the Green's function
     119            0 :         for (unsigned int d = 0; d < Dim; ++d) {
     120            0 :             grnIField_m[d].initialize(*mesh_mp, *layout_mp);
     121              : 
     122              :             // get number of ghost points and the Kokkos view to iterate over field
     123            0 :             auto view        = grnIField_m[d].getView();
     124            0 :             const int nghost = grnIField_m[d].getNghost();
     125            0 :             const auto& ldom = layout_mp->getLocalNDIndex();
     126              : 
     127              :             // the length of the physical domain
     128            0 :             const int size = nr_m[d];
     129              : 
     130              :             // Kokkos parallel for loop to initialize grnIField[d]
     131            0 :             switch (d) {
     132            0 :                 case 0:
     133            0 :                     Kokkos::parallel_for(
     134              :                         "Helper index Green field initialization",
     135            0 :                         ippl::getRangePolicy(view, nghost),
     136            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     137              :                             // go from local indices to global
     138            0 :                             const int ig = i + ldom[0].first() - nghost;
     139            0 :                             const int jg = j + ldom[1].first() - nghost;
     140            0 :                             const int kg = k + ldom[2].first() - nghost;
     141              : 
     142              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     143            0 :                             const bool outsideN = (ig >= size / 2);
     144            0 :                             view(i, j, k)       = (size * outsideN - ig) * (size * outsideN - ig);
     145              : 
     146              :                             // add 1.0 if at (0,0,0) to avoid singularity
     147            0 :                             const bool isOrig = ((ig == 0) && (jg == 0) && (kg == 0));
     148            0 :                             view(i, j, k) += isOrig * 1.0;
     149              :                         });
     150            0 :                     break;
     151            0 :                 case 1:
     152            0 :                     Kokkos::parallel_for(
     153              :                         "Helper index Green field initialization",
     154            0 :                         ippl::getRangePolicy(view, nghost),
     155            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     156              :                             // go from local indices to global
     157            0 :                             const int jg = j + ldom[1].first() - nghost;
     158              : 
     159              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     160            0 :                             const bool outsideN = (jg >= size / 2);
     161            0 :                             view(i, j, k)       = (size * outsideN - jg) * (size * outsideN - jg);
     162              :                         });
     163            0 :                     break;
     164            0 :                 case 2:
     165            0 :                     Kokkos::parallel_for(
     166              :                         "Helper index Green field initialization",
     167            0 :                         ippl::getRangePolicy(view, nghost),
     168            0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     169              :                             // go from local indices to global
     170            0 :                             const int kg = k + ldom[2].first() - nghost;
     171              : 
     172              :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     173            0 :                             const bool outsideN = (kg >= size / 2);
     174            0 :                             view(i, j, k)       = (size * outsideN - kg) * (size * outsideN - kg);
     175              :                         });
     176            0 :                     break;
     177              :             }
     178              :         }
     179              : 
     180              :         // call greensFunction and we will get the transformed G in the class attribute
     181              :         // this is done in initialization so that we already have the precomputed fct
     182              :         // for all timesteps (green's fct will only change if mesh size changes)
     183              : 
     184            0 :         greensFunction();
     185            0 :     };
     186              : 
     187              :     /////////////////////////////////////////////////////////////////////////
     188              :     // compute electric potential by solving Poisson's eq given a field rho and mesh spacings hr
     189              :     template <typename FieldLHS, typename FieldRHS>
     190            0 :     void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
     191              :         // get the output type (sol, grad, or sol & grad)
     192            0 :         const int out = this->params_m.template get<int>("output_type");
     193              : 
     194              :         // set the mesh & spacing, which may change each timestep
     195            0 :         mesh_mp = &(this->rhs_mp->get_mesh());
     196              : 
     197              :         // check whether the mesh spacing has changed with respect to the old one
     198              :         // if yes, update and set green flag to true
     199            0 :         bool green = false;
     200            0 :         for (unsigned int i = 0; i < Dim; ++i) {
     201            0 :             if (hr_m[i] != mesh_mp->getMeshSpacing(i)) {
     202            0 :                 hr_m[i] = mesh_mp->getMeshSpacing(i);
     203            0 :                 green   = true;
     204              :             }
     205              :         }
     206              : 
     207              :         // set mesh spacing on the other grids again
     208            0 :         meshComplex_m->setMeshSpacing(hr_m);
     209              : 
     210              :         // forward FFT of the charge density field on doubled grid
     211            0 :         rhotr_m = 0.0;
     212            0 :         fft_m->transform(FORWARD, *(this->rhs_mp), rhotr_m);
     213              : 
     214              :         // call greensFunction to recompute if the mesh spacing has changed
     215            0 :         if (green) {
     216            0 :             greensFunction();
     217              :         }
     218              : 
     219              :         // multiply FFT(rho2)*FFT(green)
     220              :         // convolution becomes multiplication in FFT
     221            0 :         rhotr_m = -rhotr_m * grntr_m;
     222              : 
     223              :         using index_array_type = typename RangePolicy<Dim>::index_array_type;
     224            0 :         if ((out == Base::GRAD) || (out == Base::SOL_AND_GRAD)) {
     225              :             // Compute gradient in Fourier space and then
     226              :             // take inverse FFT.
     227              : 
     228            0 :             const Trhs pi              = Kokkos::numbers::pi_v<Trhs>;
     229            0 :             Kokkos::complex<Trhs> imag = {0.0, 1.0};
     230              : 
     231            0 :             auto view               = rhotr_m.getView();
     232            0 :             const int nghost        = rhotr_m.getNghost();
     233            0 :             auto tempview           = tempFieldComplex_m.getView();
     234            0 :             auto viewRhs            = this->rhs_mp->getView();
     235            0 :             auto viewLhs            = this->lhs_mp->getView();
     236            0 :             const int nghostL       = this->lhs_mp->getNghost();
     237            0 :             const auto& lDomComplex = layoutComplex_m->getLocalNDIndex();
     238              : 
     239              :             // define some member variables in local scope for the parallel_for
     240            0 :             Vector_t hsize     = hr_m;
     241            0 :             Vector<int, Dim> N = nr_m;
     242            0 :             Vector_t origin    = mesh_mp->getOrigin();
     243              : 
     244            0 :             for (size_t gd = 0; gd < Dim; ++gd) {
     245            0 :                 ippl::parallel_for(
     246            0 :                     "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
     247            0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
     248            0 :                         Vector<int, Dim> iVec = args - nghost;
     249              : 
     250            0 :                         for (unsigned d = 0; d < Dim; ++d) {
     251            0 :                             iVec[d] += lDomComplex[d].first();
     252              :                         }
     253              : 
     254            0 :                         Vector_t kVec;
     255              : 
     256            0 :                         for (size_t d = 0; d < Dim; ++d) {
     257            0 :                             const Trhs Len = N[d] * hsize[d];
     258            0 :                             bool shift     = (iVec[d] > (N[d] / 2));
     259            0 :                             bool notMid    = (iVec[d] != (N[d] / 2));
     260              :                             // For the noMid part see
     261              :                             // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
     262            0 :                             kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
     263              :                         }
     264              : 
     265            0 :                         Trhs Dr = 0;
     266            0 :                         for (unsigned d = 0; d < Dim; ++d) {
     267            0 :                             Dr += kVec[d] * kVec[d];
     268              :                         }
     269              : 
     270            0 :                         apply(tempview, args) = apply(view, args);
     271              : 
     272            0 :                         bool isNotZero = (Dr != 0.0);
     273              : 
     274            0 :                         apply(tempview, args) *= -(isNotZero * imag * kVec[gd]);
     275            0 :                     });
     276              : 
     277            0 :                 fft_m->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
     278              : 
     279            0 :                 ippl::parallel_for(
     280            0 :                     "Assign Gradient FFTPeriodicPoissonSolver", getRangePolicy(viewLhs, nghostL),
     281            0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
     282            0 :                         apply(viewLhs, args)[gd] = apply(viewRhs, args);
     283              :                     });
     284              :             }
     285              : 
     286              :             // normalization is double counted due to 2 transforms
     287            0 :             *(this->lhs_mp) = *(this->lhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
     288              :             // discretization of integral requires h^3 factor
     289            0 :             *(this->lhs_mp) = *(this->lhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
     290            0 :         }
     291              : 
     292            0 :         if ((out == Base::SOL) || (out == Base::SOL_AND_GRAD)) {
     293              :             // inverse FFT of the product and store the electrostatic potential in rho2_mr
     294            0 :             fft_m->transform(BACKWARD, *(this->rhs_mp), rhotr_m);
     295              : 
     296              :             // normalization is double counted due to 2 transforms
     297            0 :             *(this->rhs_mp) = *(this->rhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
     298              :             // discretization of integral requires h^3 factor
     299            0 :             *(this->rhs_mp) = *(this->rhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
     300              :         }
     301            0 :     };
     302              : 
     303              :     ////////////////////////////////////////////////////////////////////////
     304              :     // calculate FFT of the Green's function
     305              : 
     306              :     template <typename FieldLHS, typename FieldRHS>
     307            0 :     void FFTTruncatedGreenPeriodicPoissonSolver<FieldLHS, FieldRHS>::greensFunction() {
     308            0 :         grn_m = 0.0;
     309              : 
     310              :         // This alpha parameter is a choice for the Green's function
     311              :         // it controls the "range" of the Green's function (e.g.
     312              :         // for the collision modelling method, it indicates
     313              :         // the splitting between Particle-Particle interactions
     314              :         // and the Particle-Mesh computations).
     315            0 :         const Trhs alpha = this->params_m. template get<Trhs>("alpha");
     316            0 :         const Trhs forceConstant = this->params_m. template get<Trhs>("force_constant");
     317              : 
     318              :         // calculate square of the mesh spacing for each dimension
     319            0 :         Vector_t hrsq(hr_m * hr_m);
     320              : 
     321              :         // use the grnIField_m helper field to compute Green's function
     322            0 :         for (unsigned int i = 0; i < Dim; ++i) {
     323            0 :             grn_m = grn_m + grnIField_m[i] * hrsq[i];
     324              :         }
     325              : 
     326            0 :         typename Field_t::view_type view = grn_m.getView();
     327            0 :         const int nghost                 = grn_m.getNghost();
     328            0 :         const auto& ldom                 = layout_mp->getLocalNDIndex();
     329              : 
     330              :         // Kokkos parallel for loop to find (0,0,0) point and regularize
     331            0 :         Kokkos::parallel_for(
     332            0 :             "Assign Green's function ", ippl::getRangePolicy(view, nghost),
     333            0 :             KOKKOS_LAMBDA(const int i, const int j, const int k) {
     334              :                 // go from local indices to global
     335            0 :                 const int ig = i + ldom[0].first() - nghost;
     336            0 :                 const int jg = j + ldom[1].first() - nghost;
     337            0 :                 const int kg = k + ldom[2].first() - nghost;
     338              : 
     339            0 :                 const bool isOrig = (ig == 0 && jg == 0 && kg == 0);
     340              : 
     341            0 :                 Trhs r        = Kokkos::real(Kokkos::sqrt(view(i, j, k)));
     342            0 :                 view(i, j, k) = (!isOrig) * forceConstant * (Kokkos::erf(alpha * r) / r);
     343              :             });
     344              : 
     345              :         // perform the FFT of the Green's function for the convolution
     346            0 :         fft_m->transform(FORWARD, grn_m, grntr_m);
     347            0 :     };
     348              : 
     349              : }  // namespace ippl
        

Generated by: LCOV version 2.0-1