LCOV - code coverage report
Current view: top level - src/PoissonSolvers - P3MSolver.hpp (source / functions) Coverage Total Hit
Test: report.info Lines: 0.0 % 153 0
Test Date: 2025-05-15 13:47:30 Functions: 0.0 % 11 0
Branches: 0.0 % 268 0

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // Class P3MSolver
       3                 :             : //   Poisson solver for periodic boundaries, based on FFTs.
       4                 :             : //   Solves laplace(phi) = -rho, and E = -grad(phi).
       5                 :             : //
       6                 :             : //   Uses a convolution with a Green's function given by:
       7                 :             : //      G(r) = ke * erf(alpha * r) / r,
       8                 :             : //   where ke = Coulomb constant,
       9                 :             : //         alpha = controls long-range interaction.
      10                 :             : //
      11                 :             : //
      12                 :             : 
      13                 :             : namespace ippl {
      14                 :             : 
      15                 :             :     /////////////////////////////////////////////////////////////////////////
      16                 :             :     // constructor and destructor
      17                 :             : 
      18                 :             :     template <typename FieldLHS, typename FieldRHS>
      19                 :           0 :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver()
      20                 :             :         : Base()
      21                 :           0 :         , mesh_mp(nullptr)
      22                 :           0 :         , layout_mp(nullptr)
      23                 :           0 :         , meshComplex_m(nullptr)
      24   [ #  #  #  #  :           0 :         , layoutComplex_m(nullptr) {
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
                #  #  # ]
      25         [ #  # ]:           0 :         setDefaultParameters();
      26         [ #  # ]:           0 :     }
      27                 :             : 
      28                 :             :     template <typename FieldLHS, typename FieldRHS>
      29                 :             :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver(rhs_type& rhs, ParameterList& params)
      30                 :             :         : mesh_mp(nullptr)
      31                 :             :         , layout_mp(nullptr)
      32                 :             :         , meshComplex_m(nullptr)
      33                 :             :         , layoutComplex_m(nullptr) {
      34                 :             :         setDefaultParameters();
      35                 :             : 
      36                 :             :         this->params_m.merge(params);
      37                 :             :         this->params_m.update("output_type", Base::SOL);
      38                 :             : 
      39                 :             :         this->setRhs(rhs);
      40                 :             :     }
      41                 :             : 
      42                 :             :     template <typename FieldLHS, typename FieldRHS>
      43                 :             :     P3MSolver<FieldLHS, FieldRHS>::P3MSolver(lhs_type& lhs, rhs_type& rhs, ParameterList& params)
      44                 :             :         : mesh_mp(nullptr)
      45                 :             :         , layout_mp(nullptr)
      46                 :             :         , meshComplex_m(nullptr)
      47                 :             :         , layoutComplex_m(nullptr) {
      48                 :             :         setDefaultParameters();
      49                 :             : 
      50                 :             :         this->params_m.merge(params);
      51                 :             : 
      52                 :             :         this->setLhs(lhs);
      53                 :             :         this->setRhs(rhs);
      54                 :             :     }
      55                 :             : 
      56                 :             :     template <typename FieldLHS, typename FieldRHS>
      57                 :           0 :     void P3MSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
      58                 :           0 :         Base::setRhs(rhs);
      59                 :           0 :         initializeFields();
      60                 :           0 :     }
      61                 :             : 
      62                 :             :     /////////////////////////////////////////////////////////////////////////
      63                 :             :     // initializeFields method, called in constructor
      64                 :             : 
      65                 :             :     template <typename FieldLHS, typename FieldRHS>
      66                 :           0 :     void P3MSolver<FieldLHS, FieldRHS>::initializeFields() {
      67                 :             :         static_assert(Dim == 3, "Dimension other than 3 not supported in P3MSolver!");
      68                 :             : 
      69                 :             :         // get layout and mesh
      70         [ #  # ]:           0 :         layout_mp              = &(this->rhs_mp->getLayout());
      71                 :           0 :         mesh_mp                = &(this->rhs_mp->get_mesh());
      72         [ #  # ]:           0 :         mpi::Communicator comm = layout_mp->comm;
      73                 :             : 
      74                 :             :         // get mesh spacing
      75                 :           0 :         hr_m = mesh_mp->getMeshSpacing();
      76                 :             : 
      77                 :             :         // get origin
      78                 :           0 :         Vector_t origin = mesh_mp->getOrigin();
      79                 :             : 
      80                 :             :         // create domain for the real fields
      81                 :           0 :         domain_m = layout_mp->getDomain();
      82                 :             : 
      83                 :             :         // get the mesh spacings and sizes for each dimension
      84         [ #  # ]:           0 :         for (unsigned int i = 0; i < Dim; ++i) {
      85                 :           0 :             nr_m[i] = domain_m[i].length();
      86                 :             :         }
      87                 :             : 
      88                 :             :         // define decomposition (parallel / serial)
      89                 :           0 :         std::array<bool, Dim> isParallel = layout_mp->isParallel();
      90                 :             : 
      91                 :             :         // create the domain for the transformed (complex) fields
      92                 :             :         // since we use HeFFTe for the transforms it doesn't require permuting to the right
      93                 :             :         // one of the dimensions has only (n/2 +1) as our original fields are fully real
      94                 :             :         // the dimension is given by the user via r2c_direction
      95   [ #  #  #  # ]:           0 :         unsigned int RCDirection = this->params_m.template get<int>("r2c_direction");
      96         [ #  # ]:           0 :         for (unsigned int i = 0; i < Dim; ++i) {
      97         [ #  # ]:           0 :             if (i == RCDirection)
      98                 :           0 :                 domainComplex_m[RCDirection] = Index(nr_m[RCDirection] / 2 + 1);
      99                 :             :             else
     100                 :           0 :                 domainComplex_m[i] = Index(nr_m[i]);
     101                 :             :         }
     102                 :             : 
     103                 :             :         // create mesh and layout for the real to complex FFT transformed fields
     104                 :             :         using mesh_type = typename lhs_type::Mesh_t;
     105   [ #  #  #  #  :           0 :         meshComplex_m   = std::unique_ptr<mesh_type>(new mesh_type(domainComplex_m, hr_m, origin));
             #  #  #  # ]
     106                 :           0 :         layoutComplex_m =
     107   [ #  #  #  #  :           0 :             std::unique_ptr<FieldLayout_t>(new FieldLayout_t(comm, domainComplex_m, isParallel));
          #  #  #  #  #  
                      # ]
     108                 :             : 
     109                 :             :         // initialize fields
     110         [ #  # ]:           0 :         grn_m.initialize(*mesh_mp, *layout_mp);
     111         [ #  # ]:           0 :         rhotr_m.initialize(*meshComplex_m, *layoutComplex_m);
     112         [ #  # ]:           0 :         grntr_m.initialize(*meshComplex_m, *layoutComplex_m);
     113         [ #  # ]:           0 :         tempFieldComplex_m.initialize(*meshComplex_m, *layoutComplex_m);
     114                 :             : 
     115                 :             :         // create the FFT object
     116         [ #  # ]:           0 :         fft_m = std::make_unique<FFT_t>(*layout_mp, *layoutComplex_m, this->params_m);
     117         [ #  # ]:           0 :         fft_m->warmup(grn_m, grntr_m);  // warmup the FFT object
     118                 :             : 
     119                 :             :         // these are fields that are used for calculating the Green's function
     120         [ #  # ]:           0 :         for (unsigned int d = 0; d < Dim; ++d) {
     121         [ #  # ]:           0 :             grnIField_m[d].initialize(*mesh_mp, *layout_mp);
     122                 :             : 
     123                 :             :             // get number of ghost points and the Kokkos view to iterate over field
     124         [ #  # ]:           0 :             auto view        = grnIField_m[d].getView();
     125                 :           0 :             const int nghost = grnIField_m[d].getNghost();
     126         [ #  # ]:           0 :             const auto& ldom = layout_mp->getLocalNDIndex();
     127                 :             : 
     128                 :             :             // the length of the physical domain
     129                 :           0 :             const int size = nr_m[d];
     130                 :             : 
     131                 :             :             // Kokkos parallel for loop to initialize grnIField[d]
     132   [ #  #  #  # ]:           0 :             switch (d) {
     133                 :           0 :                 case 0:
     134   [ #  #  #  # ]:           0 :                     Kokkos::parallel_for(
     135                 :             :                         "Helper index Green field initialization",
     136         [ #  # ]:           0 :                         ippl::getRangePolicy(view, nghost),
     137         [ #  # ]:           0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     138                 :             :                             // go from local indices to global
     139                 :           0 :                             const int ig = i + ldom[0].first() - nghost;
     140                 :           0 :                             const int jg = j + ldom[1].first() - nghost;
     141                 :           0 :                             const int kg = k + ldom[2].first() - nghost;
     142                 :             : 
     143                 :             :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     144                 :           0 :                             const bool outsideN = (ig >= size / 2);
     145                 :           0 :                             view(i, j, k)       = (size * outsideN - ig) * (size * outsideN - ig);
     146                 :             : 
     147                 :             :                             // add 1.0 if at (0,0,0) to avoid singularity
     148   [ #  #  #  #  :           0 :                             const bool isOrig = ((ig == 0) && (jg == 0) && (kg == 0));
                   #  # ]
     149                 :           0 :                             view(i, j, k) += isOrig * 1.0;
     150                 :             :                         });
     151                 :           0 :                     break;
     152                 :           0 :                 case 1:
     153   [ #  #  #  # ]:           0 :                     Kokkos::parallel_for(
     154                 :             :                         "Helper index Green field initialization",
     155         [ #  # ]:           0 :                         ippl::getRangePolicy(view, nghost),
     156         [ #  # ]:           0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     157                 :             :                             // go from local indices to global
     158                 :           0 :                             const int jg = j + ldom[1].first() - nghost;
     159                 :             : 
     160                 :             :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     161                 :           0 :                             const bool outsideN = (jg >= size / 2);
     162                 :           0 :                             view(i, j, k)       = (size * outsideN - jg) * (size * outsideN - jg);
     163                 :             :                         });
     164                 :           0 :                     break;
     165                 :           0 :                 case 2:
     166   [ #  #  #  # ]:           0 :                     Kokkos::parallel_for(
     167                 :             :                         "Helper index Green field initialization",
     168         [ #  # ]:           0 :                         ippl::getRangePolicy(view, nghost),
     169         [ #  # ]:           0 :                         KOKKOS_LAMBDA(const int i, const int j, const int k) {
     170                 :             :                             // go from local indices to global
     171                 :           0 :                             const int kg = k + ldom[2].first() - nghost;
     172                 :             : 
     173                 :             :                             // assign (index)^2 if 0 <= index < N, and (2N-index)^2 elsewhere
     174                 :           0 :                             const bool outsideN = (kg >= size / 2);
     175                 :           0 :                             view(i, j, k)       = (size * outsideN - kg) * (size * outsideN - kg);
     176                 :             :                         });
     177                 :           0 :                     break;
     178                 :             :             }
     179                 :             :         }
     180                 :             : 
     181                 :             :         // call greensFunction and we will get the transformed G in the class attribute
     182                 :             :         // this is done in initialization so that we already have the precomputed fct
     183                 :             :         // for all timesteps (green's fct will only change if mesh size changes)
     184                 :             : 
     185         [ #  # ]:           0 :         greensFunction();
     186                 :           0 :     };
     187                 :             : 
     188                 :             :     /////////////////////////////////////////////////////////////////////////
     189                 :             :     // compute electric potential by solving Poisson's eq given a field rho and mesh spacings hr
     190                 :             :     template <typename FieldLHS, typename FieldRHS>
     191                 :           0 :     void P3MSolver<FieldLHS, FieldRHS>::solve() {
     192                 :             :         // get the output type (sol, grad, or sol & grad)
     193   [ #  #  #  # ]:           0 :         const int out = this->params_m.template get<int>("output_type");
     194                 :             : 
     195                 :             :         // set the mesh & spacing, which may change each timestep
     196                 :           0 :         mesh_mp = &(this->rhs_mp->get_mesh());
     197                 :             : 
     198                 :             :         // check whether the mesh spacing has changed with respect to the old one
     199                 :             :         // if yes, update and set green flag to true
     200                 :           0 :         bool green = false;
     201         [ #  # ]:           0 :         for (unsigned int i = 0; i < Dim; ++i) {
     202         [ #  # ]:           0 :             if (hr_m[i] != mesh_mp->getMeshSpacing(i)) {
     203                 :           0 :                 hr_m[i] = mesh_mp->getMeshSpacing(i);
     204                 :           0 :                 green   = true;
     205                 :             :             }
     206                 :             :         }
     207                 :             : 
     208                 :             :         // set mesh spacing on the other grids again
     209                 :           0 :         meshComplex_m->setMeshSpacing(hr_m);
     210                 :             : 
     211                 :             :         // forward FFT of the charge density field on doubled grid
     212         [ #  # ]:           0 :         rhotr_m = 0.0;
     213                 :           0 :         fft_m->transform(FORWARD, *(this->rhs_mp), rhotr_m);
     214                 :             : 
     215                 :             :         // call greensFunction to recompute if the mesh spacing has changed
     216         [ #  # ]:           0 :         if (green) {
     217                 :           0 :             greensFunction();
     218                 :             :         }
     219                 :             : 
     220                 :             :         // multiply FFT(rho2)*FFT(green)
     221                 :             :         // convolution becomes multiplication in FFT
     222   [ #  #  #  #  :           0 :         rhotr_m = -rhotr_m * grntr_m;
                   #  # ]
     223                 :             : 
     224                 :             :         using index_array_type = typename RangePolicy<Dim>::index_array_type;
     225   [ #  #  #  # ]:           0 :         if ((out == Base::GRAD) || (out == Base::SOL_AND_GRAD)) {
     226                 :             :             // Compute gradient in Fourier space and then
     227                 :             :             // take inverse FFT.
     228                 :             : 
     229                 :           0 :             const Trhs pi              = Kokkos::numbers::pi_v<Trhs>;
     230                 :           0 :             Kokkos::complex<Trhs> imag = {0.0, 1.0};
     231                 :             : 
     232         [ #  # ]:           0 :             auto view               = rhotr_m.getView();
     233                 :           0 :             const int nghost        = rhotr_m.getNghost();
     234         [ #  # ]:           0 :             auto tempview           = tempFieldComplex_m.getView();
     235         [ #  # ]:           0 :             auto viewRhs            = this->rhs_mp->getView();
     236         [ #  # ]:           0 :             auto viewLhs            = this->lhs_mp->getView();
     237                 :           0 :             const int nghostL       = this->lhs_mp->getNghost();
     238         [ #  # ]:           0 :             const auto& lDomComplex = layoutComplex_m->getLocalNDIndex();
     239                 :             : 
     240                 :             :             // define some member variables in local scope for the parallel_for
     241                 :           0 :             Vector_t hsize     = hr_m;
     242                 :           0 :             Vector<int, Dim> N = nr_m;
     243                 :           0 :             Vector_t origin    = mesh_mp->getOrigin();
     244                 :             : 
     245         [ #  # ]:           0 :             for (size_t gd = 0; gd < Dim; ++gd) {
     246   [ #  #  #  # ]:           0 :                 ippl::parallel_for(
     247                 :           0 :                     "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
     248   [ #  #  #  #  :           0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     249   [ #  #  #  # ]:           0 :                         Vector<int, Dim> iVec = args - nghost;
     250                 :             : 
     251         [ #  # ]:           0 :                         for (unsigned d = 0; d < Dim; ++d) {
     252                 :           0 :                             iVec[d] += lDomComplex[d].first();
     253                 :             :                         }
     254                 :             : 
     255         [ #  # ]:           0 :                         Vector_t kVec;
     256                 :             : 
     257         [ #  # ]:           0 :                         for (size_t d = 0; d < Dim; ++d) {
     258                 :           0 :                             const Trhs Len = N[d] * hsize[d];
     259                 :           0 :                             bool shift     = (iVec[d] > (N[d] / 2));
     260                 :           0 :                             bool notMid    = (iVec[d] != (N[d] / 2));
     261                 :             :                             // For the noMid part see
     262                 :             :                             // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
     263                 :           0 :                             kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
     264                 :             :                         }
     265                 :             : 
     266                 :           0 :                         Trhs Dr = 0;
     267         [ #  # ]:           0 :                         for (unsigned d = 0; d < Dim; ++d) {
     268                 :           0 :                             Dr += kVec[d] * kVec[d];
     269                 :             :                         }
     270                 :             : 
     271   [ #  #  #  # ]:           0 :                         apply(tempview, args) = apply(view, args);
     272                 :             : 
     273                 :           0 :                         bool isNotZero = (Dr != 0.0);
     274                 :             : 
     275         [ #  # ]:           0 :                         apply(tempview, args) *= -(isNotZero * imag * kVec[gd]);
     276                 :           0 :                     });
     277                 :             : 
     278         [ #  # ]:           0 :                 fft_m->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
     279                 :             : 
     280   [ #  #  #  # ]:           0 :                 ippl::parallel_for(
     281                 :           0 :                     "Assign Gradient FFTPeriodicPoissonSolver", getRangePolicy(viewLhs, nghostL),
     282   [ #  #  #  #  :           0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
             #  #  #  # ]
     283                 :           0 :                         apply(viewLhs, args)[gd] = apply(viewRhs, args);
     284                 :             :                     });
     285                 :             :             }
     286                 :             : 
     287                 :             :             // normalization is double counted due to 2 transforms
     288   [ #  #  #  #  :           0 :             *(this->lhs_mp) = *(this->lhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
             #  #  #  # ]
     289                 :             :             // discretization of integral requires h^3 factor
     290   [ #  #  #  #  :           0 :             *(this->lhs_mp) = *(this->lhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
             #  #  #  # ]
     291                 :           0 :         }
     292                 :             : 
     293   [ #  #  #  # ]:           0 :         if ((out == Base::SOL) || (out == Base::SOL_AND_GRAD)) {
     294                 :             :             // inverse FFT of the product and store the electrostatic potential in rho2_mr
     295                 :           0 :             fft_m->transform(BACKWARD, *(this->rhs_mp), rhotr_m);
     296                 :             : 
     297                 :             :             // normalization is double counted due to 2 transforms
     298   [ #  #  #  #  :           0 :             *(this->rhs_mp) = *(this->rhs_mp) * nr_m[0] * nr_m[1] * nr_m[2];
             #  #  #  # ]
     299                 :             :             // discretization of integral requires h^3 factor
     300   [ #  #  #  #  :           0 :             *(this->rhs_mp) = *(this->rhs_mp) * hr_m[0] * hr_m[1] * hr_m[2];
             #  #  #  # ]
     301                 :             :         }
     302                 :           0 :     };
     303                 :             : 
     304                 :             :     ////////////////////////////////////////////////////////////////////////
     305                 :             :     // calculate FFT of the Green's function
     306                 :             : 
     307                 :             :     template <typename FieldLHS, typename FieldRHS>
     308                 :           0 :     void P3MSolver<FieldLHS, FieldRHS>::greensFunction() {
     309         [ #  # ]:           0 :         grn_m = 0.0;
     310                 :             : 
     311                 :             :         // define pi
     312                 :           0 :         Trhs pi = Kokkos::numbers::pi_v<Trhs>;
     313                 :             : 
     314                 :             :         // This alpha parameter is a choice for the Green's function
     315                 :             :         // it controls the "range" of the Green's function (e.g.
     316                 :             :         // for the P3M collision modelling method, it indicates
     317                 :             :         // the splitting between Particle-Particle interactions
     318                 :             :         // and the Particle-Mesh computations).
     319                 :           0 :         Trhs alpha = 1e6;
     320                 :             : 
     321                 :             :         // calculate square of the mesh spacing for each dimension
     322         [ #  # ]:           0 :         Vector_t hrsq(hr_m * hr_m);
     323                 :             : 
     324                 :             :         // use the grnIField_m helper field to compute Green's function
     325         [ #  # ]:           0 :         for (unsigned int i = 0; i < Dim; ++i) {
     326   [ #  #  #  #  :           0 :             grn_m = grn_m + grnIField_m[i] * hrsq[i];
                   #  # ]
     327                 :             :         }
     328                 :             : 
     329         [ #  # ]:           0 :         typename Field_t::view_type view = grn_m.getView();
     330                 :           0 :         const int nghost                 = grn_m.getNghost();
     331         [ #  # ]:           0 :         const auto& ldom                 = layout_mp->getLocalNDIndex();
     332                 :             : 
     333                 :             :         // Kokkos parallel for loop to find (0,0,0) point and regularize
     334   [ #  #  #  # ]:           0 :         Kokkos::parallel_for(
     335                 :           0 :             "Assign Green's function ", ippl::getRangePolicy(view, nghost),
     336   [ #  #  #  #  :           0 :             KOKKOS_LAMBDA(const int i, const int j, const int k) {
                   #  # ]
     337                 :             :                 // go from local indices to global
     338                 :           0 :                 const int ig = i + ldom[0].first() - nghost;
     339                 :           0 :                 const int jg = j + ldom[1].first() - nghost;
     340                 :           0 :                 const int kg = k + ldom[2].first() - nghost;
     341                 :             : 
     342   [ #  #  #  #  :           0 :                 const bool isOrig = (ig == 0 && jg == 0 && kg == 0);
                   #  # ]
     343                 :             : 
     344                 :           0 :                 Trhs r        = Kokkos::real(Kokkos::sqrt(view(i, j, k)));
     345                 :           0 :                 view(i, j, k) = (!isOrig) * (-1.0 / (4.0 * pi)) * (Kokkos::erf(alpha * r) / r);
     346                 :             :             });
     347                 :             : 
     348                 :             :         // perform the FFT of the Green's function for the convolution
     349         [ #  # ]:           0 :         fft_m->transform(FORWARD, grn_m, grntr_m);
     350                 :           0 :     };
     351                 :             : 
     352                 :             : }  // namespace ippl
        

Generated by: LCOV version 2.0-1