Branch data Line data Source code
1 : : // Class FEMPoissonSolver
2 : : // Solves the poisson equation using finite element methods and Conjugate
3 : : // Gradient
4 : :
5 : : #ifndef IPPL_FEMPOISSONSOLVER_H
6 : : #define IPPL_FEMPOISSONSOLVER_H
7 : :
8 : : #include "LinearSolvers/PCG.h"
9 : : #include "Poisson.h"
10 : :
11 : : namespace ippl {
12 : :
13 : : template <typename Tlhs, unsigned Dim, unsigned numElemDOFs>
14 : : struct EvalFunctor {
15 : : const Vector<Tlhs, Dim> DPhiInvT;
16 : : const Tlhs absDetDPhi;
17 : :
18 : 0 : EvalFunctor(Vector<Tlhs, Dim> DPhiInvT, Tlhs absDetDPhi)
19 : 0 : : DPhiInvT(DPhiInvT)
20 : 0 : , absDetDPhi(absDetDPhi) {}
21 : :
22 : 0 : KOKKOS_FUNCTION const auto operator()(
23 : : const size_t& i, const size_t& j,
24 : : const Vector<Vector<Tlhs, Dim>, numElemDOFs>& grad_b_q_k) const {
25 [ # # # # : 0 : return dot((DPhiInvT * grad_b_q_k[j]), (DPhiInvT * grad_b_q_k[i])).apply() * absDetDPhi;
# # # # ]
[ # # # #
# # ]
26 : : }
27 : : };
28 : :
29 : : /**
30 : : * @brief A solver for the poisson equation using finite element methods and
31 : : * Conjugate Gradient (CG)
32 : : *
33 : : * @tparam FieldLHS field type for the left hand side
34 : : * @tparam FieldRHS field type for the right hand side
35 : : */
36 : : template <typename FieldLHS, typename FieldRHS = FieldLHS, unsigned Order = 1, unsigned QuadNumNodes = 5>
37 : : class FEMPoissonSolver : public Poisson<FieldLHS, FieldRHS> {
38 : : constexpr static unsigned Dim = FieldLHS::dim;
39 : : using Tlhs = typename FieldLHS::value_type;
40 : :
41 : : public:
42 : : using Base = Poisson<FieldLHS, FieldRHS>;
43 : : using typename Base::lhs_type, typename Base::rhs_type;
44 : : using MeshType = typename FieldRHS::Mesh_t;
45 : :
46 : : // PCG (Preconditioned Conjugate Gradient) is the solver algorithm used
47 : : using PCGSolverAlgorithm_t =
48 : : CG<lhs_type, lhs_type, lhs_type, lhs_type, lhs_type, FieldLHS, FieldRHS>;
49 : :
50 : : // FEM Space types
51 : : using ElementType =
52 : : std::conditional_t<Dim == 1, ippl::EdgeElement<Tlhs>,
53 : : std::conditional_t<Dim == 2, ippl::QuadrilateralElement<Tlhs>,
54 : : ippl::HexahedralElement<Tlhs>>>;
55 : :
56 : : using QuadratureType = GaussJacobiQuadrature<Tlhs, QuadNumNodes, ElementType>;
57 : :
58 : : using LagrangeType = LagrangeSpace<Tlhs, Dim, Order, ElementType, QuadratureType, FieldLHS, FieldRHS>;
59 : :
60 : : // default constructor (compatibility with Alpine)
61 : : FEMPoissonSolver()
62 : : : Base()
63 : : , refElement_m()
64 : : , quadrature_m(refElement_m, 0.0, 0.0)
65 : : , lagrangeSpace_m(*(new MeshType(NDIndex<Dim>(Vector<unsigned, Dim>(0)), Vector<Tlhs, Dim>(0),
66 : : Vector<Tlhs, Dim>(0))), refElement_m, quadrature_m)
67 : : {}
68 : :
69 : 0 : FEMPoissonSolver(lhs_type& lhs, rhs_type& rhs)
70 : : : Base(lhs, rhs)
71 : : , refElement_m()
72 [ # # # # ]: 0 : , quadrature_m(refElement_m, 0.0, 0.0)
73 [ # # # # ]: 0 : , lagrangeSpace_m(rhs.get_mesh(), refElement_m, quadrature_m, rhs.getLayout()) {
74 : : static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
75 [ # # ]: 0 : setDefaultParameters();
76 : :
77 : : // start a timer
78 [ # # # # : 0 : static IpplTimings::TimerRef init = IpplTimings::getTimer("initFEM");
# # # # ]
79 [ # # ]: 0 : IpplTimings::startTimer(init);
80 : :
81 [ # # ]: 0 : rhs.fillHalo();
82 : :
83 [ # # ]: 0 : lagrangeSpace_m.evaluateLoadVector(rhs);
84 : :
85 [ # # ]: 0 : rhs.fillHalo();
86 : :
87 [ # # ]: 0 : IpplTimings::stopTimer(init);
88 : 0 : }
89 : :
90 : 0 : void setRhs(rhs_type& rhs) override {
91 : 0 : Base::setRhs(rhs);
92 : :
93 : 0 : lagrangeSpace_m.initialize(rhs.get_mesh(), rhs.getLayout());
94 : :
95 : 0 : rhs.fillHalo();
96 : :
97 : 0 : lagrangeSpace_m.evaluateLoadVector(rhs);
98 : :
99 : 0 : rhs.fillHalo();
100 : 0 : }
101 : :
102 : : /**
103 : : * @brief Solve the poisson equation using finite element methods.
104 : : * The problem is described by -laplace(lhs) = rhs
105 : : */
106 : 0 : void solve() override {
107 : : // start a timer
108 [ # # # # : 0 : static IpplTimings::TimerRef solve = IpplTimings::getTimer("solve");
# # # # ]
109 [ # # ]: 0 : IpplTimings::startTimer(solve);
110 : :
111 [ # # ]: 0 : const Vector<size_t, Dim> zeroNdIndex = Vector<size_t, Dim>(0);
112 : :
113 : : // We can pass the zeroNdIndex here, since the transformation jacobian does not depend
114 : : // on translation
115 [ # # ]: 0 : const auto firstElementVertexPoints =
116 : : lagrangeSpace_m.getElementMeshVertexPoints(zeroNdIndex);
117 : :
118 : : // Compute Inverse Transpose Transformation Jacobian ()
119 [ # # ]: 0 : const Vector<Tlhs, Dim> DPhiInvT =
120 : : refElement_m.getInverseTransposeTransformationJacobian(firstElementVertexPoints);
121 : :
122 : : // Compute absolute value of the determinant of the transformation jacobian (|det D
123 : : // Phi_K|)
124 [ # # ]: 0 : const Tlhs absDetDPhi = Kokkos::abs(
125 : : refElement_m.getDeterminantOfTransformationJacobian(firstElementVertexPoints));
126 : :
127 : 0 : EvalFunctor<Tlhs, Dim, this->lagrangeSpace_m.numElementDOFs> poissonEquationEval(
128 : : DPhiInvT, absDetDPhi);
129 : :
130 : : // get BC type of our RHS
131 : 0 : BConds<FieldRHS, Dim>& bcField = (this->rhs_mp)->getFieldBC();
132 [ # # ]: 0 : FieldBC bcType = bcField[0]->getBCType();
133 : :
134 : 0 : const auto algoOperator = [poissonEquationEval, &bcField, this](rhs_type field) -> lhs_type {
135 : : // start a timer
136 [ # # # # : 0 : static IpplTimings::TimerRef opTimer = IpplTimings::getTimer("operator");
# # # # ]
[ # # # #
# # # # #
# # # # #
# # # # #
# # # #
# ]
137 : 0 : IpplTimings::startTimer(opTimer);
138 : :
139 : : // set appropriate BCs for the field as the info gets lost in the CG iteration
140 : 0 : field.setFieldBC(bcField);
141 : :
142 : 0 : field.fillHalo();
143 : :
144 : 0 : auto return_field = lagrangeSpace_m.evaluateAx(field, poissonEquationEval);
145 : :
146 [ # # ][ # # : 0 : IpplTimings::stopTimer(opTimer);
# # # # ]
147 : :
148 : 0 : return return_field;
149 : 0 : };
150 : :
151 [ # # ]: 0 : pcg_algo_m.setOperator(algoOperator);
152 : :
153 : : // send boundary values to RHS (load vector) i.e. lifting (Dirichlet BCs)
154 [ # # ]: 0 : if (bcType == CONSTANT_FACE) {
155 [ # # # # ]: 0 : *(this->rhs_mp) = *(this->rhs_mp) -
156 [ # # ]: 0 : lagrangeSpace_m.evaluateAx_lift(*(this->rhs_mp), poissonEquationEval);
157 : : }
158 : :
159 : : // start a timer
160 [ # # # # : 0 : static IpplTimings::TimerRef pcgTimer = IpplTimings::getTimer("pcg");
# # # # ]
161 [ # # ]: 0 : IpplTimings::startTimer(pcgTimer);
162 : :
163 [ # # ]: 0 : pcg_algo_m(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
164 : :
165 [ # # ]: 0 : (this->lhs_mp)->fillHalo();
166 : :
167 [ # # ]: 0 : IpplTimings::stopTimer(pcgTimer);
168 : :
169 [ # # # # ]: 0 : int output = this->params_m.template get<int>("output_type");
170 [ # # ]: 0 : if (output & Base::GRAD) {
171 [ # # # # : 0 : *(this->grad_mp) = -grad(*(this->lhs_mp));
# # ]
172 : : }
173 : :
174 [ # # ]: 0 : IpplTimings::stopTimer(solve);
175 : 0 : }
176 : :
177 : : /**
178 : : * Query how many iterations were required to obtain the solution
179 : : * the last time this solver was used
180 : : * @return Iteration count of last solve
181 : : */
182 : 0 : int getIterationCount() { return pcg_algo_m.getIterationCount(); }
183 : :
184 : : /**
185 : : * Query the residue
186 : : * @return Residue norm from last solve
187 : : */
188 : 0 : Tlhs getResidue() const { return pcg_algo_m.getResidue(); }
189 : :
190 : : /**
191 : : * Query the L2-norm error compared to a given (analytical) sol
192 : : * @return L2 error after last solve
193 : : */
194 : : template <typename F>
195 : 0 : Tlhs getL2Error(const F& analytic) {
196 : 0 : Tlhs error_norm = this->lagrangeSpace_m.computeErrorL2(*(this->lhs_mp), analytic);
197 : 0 : return error_norm;
198 : : }
199 : :
200 : : /**
201 : : * Query the average of the solution
202 : : * @param vol Boolean indicating whether we divide by volume or not
203 : : * @return avg (offset for null space test cases if divided by volume)
204 : : */
205 : 0 : Tlhs getAvg(bool Vol = false) {
206 : 0 : Tlhs avg = this->lagrangeSpace_m.computeAvg(*(this->lhs_mp));
207 [ # # ]: 0 : if (Vol) {
208 [ # # # # ]: 0 : lhs_type unit((this->lhs_mp)->get_mesh(), (this->lhs_mp)->getLayout());
209 [ # # ]: 0 : unit = 1.0;
210 [ # # ]: 0 : Tlhs vol = this->lagrangeSpace_m.computeAvg(unit);
211 : 0 : return avg/vol;
212 : 0 : } else {
213 : 0 : return avg;
214 : : }
215 : : }
216 : :
217 : : protected:
218 : : PCGSolverAlgorithm_t pcg_algo_m;
219 : :
220 : 0 : virtual void setDefaultParameters() override {
221 [ # # # # ]: 0 : this->params_m.add("max_iterations", 1000);
222 [ # # # # ]: 0 : this->params_m.add("tolerance", (Tlhs)1e-13);
223 : 0 : }
224 : :
225 : : ElementType refElement_m;
226 : : QuadratureType quadrature_m;
227 : : LagrangeType lagrangeSpace_m;
228 : : };
229 : :
230 : : } // namespace ippl
231 : :
232 : : #endif
|