Branch data Line data Source code
1 : : //
2 : : // Parallel dispatch
3 : : // Utility functions relating to parallel dispatch in IPPL
4 : : //
5 : :
6 : : #ifndef IPPL_PARALLEL_DISPATCH_H
7 : : #define IPPL_PARALLEL_DISPATCH_H
8 : :
9 : : #include <Kokkos_Core.hpp>
10 : :
11 : : #include <tuple>
12 : :
13 : : #include "Types/Vector.h"
14 : :
15 : : #include "Utility/IpplException.h"
16 : :
17 : : namespace ippl {
18 : : /*!
19 : : * Wrapper type for Kokkos range policies with some convenience aliases
20 : : * @tparam Dim range policy rank
21 : : * @tparam PolicyArgs... additional template parameters for the range policy
22 : : */
23 : : template <unsigned Dim, class... PolicyArgs>
24 : : struct RangePolicy {
25 : : // The range policy type
26 : : using policy_type = Kokkos::MDRangePolicy<PolicyArgs..., Kokkos::Rank<Dim>>;
27 : : // The index type used by the range policy
28 : : using index_type = typename policy_type::array_index_type;
29 : : // A vector type containing the index type
30 : : using index_array_type = ::ippl::Vector<index_type, Dim>;
31 : : };
32 : :
33 : : /*!
34 : : * Specialized range policy for one dimension.
35 : : */
36 : : template <class... PolicyArgs>
37 : : struct RangePolicy<1, PolicyArgs...> {
38 : : using policy_type = Kokkos::RangePolicy<PolicyArgs...>;
39 : : using index_type = typename policy_type::index_type;
40 : : using index_array_type = ::ippl::Vector<index_type, 1>;
41 : : };
42 : :
43 : : /*!
44 : : * Create a range policy that spans an entire Kokkos view, excluding
45 : : * a specifiable number of ghost cells at the extremes.
46 : : * @tparam Tag range policy tag
47 : : * @tparam View the view type
48 : : *
49 : : * @param view to span
50 : : * @param shift number of ghost cells
51 : : *
52 : : * @return A (MD)RangePolicy that spans the desired elements of the given view
53 : : */
54 : : template <class... PolicyArgs, typename View>
55 : : typename RangePolicy<View::rank, typename View::execution_space, PolicyArgs...>::policy_type
56 : 914 : getRangePolicy(const View& view, int shift = 0) {
57 : 914 : constexpr unsigned Dim = View::rank;
58 : : using exec_space = typename View::execution_space;
59 : : using policy_type = typename RangePolicy<Dim, exec_space, PolicyArgs...>::policy_type;
60 : : if constexpr (Dim == 1) {
61 : 144 : return policy_type(shift, view.size() - shift);
62 : : } else {
63 : : using index_type = typename RangePolicy<Dim, exec_space, PolicyArgs...>::index_type;
64 : : Kokkos::Array<index_type, Dim> begin, end;
65 [ + + ]: 3776 : for (unsigned int d = 0; d < Dim; d++) {
66 : 3006 : begin[d] = shift;
67 : 3006 : end[d] = view.extent(d) - shift;
68 : : }
69 [ + - ]: 1540 : return policy_type(begin, end);
70 : : }
71 : : // Silences incorrect nvcc warning: missing return statement at end of non-void function
72 : : throw IpplException("detail::getRangePolicy", "Unreachable state");
73 : : }
74 : :
75 : : /*!
76 : : * Create a range policy for an index range given in the form of arrays
77 : : * (required because Kokkos doesn't allow the initialization of 1D range
78 : : * policies using arrays)
79 : : * @tparam Dim the dimension of the range
80 : : * @tparam PolicyArgs... additional template parameters for the range policy
81 : : *
82 : : * @param begin the starting indices
83 : : * @param end the ending indices
84 : : *
85 : : * @return A (MD)RangePolicy spanning the given range
86 : : */
87 : : template <size_t Dim, class... PolicyArgs>
88 : 798 : typename RangePolicy<Dim, PolicyArgs...>::policy_type createRangePolicy(
89 : : const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& begin,
90 : : const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& end) {
91 : : using policy_type = typename RangePolicy<Dim, PolicyArgs...>::policy_type;
92 : : if constexpr (Dim == 1) {
93 [ + - ]: 38 : return policy_type(begin[0], end[0]);
94 : : } else {
95 [ + - ]: 1520 : return policy_type(begin, end);
96 : : }
97 : : // Silences incorrect nvcc warning: missing return statement at end of non-void function
98 : : throw IpplException("detail::createRangePolicy", "Unreachable state");
99 : : }
100 : :
101 : : namespace detail {
102 : : /*!
103 : : * Recursively templated struct for defining tuples with arbitrary
104 : : * length
105 : : * @tparam Dim the length of the tuple
106 : : * @tparam T the data type to repeat (default size_t)
107 : : */
108 : : template <unsigned Dim, typename T = size_t>
109 : : struct Coords {
110 : : // https://stackoverflow.com/a/53398815/2773311
111 : : // https://en.cppreference.com/w/cpp/utility/declval
112 : : using type =
113 : : decltype(std::tuple_cat(std::declval<typename Coords<1, T>::type>(),
114 : : std::declval<typename Coords<Dim - 1, T>::type>()));
115 : : };
116 : :
117 : : template <typename T>
118 : : struct Coords<1, T> {
119 : : using type = std::tuple<T>;
120 : : };
121 : :
122 : : enum e_functor_type {
123 : : FOR,
124 : : REDUCE,
125 : : SCAN
126 : : };
127 : :
128 : : template <e_functor_type, typename, typename, typename, typename...>
129 : : struct FunctorWrapper;
130 : :
131 : : /*!
132 : : * Wrapper struct for reduction kernels
133 : : * Source:
134 : : * https://stackoverflow.com/questions/50713214/familiar-template-syntax-for-generic-lambdas
135 : : * @tparam Functor functor type
136 : : * @tparam Policy range policy type
137 : : * @tparam T... index types
138 : : * @tparam Acc accumulator data type
139 : : */
140 : : template <typename Functor, typename Policy, typename... T, typename... Acc>
141 : : struct FunctorWrapper<REDUCE, Functor, Policy, std::tuple<T...>, Acc...> {
142 : : Functor f;
143 : :
144 : : /*!
145 : : * Inline operator forwarding to a specialized instantiation
146 : : * of the functor's own operator()
147 : : * @param x... the indices
148 : : * @param res the accumulator variable
149 : : * @return The functor's return value
150 : : */
151 : 3579840 : KOKKOS_INLINE_FUNCTION void operator()(T... x, Acc&... res) const {
152 : : using index_type = typename Policy::index_type;
153 : 3579840 : typename Policy::index_array_type args = {static_cast<index_type>(x)...};
154 [ + - ]: 3579840 : f(args, res...);
155 : 3579840 : }
156 : : };
157 : :
158 : : template <typename Functor, typename Policy, typename... T>
159 : : struct FunctorWrapper<FOR, Functor, Policy, std::tuple<T...>> {
160 : : Functor f;
161 : :
162 : 56663798 : KOKKOS_INLINE_FUNCTION void operator()(T... x) const {
163 : : using index_type = typename Policy::index_type;
164 : 56663798 : typename Policy::index_array_type args = {static_cast<index_type>(x)...};
165 [ + - ]: 56663798 : f(args);
166 : 56663798 : }
167 : : };
168 : :
169 : : // Extracts the rank of a Kokkos range policy
170 : : template <typename>
171 : : struct ExtractRank;
172 : :
173 : : template <typename... T>
174 : : struct ExtractRank<Kokkos::RangePolicy<T...>> {
175 : : static constexpr int rank = 1;
176 : : };
177 : : template <typename... T>
178 : : struct ExtractRank<Kokkos::MDRangePolicy<T...>> {
179 : : static constexpr int rank = Kokkos::MDRangePolicy<T...>::rank;
180 : : };
181 : : template <typename T>
182 : : concept HasMemberValueType = requires() {
183 : : { typename T::value_type() };
184 : : };
185 : : template <typename T>
186 : : struct ExtractReducerReturnType {
187 : : using type = T;
188 : : };
189 : : template <HasMemberValueType T>
190 : : struct ExtractReducerReturnType<T> {
191 : : using type = typename T::value_type;
192 : : };
193 : :
194 : : /*!
195 : : * Convenience function for wrapping a functor with the wrapper struct.
196 : : * @tparam Functor the functor type
197 : : * @tparam Type the parallel dispatch type
198 : : * @tparam Policy the range policy type
199 : : * @tparam Acc... the accumulator type(s)
200 : : * @return A wrapper containing the given functor
201 : : */
202 : : template <e_functor_type Type, typename Policy, typename... Acc, typename Functor>
203 : 1616 : auto functorize(const Functor& f) {
204 : 1616 : constexpr unsigned Dim = ExtractRank<Policy>::rank;
205 : : using PolicyProperties = RangePolicy<Dim, typename Policy::execution_space>;
206 : : using index_type = typename PolicyProperties::index_type;
207 : : return FunctorWrapper<Type, Functor, PolicyProperties,
208 : 1616 : typename Coords<Dim, index_type>::type, Acc...>{f};
209 : : }
210 : : } // namespace detail
211 : :
212 : : // Wrappers for Kokkos' parallel dispatch functions that use
213 : : // the IPPL functor wrapper
214 : : template <class ExecPolicy, class FunctorType>
215 : 1436 : void parallel_for(const std::string& name, const ExecPolicy& policy,
216 : : const FunctorType& functor) {
217 [ + - + - ]: 1436 : Kokkos::parallel_for(name, policy, detail::functorize<detail::FOR, ExecPolicy>(functor));
218 : 1436 : }
219 : :
220 : : template <class ExecPolicy, class FunctorType, class... ReducerArgument>
221 : 180 : void parallel_reduce(const std::string& name, const ExecPolicy& policy,
222 : : const FunctorType& functor, ReducerArgument&&... reducer) {
223 [ + - ]: 180 : Kokkos::parallel_reduce(
224 : : name, policy,
225 : : detail::functorize<detail::REDUCE, ExecPolicy,
226 [ + - ]: 360 : typename detail::ExtractReducerReturnType<ReducerArgument>::type...>(
227 : : functor),
228 : 180 : std::forward<ReducerArgument>(reducer)...);
229 : 180 : }
230 : : } // namespace ippl
231 : :
232 : : #endif
|