Line data Source code
1 : #ifndef IPPL_RANDOM_UTILITY_H
2 : #define IPPL_RANDOM_UTILITY_H
3 :
4 : #include <Kokkos_MathematicalConstants.hpp>
5 : #include <Kokkos_MathematicalFunctions.hpp>
6 : #include <Kokkos_Random.hpp>
7 :
8 : #include "Types/ViewTypes.h"
9 :
10 : namespace ippl {
11 : namespace random {
12 : namespace detail {
13 : /*!
14 : * @struct NewtonRaphson
15 : * @brief Functor for solving equations using the Newton-Raphson method.
16 : *
17 : * In particular, find the root x of the equation dist.obj(x, u)= 0 for a given u using
18 : * Newton-Raphson method.
19 : *
20 : * @tparam T Data type for the equation variables.
21 : * @tparam Distribution Class of target distribution to sample from.
22 : * @param dist Distribution object providing objective function cdf(x)-u and its
23 : * derivative.
24 : * @param atol Absolute tolerance for convergence (default: 1.0e-12).
25 : * @param max_iter Maximum number of iterations (default: 20).
26 : */
27 : template <typename T, class Distribution>
28 : struct NewtonRaphson {
29 : Distribution dist;
30 : double atol = 1e-12;
31 : unsigned int max_iter = 20;
32 :
33 : KOKKOS_FUNCTION
34 : NewtonRaphson() = default;
35 :
36 : KOKKOS_FUNCTION
37 0 : ~NewtonRaphson() = default;
38 :
39 0 : KOKKOS_INLINE_FUNCTION NewtonRaphson(const Distribution& dist_)
40 0 : : dist(dist_) {}
41 :
42 : /*!
43 : * @brief Solve an equation using the Newton-Raphson method.
44 : *
45 : * This function iteratively solves an equation of the form "cdf(x) - u = 0"
46 : * for a given sample `u` using the Newton-Raphson method.
47 : *
48 : * @param d Dimension index.
49 : * @param x Variable to solve for (initial guess and final solution).
50 : * @param u Given sample from a uniform distribution [0, 1].
51 : */
52 0 : KOKKOS_INLINE_FUNCTION void solve(unsigned int d, T& x, T& u) {
53 0 : unsigned int iter = 0;
54 0 : while (iter < max_iter && Kokkos::fabs(dist.getObjFunc(x, d, u)) > atol) {
55 : // Find x, such that "cdf(x) - u = 0" for a given sample of u~uniform(0,1)
56 0 : x = x - (dist.getObjFunc(x, d, u) / dist.getDerObjFunc(x, d));
57 0 : iter += 1;
58 : }
59 0 : }
60 : };
61 : } // namespace detail
62 : } // namespace random
63 : } // namespace ippl
64 :
65 : #endif
|