Branch data Line data Source code
1 : : #ifndef IPPL_RANDOM_UTILITY_H
2 : : #define IPPL_RANDOM_UTILITY_H
3 : :
4 : : #include <Kokkos_MathematicalConstants.hpp>
5 : : #include <Kokkos_MathematicalFunctions.hpp>
6 : : #include <Kokkos_Random.hpp>
7 : :
8 : : #include "Types/ViewTypes.h"
9 : :
10 : : namespace ippl {
11 : : namespace random {
12 : : namespace detail {
13 : : /*!
14 : : * @struct NewtonRaphson
15 : : * @brief Functor for solving equations using the Newton-Raphson method.
16 : : *
17 : : * In particular, find the root x of the equation dist.obj(x, u)= 0 for a given u using
18 : : * Newton-Raphson method.
19 : : *
20 : : * @tparam T Data type for the equation variables.
21 : : * @tparam Distribution Class of target distribution to sample from.
22 : : * @param dist Distribution object providing objective function cdf(x)-u and its
23 : : * derivative.
24 : : * @param atol Absolute tolerance for convergence (default: 1.0e-12).
25 : : * @param max_iter Maximum number of iterations (default: 20).
26 : : */
27 : : template <typename T, class Distribution>
28 : : struct NewtonRaphson {
29 : : Distribution dist;
30 : : double atol = 1e-12;
31 : : unsigned int max_iter = 20;
32 : :
33 : : KOKKOS_FUNCTION
34 : : NewtonRaphson() = default;
35 : :
36 : : KOKKOS_FUNCTION
37 : 0 : ~NewtonRaphson() = default;
38 : :
39 : 0 : KOKKOS_INLINE_FUNCTION NewtonRaphson(const Distribution& dist_)
40 : 0 : : dist(dist_) {}
41 : :
42 : : /*!
43 : : * @brief Solve an equation using the Newton-Raphson method.
44 : : *
45 : : * This function iteratively solves an equation of the form "cdf(x) - u = 0"
46 : : * for a given sample `u` using the Newton-Raphson method.
47 : : *
48 : : * @param d Dimension index.
49 : : * @param x Variable to solve for (initial guess and final solution).
50 : : * @param u Given sample from a uniform distribution [0, 1].
51 : : */
52 : 0 : KOKKOS_INLINE_FUNCTION void solve(unsigned int d, T& x, T& u) {
53 : 0 : unsigned int iter = 0;
54 [ # # # # : 0 : while (iter < max_iter && Kokkos::fabs(dist.getObjFunc(x, d, u)) > atol) {
# # ]
55 : : // Find x, such that "cdf(x) - u = 0" for a given sample of u~uniform(0,1)
56 : 0 : x = x - (dist.getObjFunc(x, d, u) / dist.getDerObjFunc(x, d));
57 : 0 : iter += 1;
58 : : }
59 : 0 : }
60 : : };
61 : : } // namespace detail
62 : : } // namespace random
63 : : } // namespace ippl
64 : :
65 : : #endif
|