Line data Source code
1 : //
2 : // Class PoissonCG
3 : // Solves the Poisson problem with the CG algorithm
4 : //
5 :
6 : #ifndef IPPL_POISSON_CG_H
7 : #define IPPL_POISSON_CG_H
8 :
9 : #include "LaplaceHelpers.h"
10 : #include "LinearSolvers/PCG.h"
11 : #include "Poisson.h"
12 : namespace ippl {
13 :
14 : // Expands to a lambda that acts as a wrapper for a differential operator
15 : // fun: the function for which to create the wrapper, such as ippl::laplace
16 : // type: the argument type, which should match the LHS type for the solver
17 : #define IPPL_SOLVER_OPERATOR_WRAPPER(fun, type) \
18 : [](type arg) { \
19 : return fun(arg); \
20 : }
21 :
22 : template <typename FieldLHS, typename FieldRHS = FieldLHS>
23 : class PoissonCG : public Poisson<FieldLHS, FieldRHS> {
24 : using Tlhs = typename FieldLHS::value_type;
25 :
26 : public:
27 : using Base = Poisson<FieldLHS, FieldRHS>;
28 : constexpr static unsigned Dim = FieldLHS::dim;
29 : using typename Base::lhs_type, typename Base::rhs_type;
30 : using OperatorRet = UnaryMinus<detail::meta_laplace<lhs_type>>;
31 : using LowerRet = UnaryMinus<detail::meta_lower_laplace<lhs_type>>;
32 : using UpperRet = UnaryMinus<detail::meta_upper_laplace<lhs_type>>;
33 : using UpperAndLowerRet = UnaryMinus<detail::meta_upper_and_lower_laplace<lhs_type>>;
34 : using InverseDiagonalRet = double;
35 : using DiagRet = double;
36 :
37 0 : PoissonCG()
38 : : Base()
39 0 : , algo_m(nullptr) {
40 : static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
41 0 : setDefaultParameters();
42 0 : }
43 :
44 : PoissonCG(lhs_type& lhs, rhs_type& rhs)
45 : : Base(lhs, rhs)
46 : , algo_m(nullptr) {
47 : static_assert(std::is_floating_point<Tlhs>::value, "Not a floating point type");
48 : setDefaultParameters();
49 : }
50 :
51 0 : void setSolver(lhs_type lhs) {
52 0 : std::string solver_type = this->params_m.template get<std::string>("solver");
53 0 : typename lhs_type::Mesh_t mesh = lhs.get_mesh();
54 0 : typename lhs_type::Layout_t layout = lhs.getLayout();
55 0 : double beta = 0;
56 0 : double alpha = 0;
57 0 : if (solver_type == "preconditioned") {
58 0 : algo_m = std::move(
59 : std::make_unique<PCG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
60 0 : InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
61 0 : std::string preconditioner_type =
62 0 : this->params_m.template get<std::string>("preconditioner_type");
63 0 : int level = this->params_m.template get<int>("newton_level");
64 0 : int degree = this->params_m.template get<int>("chebyshev_degree");
65 0 : int inner = this->params_m.template get<int>("gauss_seidel_inner_iterations");
66 0 : int outer = this->params_m.template get<int>("gauss_seidel_outer_iterations");
67 0 : double omega = this->params_m.template get<double>("ssor_omega");
68 : int richardson_iterations =
69 0 : this->params_m.template get<int>("richardson_iterations");
70 0 : int communication = this->params_m.template get<int>("communication");
71 : // Analytical eigenvalues for the d dimensional laplace operator
72 : // Going brute force through all possible eigenvalues seems to be the only way to
73 : // find max and min
74 :
75 : unsigned long n;
76 : double h;
77 0 : for (unsigned int d = 0; d < Dim; ++d) {
78 0 : n = mesh.getGridsize(d);
79 0 : h = mesh.getMeshSpacing(d);
80 0 : double local_min = 4 / std::pow(h, 2); // theoretical maximum
81 0 : double local_max = 0;
82 : double test;
83 0 : for (unsigned int i = 1; i < n; ++i) {
84 0 : test = 4. / std::pow(h, 2) * std::pow(std::sin(i * M_PI * h / 2.), 2);
85 0 : if (test > local_max) {
86 0 : local_max = test;
87 : }
88 0 : if (test < local_min) {
89 0 : local_min = test;
90 : }
91 : }
92 0 : beta += local_max;
93 0 : alpha += local_min;
94 : }
95 0 : if (communication) {
96 0 : algo_m->setPreconditioner(
97 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
98 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace, lhs_type),
99 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace, lhs_type),
100 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace, lhs_type),
101 0 : IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
102 0 : IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
103 : preconditioner_type, level, degree, richardson_iterations, inner, outer,
104 : omega);
105 : } else {
106 0 : algo_m->setPreconditioner(
107 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type),
108 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-lower_laplace_no_comm, lhs_type),
109 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_laplace_no_comm, lhs_type),
110 0 : IPPL_SOLVER_OPERATOR_WRAPPER(-upper_and_lower_laplace_no_comm, lhs_type),
111 0 : IPPL_SOLVER_OPERATOR_WRAPPER(negative_inverse_diagonal_laplace, lhs_type),
112 0 : IPPL_SOLVER_OPERATOR_WRAPPER(diagonal_laplace, lhs_type), alpha, beta,
113 : preconditioner_type, level, degree, richardson_iterations, inner, outer,
114 : omega);
115 : }
116 0 : } else {
117 0 : algo_m = std::move(
118 : std::make_unique<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet,
119 0 : InverseDiagonalRet, DiagRet, FieldLHS, FieldRHS>>());
120 : }
121 0 : }
122 :
123 0 : void solve() override {
124 0 : setSolver(*(this->lhs_mp));
125 0 : algo_m->setOperator(IPPL_SOLVER_OPERATOR_WRAPPER(-laplace, lhs_type));
126 0 : algo_m->operator()(*(this->lhs_mp), *(this->rhs_mp), this->params_m);
127 :
128 0 : int output = this->params_m.template get<int>("output_type");
129 0 : if (output & Base::GRAD) {
130 0 : *(this->grad_mp) = -grad(*(this->lhs_mp));
131 : }
132 0 : }
133 :
134 : /*!
135 : * Query how many iterations were required to obtain the solution
136 : * the last time this solver was used
137 : * @return Iteration count of last solve
138 : */
139 0 : int getIterationCount() { return algo_m->getIterationCount(); }
140 :
141 : /*!
142 : * Query the residue
143 : * @return Residue norm from last solve
144 : */
145 : Tlhs getResidue() const { return algo_m->getResidue(); }
146 :
147 : protected:
148 : std::unique_ptr<CG<OperatorRet, LowerRet, UpperRet, UpperAndLowerRet, InverseDiagonalRet,
149 : DiagRet, FieldLHS, FieldRHS>>
150 : algo_m;
151 :
152 0 : void setDefaultParameters() override {
153 0 : this->params_m.add("max_iterations", 2000);
154 0 : this->params_m.add("tolerance", (Tlhs)1e-13);
155 0 : this->params_m.add("solver", "non-preconditioned");
156 0 : }
157 : };
158 :
159 : } // namespace ippl
160 :
161 : #endif
|