Branch data Line data Source code
1 : : //
2 : : // Class PoissonCG
3 : : // Solves the Poisson problem with the CG algorithm
4 : : //
5 : :
6 : : #ifndef IPPL_POISSON_CG_H
7 : : #define IPPL_POISSON_CG_H
8 : :
9 : : #include "LaplaceHelpers.h"
10 : : #include "LinearSolvers/PCG.h"
11 : : #include "Poisson.h"
12 : : namespace ippl {
13 : :
14 : : // Expands to a lambda that acts as a wrapper for a differential operator
15 : : // fun: the function for which to create the wrapper, such as ippl::laplace
16 : : // type: the argument type, which should match the LHS type for the solver
17 : : #define IPPL_SOLVER_OPERATOR_WRAPPER(fun, type) \
18 : : [](type arg) { \
19 : : return fun(arg); \
20 : : }
21 : :
22 : : template <typename FieldLHS, typename FieldRHS = FieldLHS>
23 : : class PoissonCG : public Poisson<FieldLHS, FieldRHS> {
24 : : using Tlhs = typename FieldLHS::value_type;
25 : :
26 : : public:
27 : : using Base = Poisson<FieldLHS, FieldRHS>;
28 : : constexpr static unsigned Dim = FieldLHS::dim;
29 : : using typename Base::lhs_type, typename Base::rhs_type;
30 : : using OperatorRet = UnaryMinus<detail::meta_laplace<lhs_type>>;
31 : : using LowerRet = UnaryMinus<detail::meta_lower_laplace<lhs_type>>;
32 : : using UpperRet = UnaryMinus<detail::meta_upper_laplace<lhs_type>>;
33 : : using UpperAndLowerRet = UnaryMinus<detail::meta_upper_and_lower_laplace<lhs_type>>;
34 : : using InverseDiagonalRet = double;
35 : : using DiagRet = double;
36 : :
37 : 0 : PoissonCG()
38 : : : Base()
39 : 0 : , algo_m(nullptr) {
40 : : static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
41 [ # # ]: 0 : setDefaultParameters();
42 : 0 : }
43 : :
44 : : PoissonCG(lhs_type& lhs, rhs_type& rhs)
45 : : : Base(lhs, rhs)
46 : : , algo_m(nullptr) {
47 : : static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
48 : : setDefaultParameters();
49 : : }
50 : :
51 : 0 : void setSolver(lhs_type lhs) {
52 [ # # # # ]: 0 : std::string solver_type = this->params_m.template get<std::string>("solver");
53 : 0 : typename lhs_type::Mesh_t mesh = lhs.get_mesh();
54 [ # # # # ]: 0 : typename lhs_type::Layout_t layout = lhs.getLayout();
55 : 0 : double beta = 0;
56 : 0 : double alpha = 0;
57 [ # # # # ]: 0 : if (solver_type == "preconditioned") {
58 : 0 : algo_m = std::move(
59 : : std::make_unique<PCG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
60 [ # # ]: 0 : InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
61 [ # # ]: 0 : std::string preconditioner_type =
62 [ # # ]: 0 : this->params_m.template get<std::string>("preconditioner_type");
63 [ # # # # ]: 0 : int level = this->params_m.template get<int>("newton_level");
64 [ # # # # ]: 0 : int degree = this->params_m.template get<int>("chebyshev_degree");
65 [ # # # # ]: 0 : int inner = this->params_m.template get<int>("gauss_seidel_inner_iterations");
66 [ # # # # ]: 0 : int outer = this->params_m.template get<int>("gauss_seidel_outer_iterations");
67 [ # # # # ]: 0 : double omega = this->params_m.template get<double>("ssor_omega");
68 : : int richardson_iterations =
69 [ # # # # ]: 0 : this->params_m.template get<int>("richardson_iterations");
70 [ # # # # ]: 0 : int communication = this->params_m.template get<int>("communication");
71 : : // Analytical eigenvalues for the d dimensional laplace operator
72 : : // Going brute force through all possible eigenvalues seems to be the only way to
73 : : // find max and min
74 : :
75 : : unsigned long n;
76 : : double h;
77 [ # # ]: 0 : for (unsigned int d = 0; d < Dim; ++d) {
78 : 0 : n = mesh.getGridsize(d);
79 [ # # ]: 0 : h = mesh.getMeshSpacing(d);
80 : 0 : double local_min = 4 / std::pow(h, 2); // theoretical maximum
81 : 0 : double local_max = 0;
82 : : double test;
83 [ # # ]: 0 : for (unsigned int i = 1; i < n; ++i) {
84 : 0 : test = 4. / std::pow(h, 2) * std::pow(std::sin(i * M_PI * h / 2.), 2);
85 [ # # ]: 0 : if (test > local_max) {
86 : 0 : local_max = test;
87 : : }
88 [ # # ]: 0 : if (test < local_min) {
89 : 0 : local_min = test;
90 : : }
91 : : }
92 : 0 : beta += local_max;
93 : 0 : alpha += local_min;
94 : : }
95 [ # # ]: 0 : if (communication) {
96 [ # # # # ]: 0 : algo_m->setPreconditioner(
97 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
98 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace, lhs_type),
99 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace, lhs_type),
100 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace, lhs_type),
101 : 0 : IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
102 : 0 : IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
103 : : preconditioner_type, level, degree, richardson_iterations, inner, outer,
104 : : omega);
105 : : } else {
106 [ # # # # ]: 0 : algo_m->setPreconditioner(
107 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
108 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace_no_comm, lhs_type),
109 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace_no_comm, lhs_type),
110 [ # # # # ]: 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace_no_comm, lhs_type),
111 : 0 : IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
112 : 0 : IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
113 : : preconditioner_type, level, degree, richardson_iterations, inner, outer,
114 : : omega);
115 : : }
116 : 0 : } else {
117 : 0 : algo_m = std::move(
118 : : std::make_unique<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
119 [ # # ]: 0 : InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
120 : : }
121 : 0 : }
122 : :
123 : 0 : void solve() override {
124 [ # # # # ]: 0 : setSolver(*(this->lhs_mp));
125 [ # # # # : 0 : algo_m->setOperator(IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type));
# # ]
126 : 0 : algo_m->operator()(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
127 : :
128 [ # # # # ]: 0 : int output = this->params_m.template get<int>("output_type");
129 [ # # ]: 0 : if (output & Base::GRAD) {
130 [ # # # # : 0 : *(this->grad_mp) = -grad(*(this->lhs_mp));
# # ]
131 : : }
132 : 0 : }
133 : :
134 : : /*!
135 : : * Query how many iterations were required to obtain the solution
136 : : * the last time this solver was used
137 : : * @return Iteration count of last solve
138 : : */
139 : 0 : int getIterationCount() { return algo_m->getIterationCount(); }
140 : :
141 : : /*!
142 : : * Query the residue
143 : : * @return Residue norm from last solve
144 : : */
145 : : Tlhs getResidue() const { return algo_m->getResidue(); }
146 : :
147 : : protected:
148 : : std::unique_ptr<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet, InverseDiagonalRet,
149 : : DiagRet, FieldLHS, FieldRHS>>
150 : : algo_m;
151 : :
152 : 0 : void setDefaultParameters() override {
153 [ # # # # ]: 0 : this->params_m.add("max_iterations", 2000);
154 [ # # # # ]: 0 : this->params_m.add("tolerance", (Tlhs)1e-13);
155 [ # # # # ]: 0 : this->params_m.add("solver", "non-preconditioned");
156 : 0 : }
157 : : };
158 : :
159 : : } // namespace ippl
160 : :
161 : : #endif
|