Line data Source code
1 : //
2 : // Parallel dispatch
3 : // Utility functions relating to parallel dispatch in IPPL
4 : //
5 :
6 : #ifndef IPPL_PARALLEL_DISPATCH_H
7 : #define IPPL_PARALLEL_DISPATCH_H
8 :
9 : #include <Kokkos_Core.hpp>
10 :
11 : #include <tuple>
12 :
13 : #include "Types/Vector.h"
14 :
15 : #include "Utility/IpplException.h"
16 :
17 : namespace ippl {
18 : /*!
19 : * Wrapper type for Kokkos range policies with some convenience aliases
20 : * @tparam Dim range policy rank
21 : * @tparam PolicyArgs... additional template parameters for the range policy
22 : */
23 : template <unsigned Dim, class... PolicyArgs>
24 : struct RangePolicy {
25 : // The range policy type
26 : using policy_type = Kokkos::MDRangePolicy<PolicyArgs..., Kokkos::Rank<Dim>>;
27 : // The index type used by the range policy
28 : using index_type = typename policy_type::array_index_type;
29 : // A vector type containing the index type
30 : using index_array_type = ::ippl::Vector<index_type, Dim>;
31 : };
32 :
33 : /*!
34 : * Specialized range policy for one dimension.
35 : */
36 : template <class... PolicyArgs>
37 : struct RangePolicy<1, PolicyArgs...> {
38 : using policy_type = Kokkos::RangePolicy<PolicyArgs...>;
39 : using index_type = typename policy_type::index_type;
40 : using index_array_type = ::ippl::Vector<index_type, 1>;
41 : };
42 :
43 : /*!
44 : * Create a range policy that spans an entire Kokkos view, excluding
45 : * a specifiable number of ghost cells at the extremes.
46 : * @tparam Tag range policy tag
47 : * @tparam View the view type
48 : *
49 : * @param view to span
50 : * @param shift number of ghost cells
51 : *
52 : * @return A (MD)RangePolicy that spans the desired elements of the given view
53 : */
54 : template <class... PolicyArgs, typename View>
55 : typename RangePolicy<View::rank, typename View::execution_space, PolicyArgs...>::policy_type
56 920 : getRangePolicy(const View& view, int shift = 0) {
57 920 : constexpr unsigned Dim = View::rank;
58 : using exec_space = typename View::execution_space;
59 : using policy_type = typename RangePolicy<Dim, exec_space, PolicyArgs...>::policy_type;
60 : if constexpr (Dim == 1) {
61 150 : return policy_type(shift, view.size() - shift);
62 : } else {
63 : using index_type = typename RangePolicy<Dim, exec_space, PolicyArgs...>::index_type;
64 : Kokkos::Array<index_type, Dim> begin, end;
65 3776 : for (unsigned int d = 0; d < Dim; d++) {
66 3006 : begin[d] = shift;
67 3006 : end[d] = view.extent(d) - shift;
68 : }
69 1540 : return policy_type(begin, end);
70 : }
71 : // Silences incorrect nvcc warning: missing return statement at end of non-void function
72 : throw IpplException("detail::getRangePolicy", "Unreachable state");
73 : }
74 :
75 : /*!
76 : * Create a range policy for an index range given in the form of arrays
77 : * (required because Kokkos doesn't allow the initialization of 1D range
78 : * policies using arrays)
79 : * @tparam Dim the dimension of the range
80 : * @tparam PolicyArgs... additional template parameters for the range policy
81 : *
82 : * @param begin the starting indices
83 : * @param end the ending indices
84 : *
85 : * @return A (MD)RangePolicy spanning the given range
86 : */
87 : template <size_t Dim, class... PolicyArgs>
88 798 : typename RangePolicy<Dim, PolicyArgs...>::policy_type createRangePolicy(
89 : const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& begin,
90 : const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& end) {
91 : using policy_type = typename RangePolicy<Dim, PolicyArgs...>::policy_type;
92 : if constexpr (Dim == 1) {
93 38 : return policy_type(begin[0], end[0]);
94 : } else {
95 1520 : return policy_type(begin, end);
96 : }
97 : // Silences incorrect nvcc warning: missing return statement at end of non-void function
98 : throw IpplException("detail::createRangePolicy", "Unreachable state");
99 : }
100 :
101 : namespace detail {
102 : /*!
103 : * Recursively templated struct for defining tuples with arbitrary
104 : * length
105 : * @tparam Dim the length of the tuple
106 : * @tparam T the data type to repeat (default size_t)
107 : */
108 : template <unsigned Dim, typename T = size_t>
109 : struct Coords {
110 : // https://stackoverflow.com/a/53398815/2773311
111 : // https://en.cppreference.com/w/cpp/utility/declval
112 : using type =
113 : decltype(std::tuple_cat(std::declval<typename Coords<1, T>::type>(),
114 : std::declval<typename Coords<Dim - 1, T>::type>()));
115 : };
116 :
117 : template <typename T>
118 : struct Coords<1, T> {
119 : using type = std::tuple<T>;
120 : };
121 :
122 : enum e_functor_type {
123 : FOR,
124 : REDUCE,
125 : SCAN
126 : };
127 :
128 : template <e_functor_type, typename, typename, typename, typename...>
129 : struct FunctorWrapper;
130 :
131 : /*!
132 : * Wrapper struct for reduction kernels
133 : * Source:
134 : * https://stackoverflow.com/questions/50713214/familiar-template-syntax-for-generic-lambdas
135 : * @tparam Functor functor type
136 : * @tparam Policy range policy type
137 : * @tparam T... index types
138 : * @tparam Acc accumulator data type
139 : */
140 : template <typename Functor, typename Policy, typename... T, typename... Acc>
141 : struct FunctorWrapper<REDUCE, Functor, Policy, std::tuple<T...>, Acc...> {
142 : Functor f;
143 :
144 : /*!
145 : * Inline operator forwarding to a specialized instantiation
146 : * of the functor's own operator()
147 : * @param x... the indices
148 : * @param res the accumulator variable
149 : * @return The functor's return value
150 : */
151 3579850 : KOKKOS_INLINE_FUNCTION void operator()(T... x, Acc&... res) const {
152 : using index_type = typename Policy::index_type;
153 3579850 : typename Policy::index_array_type args = {static_cast<index_type>(x)...};
154 3579850 : f(args, res...);
155 3579850 : }
156 : };
157 :
158 : template <typename Functor, typename Policy, typename... T>
159 : struct FunctorWrapper<FOR, Functor, Policy, std::tuple<T...>> {
160 : Functor f;
161 :
162 56663822 : KOKKOS_INLINE_FUNCTION void operator()(T... x) const {
163 : using index_type = typename Policy::index_type;
164 56663822 : typename Policy::index_array_type args = {static_cast<index_type>(x)...};
165 56663822 : f(args);
166 56663822 : }
167 : };
168 :
169 : // Extracts the rank of a Kokkos range policy
170 : template <typename>
171 : struct ExtractRank;
172 :
173 : template <typename... T>
174 : struct ExtractRank<Kokkos::RangePolicy<T...>> {
175 : static constexpr int rank = 1;
176 : };
177 : template <typename... T>
178 : struct ExtractRank<Kokkos::MDRangePolicy<T...>> {
179 : static constexpr int rank = Kokkos::MDRangePolicy<T...>::rank;
180 : };
181 : template <typename T>
182 : concept HasMemberValueType = requires() {
183 : { typename T::value_type() };
184 : };
185 : template <typename T>
186 : struct ExtractReducerReturnType {
187 : using type = T;
188 : };
189 : template <HasMemberValueType T>
190 : struct ExtractReducerReturnType<T> {
191 : using type = typename T::value_type;
192 : };
193 :
194 : /*!
195 : * Convenience function for wrapping a functor with the wrapper struct.
196 : * @tparam Functor the functor type
197 : * @tparam Type the parallel dispatch type
198 : * @tparam Policy the range policy type
199 : * @tparam Acc... the accumulator type(s)
200 : * @return A wrapper containing the given functor
201 : */
202 : template <e_functor_type Type, typename Policy, typename... Acc, typename Functor>
203 1622 : auto functorize(const Functor& f) {
204 1622 : constexpr unsigned Dim = ExtractRank<Policy>::rank;
205 : using PolicyProperties = RangePolicy<Dim, typename Policy::execution_space>;
206 : using index_type = typename PolicyProperties::index_type;
207 : return FunctorWrapper<Type, Functor, PolicyProperties,
208 1622 : typename Coords<Dim, index_type>::type, Acc...>{f};
209 : }
210 : } // namespace detail
211 :
212 : // Wrappers for Kokkos' parallel dispatch functions that use
213 : // the IPPL functor wrapper
214 : template <class ExecPolicy, class FunctorType>
215 1440 : void parallel_for(const std::string& name, const ExecPolicy& policy,
216 : const FunctorType& functor) {
217 1440 : Kokkos::parallel_for(name, policy, detail::functorize<detail::FOR, ExecPolicy>(functor));
218 1440 : }
219 :
220 : template <class ExecPolicy, class FunctorType, class... ReducerArgument>
221 182 : void parallel_reduce(const std::string& name, const ExecPolicy& policy,
222 : const FunctorType& functor, ReducerArgument&&... reducer) {
223 182 : Kokkos::parallel_reduce(
224 : name, policy,
225 : detail::functorize<detail::REDUCE, ExecPolicy,
226 364 : typename detail::ExtractReducerReturnType<ReducerArgument>::type...>(
227 : functor),
228 182 : std::forward<ReducerArgument>(reducer)...);
229 182 : }
230 : } // namespace ippl
231 :
232 : #endif
|