LCOV - code coverage report
Current view: top level - src/PoissonSolvers - PoissonCG.h (source / functions) Coverage Total Hit
Test: report.info Lines: 0.0 % 68 0
Test Date: 2025-05-21 12:58:26 Functions: 0.0 % 18 0
Branches: 0.0 % 136 0

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // Class PoissonCG
       3                 :             : //   Solves the Poisson problem with the CG algorithm
       4                 :             : //
       5                 :             : 
       6                 :             : #ifndef IPPL_POISSON_CG_H
       7                 :             : #define IPPL_POISSON_CG_H
       8                 :             : 
       9                 :             : #include "LaplaceHelpers.h"
      10                 :             : #include "LinearSolvers/PCG.h"
      11                 :             : #include "Poisson.h"
      12                 :             : namespace ippl {
      13                 :             : 
      14                 :             : // Expands to a lambda that acts as a wrapper for a differential operator
      15                 :             : // fun: the function for which to create the wrapper, such as ippl::laplace
      16                 :             : // type: the argument type, which should match the LHS type for the solver
      17                 :             : #define IPPL_SOLVER_OPERATOR_WRAPPER(fun, type) \
      18                 :             :     [](type arg) {                              \
      19                 :             :         return fun(arg);                        \
      20                 :             :     }
      21                 :             : 
      22                 :             :     template <typename FieldLHS, typename FieldRHS = FieldLHS>
      23                 :             :     class PoissonCG : public Poisson<FieldLHS, FieldRHS> {
      24                 :             :         using Tlhs = typename FieldLHS::value_type;
      25                 :             : 
      26                 :             :     public:
      27                 :             :         using Base                    = Poisson<FieldLHS, FieldRHS>;
      28                 :             :         constexpr static unsigned Dim = FieldLHS::dim;
      29                 :             :         using typename Base::lhs_type, typename Base::rhs_type;
      30                 :             :         using OperatorRet        = UnaryMinus<detail::meta_laplace<lhs_type>>;
      31                 :             :         using LowerRet           = UnaryMinus<detail::meta_lower_laplace<lhs_type>>;
      32                 :             :         using UpperRet           = UnaryMinus<detail::meta_upper_laplace<lhs_type>>;
      33                 :             :         using UpperAndLowerRet   = UnaryMinus<detail::meta_upper_and_lower_laplace<lhs_type>>;
      34                 :             :         using InverseDiagonalRet = double;
      35                 :             :         using DiagRet            = double;
      36                 :             : 
      37                 :           0 :         PoissonCG()
      38                 :             :             : Base()
      39                 :           0 :             , algo_m(nullptr) {
      40                 :             :             static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
      41         [ #  # ]:           0 :             setDefaultParameters();
      42                 :           0 :         }
      43                 :             : 
      44                 :             :         PoissonCG(lhs_type& lhs, rhs_type& rhs)
      45                 :             :             : Base(lhs, rhs)
      46                 :             :             , algo_m(nullptr) {
      47                 :             :             static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
      48                 :             :             setDefaultParameters();
      49                 :             :         }
      50                 :             : 
      51                 :           0 :         void setSolver(lhs_type lhs) {
      52   [ #  #  #  # ]:           0 :             std::string solver_type            = this->params_m.template get<std::string>("solver");
      53                 :           0 :             typename lhs_type::Mesh_t mesh     = lhs.get_mesh();
      54   [ #  #  #  # ]:           0 :             typename lhs_type::Layout_t layout = lhs.getLayout();
      55                 :           0 :             double beta                        = 0;
      56                 :           0 :             double alpha                       = 0;
      57   [ #  #  #  # ]:           0 :             if (solver_type == "preconditioned") {
      58                 :           0 :                 algo_m = std::move(
      59                 :             :                     std::make_unique<PCG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
      60         [ #  # ]:           0 :                                          InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
      61         [ #  # ]:           0 :                 std::string preconditioner_type =
      62         [ #  # ]:           0 :                     this->params_m.template get<std::string>("preconditioner_type");
      63   [ #  #  #  # ]:           0 :                 int level    = this->params_m.template get<int>("newton_level");
      64   [ #  #  #  # ]:           0 :                 int degree   = this->params_m.template get<int>("chebyshev_degree");
      65   [ #  #  #  # ]:           0 :                 int inner    = this->params_m.template get<int>("gauss_seidel_inner_iterations");
      66   [ #  #  #  # ]:           0 :                 int outer    = this->params_m.template get<int>("gauss_seidel_outer_iterations");
      67   [ #  #  #  # ]:           0 :                 double omega = this->params_m.template get<double>("ssor_omega");
      68                 :             :                 int richardson_iterations =
      69   [ #  #  #  # ]:           0 :                     this->params_m.template get<int>("richardson_iterations");
      70   [ #  #  #  # ]:           0 :                 int communication = this->params_m.template get<int>("communication");
      71                 :             :                 // Analytical eigenvalues for the d dimensional laplace operator
      72                 :             :                 // Going brute force through all possible eigenvalues seems to be the only way to
      73                 :             :                 // find max and min
      74                 :             : 
      75                 :             :                 unsigned long n;
      76                 :             :                 double h;
      77         [ #  # ]:           0 :                 for (unsigned int d = 0; d < Dim; ++d) {
      78                 :           0 :                     n                = mesh.getGridsize(d);
      79         [ #  # ]:           0 :                     h                = mesh.getMeshSpacing(d);
      80                 :           0 :                     double local_min = 4 / std::pow(h, 2);  // theoretical maximum
      81                 :           0 :                     double local_max = 0;
      82                 :             :                     double test;
      83         [ #  # ]:           0 :                     for (unsigned int i = 1; i < n; ++i) {
      84                 :           0 :                         test = 4. / std::pow(h, 2) * std::pow(std::sin(i * M_PI * h / 2.), 2);
      85         [ #  # ]:           0 :                         if (test > local_max) {
      86                 :           0 :                             local_max = test;
      87                 :             :                         }
      88         [ #  # ]:           0 :                         if (test < local_min) {
      89                 :           0 :                             local_min = test;
      90                 :             :                         }
      91                 :             :                     }
      92                 :           0 :                     beta += local_max;
      93                 :           0 :                     alpha += local_min;
      94                 :             :                 }
      95         [ #  # ]:           0 :                 if (communication) {
      96   [ #  #  #  # ]:           0 :                     algo_m->setPreconditioner(
      97   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
      98   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace, lhs_type),
      99   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace, lhs_type),
     100   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace, lhs_type),
     101                 :           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
     102                 :           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
     103                 :             :                         preconditioner_type, level, degree, richardson_iterations, inner, outer,
     104                 :             :                         omega);
     105                 :             :                 } else {
     106   [ #  #  #  # ]:           0 :                     algo_m->setPreconditioner(
     107   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
     108   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace_no_comm, lhs_type),
     109   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace_no_comm, lhs_type),
     110   [ #  #  #  # ]:           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace_no_comm, lhs_type),
     111                 :           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
     112                 :           0 :                         IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
     113                 :             :                         preconditioner_type, level, degree, richardson_iterations, inner, outer,
     114                 :             :                         omega);
     115                 :             :                 }
     116                 :           0 :             } else {
     117                 :           0 :                 algo_m = std::move(
     118                 :             :                     std::make_unique<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
     119         [ #  # ]:           0 :                                         InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
     120                 :             :             }
     121                 :           0 :         }
     122                 :             : 
     123                 :           0 :         void solve() override {
     124   [ #  #  #  # ]:           0 :             setSolver(*(this->lhs_mp));
     125   [ #  #  #  #  :           0 :             algo_m->setOperator(IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type));
                   #  # ]
     126                 :           0 :             algo_m->operator()(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
     127                 :             : 
     128   [ #  #  #  # ]:           0 :             int output = this->params_m.template get<int>("output_type");
     129         [ #  # ]:           0 :             if (output & Base::GRAD) {
     130   [ #  #  #  #  :           0 :                 *(this->grad_mp) = -grad(*(this->lhs_mp));
                   #  # ]
     131                 :             :             }
     132                 :           0 :         }
     133                 :             : 
     134                 :             :         /*!
     135                 :             :          * Query how many iterations were required to obtain the solution
     136                 :             :          * the last time this solver was used
     137                 :             :          * @return Iteration count of last solve
     138                 :             :          */
     139                 :           0 :         int getIterationCount() { return algo_m->getIterationCount(); }
     140                 :             : 
     141                 :             :         /*!
     142                 :             :          * Query the residue
     143                 :             :          * @return Residue norm from last solve
     144                 :             :          */
     145                 :             :         Tlhs getResidue() const { return algo_m->getResidue(); }
     146                 :             : 
     147                 :             :     protected:
     148                 :             :         std::unique_ptr<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet, InverseDiagonalRet,
     149                 :             :                            DiagRet, FieldLHS, FieldRHS>>
     150                 :             :             algo_m;
     151                 :             : 
     152                 :           0 :         void setDefaultParameters() override {
     153   [ #  #  #  # ]:           0 :             this->params_m.add("max_iterations", 2000);
     154   [ #  #  #  # ]:           0 :             this->params_m.add("tolerance", (Tlhs)1e-13);
     155   [ #  #  #  # ]:           0 :             this->params_m.add("solver", "non-preconditioned");
     156                 :           0 :         }
     157                 :             :     };
     158                 :             : 
     159                 :             : }  // namespace ippl
     160                 :             : 
     161                 :             : #endif
        

Generated by: LCOV version 2.0-1