LCOV - code coverage report
Current view: top level - src/PoissonSolvers - FEMPoissonSolver.h (source / functions) Coverage Total Hit
Test: report.info Lines: 0.0 % 74 0
Test Date: 2025-05-21 16:07:51 Functions: 0.0 % 36 0
Branches: 0.0 % 160 0

             Branch data     Line data    Source code
       1                 :             : // Class FEMPoissonSolver
       2                 :             : //   Solves the poisson equation using finite element methods and Conjugate
       3                 :             : //   Gradient
       4                 :             : 
       5                 :             : #ifndef IPPL_FEMPOISSONSOLVER_H
       6                 :             : #define IPPL_FEMPOISSONSOLVER_H
       7                 :             : 
       8                 :             : #include "LinearSolvers/PCG.h"
       9                 :             : #include "Poisson.h"
      10                 :             : 
      11                 :             : namespace ippl {
      12                 :             : 
      13                 :             :     template <typename Tlhs, unsigned Dim, unsigned numElemDOFs>
      14                 :             :     struct EvalFunctor {
      15                 :             :         const Vector<Tlhs, Dim> DPhiInvT;
      16                 :             :         const Tlhs absDetDPhi;
      17                 :             : 
      18                 :           0 :         EvalFunctor(Vector<Tlhs, Dim> DPhiInvT, Tlhs absDetDPhi)
      19                 :           0 :             : DPhiInvT(DPhiInvT)
      20                 :           0 :             , absDetDPhi(absDetDPhi) {}
      21                 :             : 
      22                 :           0 :         KOKKOS_FUNCTION const auto operator()(
      23                 :             :             const size_t& i, const size_t& j,
      24                 :             :             const Vector<Vector<Tlhs, Dim>, numElemDOFs>& grad_b_q_k) const {
      25   [ #  #  #  #  :           0 :             return dot((DPhiInvT * grad_b_q_k[j]), (DPhiInvT * grad_b_q_k[i])).apply() * absDetDPhi;
             #  #  #  # ]
           [ #  #  #  #  
                   #  # ]
      26                 :             :         }
      27                 :             :     };
      28                 :             : 
      29                 :             :     /**
      30                 :             :      * @brief A solver for the poisson equation using finite element methods and
      31                 :             :      * Conjugate Gradient (CG)
      32                 :             :      *
      33                 :             :      * @tparam FieldLHS field type for the left hand side
      34                 :             :      * @tparam FieldRHS field type for the right hand side
      35                 :             :      */
      36                 :             :     template <typename FieldLHS, typename FieldRHS = FieldLHS, unsigned Order = 1, unsigned QuadNumNodes = 5>
      37                 :             :     class FEMPoissonSolver : public Poisson<FieldLHS, FieldRHS> {
      38                 :             :         constexpr static unsigned Dim = FieldLHS::dim;
      39                 :             :         using Tlhs                    = typename FieldLHS::value_type;
      40                 :             : 
      41                 :             :     public:
      42                 :             :         using Base = Poisson<FieldLHS, FieldRHS>;
      43                 :             :         using typename Base::lhs_type, typename Base::rhs_type;
      44                 :             :         using MeshType = typename FieldRHS::Mesh_t;
      45                 :             : 
      46                 :             :         // PCG (Preconditioned Conjugate Gradient) is the solver algorithm used
      47                 :             :         using PCGSolverAlgorithm_t =
      48                 :             :             CG<lhs_type, lhs_type, lhs_type, lhs_type, lhs_type, FieldLHS, FieldRHS>;
      49                 :             : 
      50                 :             :         // FEM Space types
      51                 :             :         using ElementType =
      52                 :             :             std::conditional_t<Dim == 1, ippl::EdgeElement<Tlhs>,
      53                 :             :                                std::conditional_t<Dim == 2, ippl::QuadrilateralElement<Tlhs>,
      54                 :             :                                                   ippl::HexahedralElement<Tlhs>>>;
      55                 :             : 
      56                 :             :         using QuadratureType = GaussJacobiQuadrature<Tlhs, QuadNumNodes, ElementType>;
      57                 :             : 
      58                 :             :         using LagrangeType = LagrangeSpace<Tlhs, Dim, Order, ElementType, QuadratureType, FieldLHS, FieldRHS>;
      59                 :             : 
      60                 :             :         // default constructor (compatibility with Alpine)
      61                 :             :         FEMPoissonSolver() 
      62                 :             :             : Base()
      63                 :             :             , refElement_m()
      64                 :             :             , quadrature_m(refElement_m, 0.0, 0.0)
      65                 :             :             , lagrangeSpace_m(*(new MeshType(NDIndex<Dim>(Vector<unsigned, Dim>(0)), Vector<Tlhs, Dim>(0),
      66                 :             :                                 Vector<Tlhs, Dim>(0))), refElement_m, quadrature_m)
      67                 :             :         {}
      68                 :             : 
      69                 :           0 :         FEMPoissonSolver(lhs_type& lhs, rhs_type& rhs)
      70                 :             :             : Base(lhs, rhs)
      71                 :             :             , refElement_m()
      72   [ #  #  #  # ]:           0 :             , quadrature_m(refElement_m, 0.0, 0.0)
      73   [ #  #  #  # ]:           0 :             , lagrangeSpace_m(rhs.get_mesh(), refElement_m, quadrature_m, rhs.getLayout()) {
      74                 :             :             static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
      75         [ #  # ]:           0 :             setDefaultParameters();
      76                 :             : 
      77                 :             :             // start a timer
      78   [ #  #  #  #  :           0 :             static IpplTimings::TimerRef init = IpplTimings::getTimer("initFEM");
             #  #  #  # ]
      79         [ #  # ]:           0 :             IpplTimings::startTimer(init);
      80                 :             :             
      81         [ #  # ]:           0 :             rhs.fillHalo();
      82                 :             : 
      83         [ #  # ]:           0 :             lagrangeSpace_m.evaluateLoadVector(rhs);
      84                 :             : 
      85         [ #  # ]:           0 :             rhs.fillHalo();
      86                 :             :             
      87         [ #  # ]:           0 :             IpplTimings::stopTimer(init);
      88                 :           0 :         }
      89                 :             : 
      90                 :           0 :         void setRhs(rhs_type& rhs) override {
      91                 :           0 :             Base::setRhs(rhs);
      92                 :             : 
      93                 :           0 :             lagrangeSpace_m.initialize(rhs.get_mesh(), rhs.getLayout());
      94                 :             : 
      95                 :           0 :             rhs.fillHalo();
      96                 :             : 
      97                 :           0 :             lagrangeSpace_m.evaluateLoadVector(rhs);
      98                 :             : 
      99                 :           0 :             rhs.fillHalo();
     100                 :           0 :         }
     101                 :             : 
     102                 :             :         /**
     103                 :             :          * @brief Solve the poisson equation using finite element methods.
     104                 :             :          * The problem is described by -laplace(lhs) = rhs
     105                 :             :          */
     106                 :           0 :         void solve() override {
     107                 :             :             // start a timer
     108   [ #  #  #  #  :           0 :             static IpplTimings::TimerRef solve = IpplTimings::getTimer("solve");
             #  #  #  # ]
     109         [ #  # ]:           0 :             IpplTimings::startTimer(solve);
     110                 :             : 
     111         [ #  # ]:           0 :             const Vector<size_t, Dim> zeroNdIndex = Vector<size_t, Dim>(0);
     112                 :             : 
     113                 :             :             // We can pass the zeroNdIndex here, since the transformation jacobian does not depend
     114                 :             :             // on translation
     115         [ #  # ]:           0 :             const auto firstElementVertexPoints =
     116                 :             :                 lagrangeSpace_m.getElementMeshVertexPoints(zeroNdIndex);
     117                 :             : 
     118                 :             :             // Compute Inverse Transpose Transformation Jacobian ()
     119         [ #  # ]:           0 :             const Vector<Tlhs, Dim> DPhiInvT =
     120                 :             :                 refElement_m.getInverseTransposeTransformationJacobian(firstElementVertexPoints);
     121                 :             : 
     122                 :             :             // Compute absolute value of the determinant of the transformation jacobian (|det D
     123                 :             :             // Phi_K|)
     124         [ #  # ]:           0 :             const Tlhs absDetDPhi = Kokkos::abs(
     125                 :             :                 refElement_m.getDeterminantOfTransformationJacobian(firstElementVertexPoints));
     126                 :             : 
     127                 :           0 :             EvalFunctor<Tlhs, Dim, this->lagrangeSpace_m.numElementDOFs> poissonEquationEval(
     128                 :             :                 DPhiInvT, absDetDPhi);
     129                 :             : 
     130                 :             :             // get BC type of our RHS
     131                 :           0 :             BConds<FieldRHS, Dim>& bcField = (this->rhs_mp)->getFieldBC();
     132         [ #  # ]:           0 :             FieldBC bcType = bcField[0]->getBCType();
     133                 :             : 
     134                 :           0 :             const auto algoOperator = [poissonEquationEval, &bcField, this](rhs_type field) -> lhs_type {
     135                 :             :                 // start a timer
     136   [ #  #  #  #  :           0 :                 static IpplTimings::TimerRef opTimer = IpplTimings::getTimer("operator");
             #  #  #  # ]
           [ #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     137                 :           0 :                 IpplTimings::startTimer(opTimer);
     138                 :             : 
     139                 :             :                 // set appropriate BCs for the field as the info gets lost in the CG iteration
     140                 :           0 :                 field.setFieldBC(bcField);
     141                 :             : 
     142                 :           0 :                 field.fillHalo();
     143                 :             : 
     144                 :           0 :                 auto return_field = lagrangeSpace_m.evaluateAx(field, poissonEquationEval);
     145                 :             : 
     146 [ #  # ][ #  #  :           0 :                 IpplTimings::stopTimer(opTimer);
             #  #  #  # ]
     147                 :             : 
     148                 :           0 :                 return return_field;
     149                 :           0 :             };
     150                 :             : 
     151         [ #  # ]:           0 :             pcg_algo_m.setOperator(algoOperator);
     152                 :             : 
     153                 :             :             // send boundary values to RHS (load vector) i.e. lifting (Dirichlet BCs)
     154         [ #  # ]:           0 :             if (bcType == CONSTANT_FACE) {
     155   [ #  #  #  # ]:           0 :                 *(this->rhs_mp) = *(this->rhs_mp) -
     156         [ #  # ]:           0 :                     lagrangeSpace_m.evaluateAx_lift(*(this->rhs_mp), poissonEquationEval);
     157                 :             :             }
     158                 :             : 
     159                 :             :             // start a timer
     160   [ #  #  #  #  :           0 :             static IpplTimings::TimerRef pcgTimer = IpplTimings::getTimer("pcg");
             #  #  #  # ]
     161         [ #  # ]:           0 :             IpplTimings::startTimer(pcgTimer);
     162                 :             : 
     163         [ #  # ]:           0 :             pcg_algo_m(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
     164                 :             : 
     165         [ #  # ]:           0 :             (this->lhs_mp)->fillHalo();
     166                 :             : 
     167         [ #  # ]:           0 :             IpplTimings::stopTimer(pcgTimer);
     168                 :             : 
     169   [ #  #  #  # ]:           0 :             int output = this->params_m.template get<int>("output_type");
     170         [ #  # ]:           0 :             if (output & Base::GRAD) {
     171   [ #  #  #  #  :           0 :                 *(this->grad_mp) = -grad(*(this->lhs_mp));
                   #  # ]
     172                 :             :             }
     173                 :             : 
     174         [ #  # ]:           0 :             IpplTimings::stopTimer(solve);
     175                 :           0 :         }
     176                 :             : 
     177                 :             :         /**
     178                 :             :          * Query how many iterations were required to obtain the solution
     179                 :             :          * the last time this solver was used
     180                 :             :          * @return Iteration count of last solve
     181                 :             :          */
     182                 :           0 :         int getIterationCount() { return pcg_algo_m.getIterationCount(); }
     183                 :             : 
     184                 :             :         /**
     185                 :             :          * Query the residue
     186                 :             :          * @return Residue norm from last solve
     187                 :             :          */
     188                 :           0 :         Tlhs getResidue() const { return pcg_algo_m.getResidue(); }
     189                 :             : 
     190                 :             :         /**
     191                 :             :          * Query the L2-norm error compared to a given (analytical) sol
     192                 :             :          * @return L2 error after last solve
     193                 :             :          */
     194                 :             :         template <typename F>
     195                 :           0 :         Tlhs getL2Error(const F& analytic) {
     196                 :           0 :             Tlhs error_norm = this->lagrangeSpace_m.computeErrorL2(*(this->lhs_mp), analytic);
     197                 :           0 :             return error_norm;
     198                 :             :         }
     199                 :             : 
     200                 :             :         /**
     201                 :             :          * Query the average of the solution
     202                 :             :          * @param vol Boolean indicating whether we divide by volume or not
     203                 :             :          * @return avg (offset for null space test cases if divided by volume)
     204                 :             :          */
     205                 :           0 :         Tlhs getAvg(bool Vol = false) {
     206                 :           0 :             Tlhs avg = this->lagrangeSpace_m.computeAvg(*(this->lhs_mp));
     207         [ #  # ]:           0 :             if (Vol) {
     208   [ #  #  #  # ]:           0 :                 lhs_type unit((this->lhs_mp)->get_mesh(), (this->lhs_mp)->getLayout());
     209         [ #  # ]:           0 :                 unit = 1.0;
     210         [ #  # ]:           0 :                 Tlhs vol = this->lagrangeSpace_m.computeAvg(unit);
     211                 :           0 :                 return avg/vol;
     212                 :           0 :             } else {
     213                 :           0 :                 return avg;
     214                 :             :             }
     215                 :             :         }
     216                 :             : 
     217                 :             :     protected:
     218                 :             :         PCGSolverAlgorithm_t pcg_algo_m;
     219                 :             : 
     220                 :           0 :         virtual void setDefaultParameters() override {
     221   [ #  #  #  # ]:           0 :             this->params_m.add("max_iterations", 1000);
     222   [ #  #  #  # ]:           0 :             this->params_m.add("tolerance", (Tlhs)1e-13);
     223                 :           0 :         }
     224                 :             : 
     225                 :             :         ElementType refElement_m;
     226                 :             :         QuadratureType quadrature_m;
     227                 :             :         LagrangeType lagrangeSpace_m;
     228                 :             :     };
     229                 :             : 
     230                 :             : }  // namespace ippl
     231                 :             : 
     232                 :             : #endif
        

Generated by: LCOV version 2.0-1