LCOV - code coverage report
Current view: top level - src/PoissonSolvers - FFTPeriodicPoissonSolver.hpp (source / functions) Coverage Total Hit
Test: report.info Lines: 0.0 % 98 0
Test Date: 2025-05-12 09:25:18 Functions: 0.0 % 6 0
Branches: 0.0 % 147 0

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // Class FFTPeriodicPoissonSolver
       3                 :             : //   Solves the periodic Poisson problem using Fourier transforms
       4                 :             : //   cf. https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 5
       5                 :             : //
       6                 :             : //
       7                 :             : 
       8                 :             : namespace ippl {
       9                 :             : 
      10                 :             :     template <typename FieldLHS, typename FieldRHS>
      11                 :           0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::setRhs(rhs_type& rhs) {
      12                 :           0 :         Base::setRhs(rhs);
      13                 :           0 :         initialize();
      14                 :           0 :     }
      15                 :             : 
      16                 :             :     template <typename FieldLHS, typename FieldRHS>
      17                 :           0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::initialize() {
      18         [ #  # ]:           0 :         const Layout_t& layout_r = this->rhs_mp->getLayout();
      19                 :           0 :         domain_m                 = layout_r.getDomain();
      20                 :             : 
      21                 :           0 :         NDIndex<Dim> domainComplex;
      22                 :             : 
      23         [ #  # ]:           0 :         vector_type hComplex;
      24         [ #  # ]:           0 :         vector_type originComplex;
      25                 :             : 
      26                 :           0 :         std::array<bool, Dim> isParallel = layout_r.isParallel();
      27         [ #  # ]:           0 :         for (unsigned d = 0; d < Dim; ++d) {
      28                 :           0 :             hComplex[d]      = 1.0;
      29                 :           0 :             originComplex[d] = 0.0;
      30                 :             : 
      31   [ #  #  #  #  :           0 :             if (this->params_m.template get<int>("r2c_direction") == (int)d) {
                   #  # ]
      32                 :           0 :                 domainComplex[d] = Index(domain_m[d].length() / 2 + 1);
      33                 :             :             } else {
      34                 :           0 :                 domainComplex[d] = Index(domain_m[d].length());
      35                 :             :             }
      36                 :             :         }
      37                 :             : 
      38         [ #  # ]:           0 :         layoutComplex_mp = std::make_shared<Layout_t>(layout_r.comm, domainComplex, isParallel);
      39                 :             : 
      40         [ #  # ]:           0 :         mesh_type meshComplex(domainComplex, hComplex, originComplex);
      41                 :             : 
      42         [ #  # ]:           0 :         fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
      43                 :             : 
      44   [ #  #  #  #  :           0 :         if (this->params_m.template get<int>("output_type") == Base::GRAD) {
                   #  # ]
      45         [ #  # ]:           0 :             tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
      46                 :             :         }
      47                 :             : 
      48         [ #  # ]:           0 :         fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
      49         [ #  # ]:           0 :         fft_mp->warmup(*this->rhs_mp, fieldComplex_m);  // warmup the FFT object
      50                 :           0 :     }
      51                 :             : 
      52                 :             :     template <typename FieldLHS, typename FieldRHS>
      53                 :           0 :     void FFTPeriodicPoissonSolver<FieldLHS, FieldRHS>::solve() {
      54         [ #  # ]:           0 :         fft_mp->transform(FORWARD, *this->rhs_mp, fieldComplex_m);
      55                 :             : 
      56         [ #  # ]:           0 :         auto view        = fieldComplex_m.getView();
      57                 :           0 :         const int nghost = fieldComplex_m.getNghost();
      58                 :             : 
      59                 :           0 :         scalar_type pi            = Kokkos::numbers::pi_v<scalar_type>;
      60                 :           0 :         const mesh_type& mesh     = this->rhs_mp->get_mesh();
      61         [ #  # ]:           0 :         const auto& lDomComplex   = layoutComplex_mp->getLocalNDIndex();
      62                 :             :         using vector_type         = typename mesh_type::vector_type;
      63                 :           0 :         const vector_type& origin = mesh.getOrigin();
      64                 :           0 :         const vector_type& hx     = mesh.getMeshSpacing();
      65                 :             : 
      66         [ #  # ]:           0 :         vector_type rmax;
      67         [ #  # ]:           0 :         Vector<int, Dim> N;
      68         [ #  # ]:           0 :         for (size_t d = 0; d < Dim; ++d) {
      69                 :           0 :             N[d]    = domain_m[d].length();
      70                 :           0 :             rmax[d] = origin[d] + (N[d] * hx[d]);
      71                 :             :         }
      72                 :             : 
      73                 :             :         // Based on output_type calculate either solution
      74                 :             :         // or gradient
      75                 :             : 
      76                 :             :         using index_array_type = typename RangePolicy<Dim>::index_array_type;
      77   [ #  #  #  #  :           0 :         switch (this->params_m.template get<int>("output_type")) {
                #  #  # ]
      78                 :           0 :             case Base::SOL: {
      79   [ #  #  #  # ]:           0 :                 ippl::parallel_for(
      80                 :           0 :                     "Solution FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
      81   [ #  #  #  #  :           0 :                     KOKKOS_LAMBDA(const index_array_type& args) {
          #  #  #  #  #  
                      # ]
      82   [ #  #  #  # ]:           0 :                         Vector<int, Dim> iVec = args - nghost;
      83         [ #  # ]:           0 :                         for (unsigned d = 0; d < Dim; ++d) {
      84                 :           0 :                             iVec[d] += lDomComplex[d].first();
      85                 :             :                         }
      86                 :             : 
      87         [ #  # ]:           0 :                         Vector_t kVec;
      88                 :             : 
      89         [ #  # ]:           0 :                         for (size_t d = 0; d < Dim; ++d) {
      90                 :           0 :                             const scalar_type Len = rmax[d] - origin[d];
      91                 :           0 :                             bool shift            = (iVec[d] > (N[d] / 2));
      92                 :           0 :                             kVec[d]               = 2 * pi / Len * (iVec[d] - shift * N[d]);
      93                 :             :                         }
      94                 :             : 
      95                 :           0 :                         scalar_type Dr = 0;
      96         [ #  # ]:           0 :                         for (unsigned d = 0; d < Dim; ++d) {
      97                 :           0 :                             Dr += kVec[d] * kVec[d];
      98                 :             :                         }
      99                 :             : 
     100                 :           0 :                         bool isNotZero     = (Dr != 0.0);
     101                 :           0 :                         scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
     102                 :             : 
     103         [ #  # ]:           0 :                         apply(view, args) *= factor;
     104                 :           0 :                     });
     105                 :             : 
     106         [ #  # ]:           0 :                 fft_mp->transform(BACKWARD, *this->rhs_mp, fieldComplex_m);
     107                 :             : 
     108                 :           0 :                 break;
     109                 :             :             }
     110                 :           0 :             case Base::GRAD: {
     111                 :             :                 // Compute gradient in Fourier space and then
     112                 :             :                 // take inverse FFT.
     113                 :             : 
     114                 :           0 :                 Complex_t imag    = {0.0, 1.0};
     115         [ #  # ]:           0 :                 auto tempview     = tempFieldComplex_m.getView();
     116         [ #  # ]:           0 :                 auto viewRhs      = this->rhs_mp->getView();
     117         [ #  # ]:           0 :                 auto viewLhs      = this->lhs_mp->getView();
     118                 :           0 :                 const int nghostL = this->lhs_mp->getNghost();
     119                 :             : 
     120         [ #  # ]:           0 :                 for (size_t gd = 0; gd < Dim; ++gd) {
     121   [ #  #  #  # ]:           0 :                     ippl::parallel_for(
     122                 :           0 :                         "Gradient FFTPeriodicPoissonSolver", getRangePolicy(view, nghost),
     123   [ #  #  #  #  :           0 :                         KOKKOS_LAMBDA(const index_array_type& args) {
          #  #  #  #  #  
          #  #  #  #  #  
                   #  # ]
     124   [ #  #  #  # ]:           0 :                             Vector<int, Dim> iVec = args - nghost;
     125         [ #  # ]:           0 :                             for (unsigned d = 0; d < Dim; ++d) {
     126                 :           0 :                                 iVec[d] += lDomComplex[d].first();
     127                 :             :                             }
     128                 :             : 
     129         [ #  # ]:           0 :                             Vector_t kVec;
     130                 :             : 
     131         [ #  # ]:           0 :                             for (size_t d = 0; d < Dim; ++d) {
     132                 :           0 :                                 const scalar_type Len = rmax[d] - origin[d];
     133                 :           0 :                                 bool shift            = (iVec[d] > (N[d] / 2));
     134                 :           0 :                                 bool notMid           = (iVec[d] != (N[d] / 2));
     135                 :             :                                 // For the noMid part see
     136                 :             :                                 // https://math.mit.edu/~stevenj/fft-deriv.pdf Algorithm 1
     137                 :           0 :                                 kVec[d] = notMid * 2 * pi / Len * (iVec[d] - shift * N[d]);
     138                 :             :                             }
     139                 :             : 
     140                 :           0 :                             scalar_type Dr = 0;
     141         [ #  # ]:           0 :                             for (unsigned d = 0; d < Dim; ++d) {
     142                 :           0 :                                 Dr += kVec[d] * kVec[d];
     143                 :             :                             }
     144                 :             : 
     145   [ #  #  #  # ]:           0 :                             apply(tempview, args) = apply(view, args);
     146                 :             : 
     147                 :           0 :                             bool isNotZero     = (Dr != 0.0);
     148                 :           0 :                             scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
     149                 :             : 
     150         [ #  # ]:           0 :                             apply(tempview, args) *= -(imag * kVec[gd] * factor);
     151                 :           0 :                         });
     152                 :             : 
     153         [ #  # ]:           0 :                     fft_mp->transform(BACKWARD, *this->rhs_mp, tempFieldComplex_m);
     154                 :             : 
     155   [ #  #  #  # ]:           0 :                     ippl::parallel_for(
     156                 :             :                         "Assign Gradient FFTPeriodicPoissonSolver",
     157                 :           0 :                         getRangePolicy(viewLhs, nghostL),
     158   [ #  #  #  #  :           0 :                         KOKKOS_LAMBDA(const index_array_type& args) {
             #  #  #  # ]
     159                 :           0 :                             apply(viewLhs, args)[gd] = apply(viewRhs, args);
     160                 :             :                         });
     161                 :             :                 }
     162                 :             : 
     163                 :           0 :                 break;
     164                 :           0 :             }
     165                 :             : 
     166                 :           0 :             default:
     167   [ #  #  #  #  :           0 :                 throw IpplException("FFTPeriodicPoissonSolver::solve", "Unrecognized output_type");
                   #  # ]
     168                 :             :         }
     169                 :           0 :     }
     170                 :             : }  // namespace ippl
        

Generated by: LCOV version 2.0-1