|             Line data    Source code 
       1              : // Class FEMPoissonSolver
       2              : //   Solves the poisson equation using finite element methods and Conjugate
       3              : //   Gradient
       4              : 
       5              : #ifndef IPPL_FEMPOISSONSOLVER_H
       6              : #define IPPL_FEMPOISSONSOLVER_H
       7              : 
       8              : #include "LinearSolvers/PCG.h"
       9              : #include "Poisson.h"
      10              : 
      11              : namespace ippl {
      12              : 
      13              :     template <typename Tlhs, unsigned Dim, unsigned numElemDOFs>
      14              :     struct EvalFunctor {
      15              :         const Vector<Tlhs, Dim> DPhiInvT;
      16              :         const Tlhs absDetDPhi;
      17              : 
      18            0 :         EvalFunctor(Vector<Tlhs, Dim> DPhiInvT, Tlhs absDetDPhi)
      19            0 :             : DPhiInvT(DPhiInvT)
      20            0 :             , absDetDPhi(absDetDPhi) {}
      21              : 
      22            0 :         KOKKOS_FUNCTION auto operator()(
      23              :             const size_t& i, const size_t& j,
      24              :             const Vector<Vector<Tlhs, Dim>, numElemDOFs>& grad_b_q_k) const {
      25            0 :             return dot((DPhiInvT * grad_b_q_k[j]), (DPhiInvT * grad_b_q_k[i])).apply() * absDetDPhi;
      26              :         }
      27              :     };
      28              : 
      29              :     /**
      30              :      * @brief A solver for the poisson equation using finite element methods and
      31              :      * Conjugate Gradient (CG)
      32              :      *
      33              :      * @tparam FieldLHS field type for the left hand side
      34              :      * @tparam FieldRHS field type for the right hand side
      35              :      */
      36              :     template <typename FieldLHS, typename FieldRHS = FieldLHS, unsigned Order = 1, unsigned QuadNumNodes = 5>
      37              :     class FEMPoissonSolver : public Poisson<FieldLHS, FieldRHS> {
      38              :         constexpr static unsigned Dim = FieldLHS::dim;
      39              :         using Tlhs                    = typename FieldLHS::value_type;
      40              : 
      41              :     public:
      42              :         using Base = Poisson<FieldLHS, FieldRHS>;
      43              :         using typename Base::lhs_type, typename Base::rhs_type;
      44              :         using MeshType = typename FieldRHS::Mesh_t;
      45              : 
      46              :         // PCG (Preconditioned Conjugate Gradient) is the solver algorithm used
      47              :         using PCGSolverAlgorithm_t =
      48              :             CG<lhs_type, lhs_type, lhs_type, lhs_type, lhs_type, FieldLHS, FieldRHS>;
      49              : 
      50              :         // FEM Space types
      51              :         using ElementType =
      52              :             std::conditional_t<Dim == 1, ippl::EdgeElement<Tlhs>,
      53              :                                std::conditional_t<Dim == 2, ippl::QuadrilateralElement<Tlhs>,
      54              :                                                   ippl::HexahedralElement<Tlhs>>>;
      55              : 
      56              :         using QuadratureType = GaussJacobiQuadrature<Tlhs, QuadNumNodes, ElementType>;
      57              : 
      58              :         using LagrangeType = LagrangeSpace<Tlhs, Dim, Order, ElementType, QuadratureType, FieldLHS, FieldRHS>;
      59              : 
      60              :         // default constructor (compatibility with Alpine)
      61              :         FEMPoissonSolver() 
      62              :             : Base()
      63              :             , refElement_m()
      64              :             , quadrature_m(refElement_m, 0.0, 0.0)
      65              :             , lagrangeSpace_m(*(new MeshType(NDIndex<Dim>(Vector<unsigned, Dim>(0)), Vector<Tlhs, Dim>(0),
      66              :                                 Vector<Tlhs, Dim>(0))), refElement_m, quadrature_m)
      67              :         {}
      68              : 
      69            0 :         FEMPoissonSolver(lhs_type& lhs, rhs_type& rhs)
      70              :             : Base(lhs, rhs)
      71              :             , refElement_m()
      72            0 :             , quadrature_m(refElement_m, 0.0, 0.0)
      73            0 :             , lagrangeSpace_m(rhs.get_mesh(), refElement_m, quadrature_m, rhs.getLayout()) {
      74              :             static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
      75            0 :             setDefaultParameters();
      76              : 
      77              :             // start a timer
      78            0 :             static IpplTimings::TimerRef init = IpplTimings::getTimer("initFEM");
      79            0 :             IpplTimings::startTimer(init);
      80              :             
      81            0 :             rhs.fillHalo();
      82              : 
      83            0 :             lagrangeSpace_m.evaluateLoadVector(rhs);
      84              : 
      85            0 :             rhs.fillHalo();
      86              :             
      87            0 :             IpplTimings::stopTimer(init);
      88            0 :         }
      89              : 
      90            0 :         void setRhs(rhs_type& rhs) override {
      91            0 :             Base::setRhs(rhs);
      92              : 
      93            0 :             lagrangeSpace_m.initialize(rhs.get_mesh(), rhs.getLayout());
      94              : 
      95            0 :             rhs.fillHalo();
      96              : 
      97            0 :             lagrangeSpace_m.evaluateLoadVector(rhs);
      98              : 
      99            0 :             rhs.fillHalo();
     100            0 :         }
     101              : 
     102              :         /**
     103              :          * @brief Solve the poisson equation using finite element methods.
     104              :          * The problem is described by -laplace(lhs) = rhs
     105              :          */
     106            0 :         void solve() override {
     107              :             // start a timer
     108            0 :             static IpplTimings::TimerRef solve = IpplTimings::getTimer("solve");
     109            0 :             IpplTimings::startTimer(solve);
     110              : 
     111            0 :             const Vector<size_t, Dim> zeroNdIndex = Vector<size_t, Dim>(0);
     112              : 
     113              :             // We can pass the zeroNdIndex here, since the transformation jacobian does not depend
     114              :             // on translation
     115            0 :             const auto firstElementVertexPoints =
     116              :                 lagrangeSpace_m.getElementMeshVertexPoints(zeroNdIndex);
     117              : 
     118              :             // Compute Inverse Transpose Transformation Jacobian ()
     119            0 :             const Vector<Tlhs, Dim> DPhiInvT =
     120              :                 refElement_m.getInverseTransposeTransformationJacobian(firstElementVertexPoints);
     121              : 
     122              :             // Compute absolute value of the determinant of the transformation jacobian (|det D
     123              :             // Phi_K|)
     124            0 :             const Tlhs absDetDPhi = Kokkos::abs(
     125              :                 refElement_m.getDeterminantOfTransformationJacobian(firstElementVertexPoints));
     126              : 
     127            0 :             EvalFunctor<Tlhs, Dim, LagrangeType::numElementDOFs> poissonEquationEval(
     128              :                 DPhiInvT, absDetDPhi);
     129              : 
     130              :             // get BC type of our RHS
     131            0 :             BConds<FieldRHS, Dim>& bcField = (this->rhs_mp)->getFieldBC();
     132            0 :             FieldBC bcType = bcField[0]->getBCType();
     133              : 
     134            0 :             const auto algoOperator = [poissonEquationEval, &bcField, this](rhs_type field) -> lhs_type {
     135              :                 // set appropriate BCs for the field as the info gets lost in the CG iteration
     136            0 :                 field.setFieldBC(bcField);
     137              : 
     138            0 :                 field.fillHalo();
     139              : 
     140            0 :                 auto return_field = lagrangeSpace_m.evaluateAx(field, poissonEquationEval);
     141              : 
     142            0 :                 return return_field;
     143              :             };
     144              : 
     145            0 :             pcg_algo_m.setOperator(algoOperator);
     146              : 
     147              :             // send boundary values to RHS (load vector) i.e. lifting (Dirichlet BCs)
     148            0 :             if (bcType == CONSTANT_FACE) {
     149            0 :                 *(this->rhs_mp) = *(this->rhs_mp) -
     150            0 :                     lagrangeSpace_m.evaluateAx_lift(*(this->rhs_mp), poissonEquationEval);
     151              :             }
     152              : 
     153              :             // start a timer
     154            0 :             static IpplTimings::TimerRef pcgTimer = IpplTimings::getTimer("pcg");
     155            0 :             IpplTimings::startTimer(pcgTimer);
     156              : 
     157            0 :             pcg_algo_m(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
     158              : 
     159            0 :             (this->lhs_mp)->fillHalo();
     160              : 
     161            0 :             IpplTimings::stopTimer(pcgTimer);
     162              : 
     163            0 :             int output = this->params_m.template get<int>("output_type");
     164            0 :             if (output & Base::GRAD) {
     165            0 :                 *(this->grad_mp) = -grad(*(this->lhs_mp));
     166              :             }
     167              : 
     168            0 :             IpplTimings::stopTimer(solve);
     169            0 :         }
     170              : 
     171              :         /**
     172              :          * Query how many iterations were required to obtain the solution
     173              :          * the last time this solver was used
     174              :          * @return Iteration count of last solve
     175              :          */
     176            0 :         int getIterationCount() { return pcg_algo_m.getIterationCount(); }
     177              : 
     178              :         /**
     179              :          * Query the residue
     180              :          * @return Residue norm from last solve
     181              :          */
     182            0 :         Tlhs getResidue() const { return pcg_algo_m.getResidue(); }
     183              : 
     184              :         /**
     185              :          * Query the L2-norm error compared to a given (analytical) sol
     186              :          * @return L2 error after last solve
     187              :          */
     188              :         template <typename F>
     189            0 :         Tlhs getL2Error(const F& analytic) {
     190            0 :             Tlhs error_norm = this->lagrangeSpace_m.computeErrorL2(*(this->lhs_mp), analytic);
     191            0 :             return error_norm;
     192              :         }
     193              : 
     194              :         /**
     195              :          * Query the average of the solution
     196              :          * @param vol Boolean indicating whether we divide by volume or not
     197              :          * @return avg (offset for null space test cases if divided by volume)
     198              :          */
     199            0 :         Tlhs getAvg(bool Vol = false) {
     200            0 :             Tlhs avg = this->lagrangeSpace_m.computeAvg(*(this->lhs_mp));
     201            0 :             if (Vol) {
     202            0 :                 lhs_type unit((this->lhs_mp)->get_mesh(), (this->lhs_mp)->getLayout());
     203            0 :                 unit = 1.0;
     204            0 :                 Tlhs vol = this->lagrangeSpace_m.computeAvg(unit);
     205            0 :                 return avg/vol;
     206            0 :             } else {
     207            0 :                 return avg;
     208              :             }
     209              :         }
     210              : 
     211              :     protected:
     212              :         PCGSolverAlgorithm_t pcg_algo_m;
     213              : 
     214            0 :         virtual void setDefaultParameters() override {
     215            0 :             this->params_m.add("max_iterations", 1000);
     216            0 :             this->params_m.add("tolerance", (Tlhs)1e-13);
     217            0 :         }
     218              : 
     219              :         ElementType refElement_m;
     220              :         QuadratureType quadrature_m;
     221              :         LagrangeType lagrangeSpace_m;
     222              :     };
     223              : 
     224              : }  // namespace ippl
     225              : 
     226              : #endif
         |