Branch data Line data Source code
1 : : //
2 : : // Type Utilities
3 : : // Metaprogramming utility functions for type manipulation
4 : : //
5 : :
6 : : #ifndef IPPL_TYPE_UTILS_H
7 : : #define IPPL_TYPE_UTILS_H
8 : :
9 : : #include <Kokkos_Core.hpp>
10 : :
11 : : #include "Types/Variant.h"
12 : :
13 : : #include "Utility/IpplException.h"
14 : :
15 : : namespace ippl {
16 : : namespace detail {
17 : :
18 : : /*!
19 : : * Variant verification struct
20 : : * Checks that a given type has not already been added to a variant
21 : : * @tparam Check the type for whose presence to check
22 : : * @tparam Collection... a collection of types
23 : : */
24 : : template <typename Check, typename... Collection>
25 : : struct IsUnique {
26 : : constexpr static bool enable = !std::disjunction_v<std::is_same<Check, Collection>...>;
27 : : typedef Check type;
28 : : };
29 : :
30 : : /*!
31 : : * Defines a variant verification struct
32 : : * Performs the same check as IsUnique, but instead of using the provided types
33 : : * directly, the types are wrapped in another provided type. For example, if the
34 : : * wrapper type is std::shared_ptr and the types are <int, float, int>, then
35 : : * the final variant will allow std::shared_ptr<int> and std::shared_ptr<float>.
36 : : * @tparam Wrapper the wrapper type
37 : : */
38 : : template <template <typename> class Wrapper>
39 : : struct WrapUnique {
40 : : template <typename Check, typename... Collection>
41 : : struct Verifier {
42 : : typedef Wrapper<Check> type;
43 : : constexpr static bool enable =
44 : : !std::disjunction_v<std::is_same<type, Collection>...>;
45 : : };
46 : : };
47 : :
48 : : /*!
49 : : * Convenience alias for types that should or should not be included
50 : : * in variants constructed with ConstructVariant (defined below) based
51 : : * on some compile-time constant
52 : : * @tparam B whether the type should be enabled
53 : : * @tparam T the type
54 : : */
55 : : template <bool B, typename T>
56 : : using ConditionalType = std::conditional_t<B, T, void>;
57 : :
58 : : /*!
59 : : * Variant verification struct
60 : : * Enables the type if it is not void (intended for use with std::conditional_t
61 : : * where the user passes void if the type should not be included)
62 : : * @tparam Type the type that should be added, or void
63 : : * @tparam ... dummy parameter to ensure compatibility with IsPresent
64 : : */
65 : : template <typename Type, typename...>
66 : : struct IsEnabled {
67 : : constexpr static bool enable = !std::is_void_v<Type>;
68 : : typedef Type type;
69 : : };
70 : :
71 : : /*!
72 : : * Base struct declaration (see full declaration below for details)
73 : : */
74 : : template <typename, typename, template <typename...> class Verifier = IsUnique>
75 : : struct ConstructVariant;
76 : :
77 : : /*!
78 : : * Base case for variant construction with no types to add
79 : : */
80 : : template <template <typename...> class Verifier>
81 : : struct ConstructVariant<std::variant<>, std::variant<>, Verifier> {
82 : : typedef std::variant<> type;
83 : : };
84 : :
85 : : /*!
86 : : * Base case for a fully constructed variant
87 : : * @tparam T... the types to be included in the variant
88 : : */
89 : : template <typename... T, template <typename...> class Verifier>
90 : : struct ConstructVariant<std::variant<>, std::variant<T...>, Verifier> {
91 : : typedef std::variant<T...> type;
92 : : };
93 : :
94 : : /*!
95 : : * Constructs a variant type containing all the provided types that fulfill a certain
96 : : * condition. This is done by recursively adding types to the variant based on the
97 : : * inclusion criteria.
98 : : *
99 : : * The default verification struct is IsPresent defined above, which includes the type
100 : : * if it has not already been added to the variant before. This is useful
101 : : * if the provided types include duplicates, e.g. if they are type aliases that
102 : : * can sometimes refer to the same type. In particular, Kokkos memory spaces can
103 : : * sometimes have different names but refer to the same memory space. Variants do
104 : : * not allow duplicate types to appear in their parameter packs; each type may only
105 : : * appear once.
106 : : *
107 : : * The verification struct can be user defined as long as it conforms to the variant
108 : : * verification struct interface. The struct must accept at least a parameter pack; the
109 : : * first parameter is the type currently being checked and the rest are the types already
110 : : * added to the variant. The struct must expose a boolean `enable` that indicates whether
111 : : * the next type should be included in the variant. The struct must also expose a type
112 : : * `type`, which is the type to be added to the variant
113 : : *
114 : : * @tparam Next the next type to add to the variant
115 : : * @tparam ToAdd... the remaining types waiting to be added to the variant
116 : : * @tparam Added... the types that have already been added to the variant
117 : : * @tparam Verifier the variant verification struct
118 : : */
119 : : template <typename Next, typename... ToAdd, typename... Added,
120 : : template <typename...> class Verifier>
121 : : struct ConstructVariant<std::variant<Next, ToAdd...>, std::variant<Added...>, Verifier> {
122 : : // Convenience aliases
123 : : template <bool B, class T, class F>
124 : : using cond = std::conditional_t<B, T, F>;
125 : : template <typename... T>
126 : : using variant = std::variant<T...>;
127 : :
128 : : using Check = Verifier<Next, Added...>;
129 : :
130 : : typedef cond<
131 : : Check::enable,
132 : : // The verifier has indicated that this type should be added
133 : : typename ConstructVariant<variant<ToAdd...>,
134 : : variant<typename Check::type, Added...>, Verifier>::type,
135 : : // The verifier has indicated that the type should not be added
136 : : typename ConstructVariant<variant<ToAdd...>, variant<Added...>, Verifier>::type>
137 : : type;
138 : : };
139 : :
140 : : /*!
141 : : * A variant containing all the enabled types,
142 : : * where "enabled" types are assumed to be void
143 : : * when disabled (i.e. std::conditional_t<B, T, void>)
144 : : */
145 : : template <typename... Types>
146 : : using VariantFromConditionalTypes =
147 : : typename ConstructVariant<std::variant<Types...>, std::variant<>, IsEnabled>::type;
148 : :
149 : : /*!
150 : : * A variant containing just the unique types
151 : : * from the pack
152 : : */
153 : : template <typename... Types>
154 : : using VariantFromUniqueTypes =
155 : : typename ConstructVariant<std::variant<Types...>, std::variant<>, IsUnique>::type;
156 : :
157 : : /*!
158 : : * A variant containing the types enabled by a custom
159 : : * verifier; to implement a custom verifier, provide the following:
160 : : * - template <typename Next, typename... Added>
161 : : * Next: the next input type to check
162 : : * Added: the types that have already been added
163 : : * - bool enable: whether the type should be added
164 : : * - typename type: the output type to be added
165 : : */
166 : : template <template <typename...> class Verifier, typename... Types>
167 : : using VariantWithVerifier =
168 : : typename ConstructVariant<std::variant<Types...>, std::variant<>, Verifier>::type;
169 : :
170 : : /*!
171 : : * Utility struct for forwarding parameter packs
172 : : * (see specializations)
173 : : * @tparam Type the templated type
174 : : * @tparam Pack a type containing the parameters to forward
175 : : */
176 : : template <template <typename...> class Type, typename Pack>
177 : : struct Forward;
178 : :
179 : : /*!
180 : : * Forwards the types in a variant to another type
181 : : */
182 : : template <template <typename...> class Type, typename... Spaces>
183 : : struct Forward<Type, std::variant<Spaces...>> {
184 : : using type = Type<Spaces...>;
185 : : };
186 : :
187 : : /*!
188 : : * Forwards the properties of a Kokkos view to another type
189 : : */
190 : : template <template <typename...> class Type, typename T, typename... Properties>
191 : : struct Forward<Type, Kokkos::View<T, Properties...>> {
192 : : using type = Type<Properties...>;
193 : : };
194 : :
195 : : /*!
196 : : * Constructs a uniform type based on Kokkos views' uniform
197 : : * types (i.e. a type where all optional template parameters
198 : : * are explicitly specified)
199 : : * @tparam Type the type to specialize
200 : : * @tparam View the view type
201 : : */
202 : : template <template <typename...> class Type, typename View>
203 : : struct CreateUniformType {
204 : : using view_type = typename View::uniform_type;
205 : : using type = typename Forward<Type, view_type>::type;
206 : : };
207 : :
208 : : /*!
209 : : * Instantiates a parameter pack with all the available Kokkos memory spaces
210 : : */
211 : : template <template <typename...> class Type>
212 : : struct TypeForAllSpaces {
213 : : using unique_memory_spaces = VariantFromUniqueTypes<
214 : : Kokkos::HostSpace, Kokkos::SharedSpace, Kokkos::SharedHostPinnedSpace
215 : : #ifdef KOKKOS_ENABLE_CUDA
216 : : ,
217 : : Kokkos::CudaSpace, Kokkos::CudaHostPinnedSpace, Kokkos::CudaUVMSpace
218 : : #endif
219 : : #ifdef KOKKOS_ENABLE_HIP
220 : : ,
221 : : Kokkos::HIPSpace, Kokkos::HIPHostPinnedSpace, Kokkos::HIPManagedSpace
222 : : #endif
223 : : #ifdef KOKKOS_ENABLE_SYCL
224 : : ,
225 : : Kokkos::Experimental::SYCLDeviceUSMSpace, Kokkos::Experimental::SYCLHostUSMSpace,
226 : : Kokkos::Experimental::SYCLSharedUSMSpace
227 : : #endif
228 : : >;
229 : :
230 : : using unique_exec_spaces = VariantFromUniqueTypes<Kokkos::DefaultExecutionSpace
231 : : #ifdef KOKKOS_ENABLE_OPENMP
232 : : ,
233 : : Kokkos::OpenMP
234 : : #endif
235 : : #ifdef KOKKOS_ENABLE_OPENMPTARGET
236 : : ,
237 : : Kokkos::OpenMPTarget
238 : : #endif
239 : : #ifdef KOKKOS_ENABLE_THREADS
240 : : ,
241 : : Kokkos::Thread
242 : : #endif
243 : : #ifdef KOKKOS_ENABLE_SERIAL
244 : : ,
245 : : Kokkos::Serial
246 : : #endif
247 : : #ifdef KOKKOS_ENABLE_CUDA
248 : : ,
249 : : Kokkos::Cuda
250 : : #endif
251 : : #ifdef KOKKOS_ENABLE_HIP
252 : : ,
253 : : Kokkos::HIP
254 : : #endif
255 : : #ifdef KOKKOS_ENABLE_SYCL
256 : : ,
257 : : Kokkos::Experimental::SYCL
258 : : #endif
259 : : #ifdef KOKKOS_ENABLE_HPX
260 : : ,
261 : : Kokkos::HPX
262 : : #endif
263 : : >;
264 : :
265 : : using memory_spaces_type = typename Forward<Type, unique_memory_spaces>::type;
266 : : using exec_spaces_type = typename Forward<Type, unique_exec_spaces>::type;
267 : : };
268 : :
269 : : /*!
270 : : * A container indexed by type instead of by numerical indices;
271 : : * designed for storing elements associated with Kokkos memory spaces
272 : : * @tparam Type the element type
273 : : * @tparam Spaces... the memory spaces of interest
274 : : */
275 : : template <template <typename> class Type, typename... Spaces>
276 : : class MultispaceContainer {
277 : : template <typename T, typename... Ts>
278 : : using Verifier = typename WrapUnique<Type>::template Verifier<T, Ts...>;
279 : :
280 : : using Types = VariantWithVerifier<Verifier, Spaces...>;
281 : :
282 : : std::array<Types, sizeof...(Spaces)> elements_m;
283 : :
284 : : /*!
285 : : * Locates an element associated with a space
286 : : * @tparam Space the memory space
287 : : * @return The numerical index for that space's element
288 : : */
289 : : template <typename Space, unsigned Idx = 0>
290 : 3878 : constexpr static unsigned spaceToIndex() {
291 : : static_assert(Idx < sizeof...(Spaces));
292 : : if constexpr (std::is_same_v<Space,
293 : : std::tuple_element_t<Idx, std::tuple<Spaces...>>>) {
294 : 3878 : return Idx;
295 : : } else {
296 : : return spaceToIndex<Space, Idx + 1>();
297 : : }
298 : : // Silences incorrect nvcc warning: missing return statement at end of non-void
299 : : // function
300 : : throw IpplException("detail::MultispaceContainer::spaceToIndex",
301 : : "Unreachable state");
302 : : }
303 : :
304 : : /*!
305 : : * Initializes the element for a space
306 : : */
307 : : template <typename Space>
308 : 2382 : void initElements() {
309 [ + - + - ]: 2382 : elements_m[spaceToIndex<Space>()] = Type<Space>{};
[ + - + -
+ - ]
310 : 2382 : }
311 : :
312 : : /*!
313 : : * Determine whether the element for a space should be initialized,
314 : : * possibly based on a predicate functor
315 : : */
316 : : template <typename MemorySpace, typename Filter,
317 : : std::enable_if_t<std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
318 : : constexpr bool copyToSpace(Filter&&) {
319 : : return true;
320 : : }
321 : :
322 : : template <typename MemorySpace, typename Filter,
323 : : std::enable_if_t<!std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
324 : : bool copyToSpace(Filter&& predicate) {
325 : : return predicate.template operator()<MemorySpace>();
326 : : }
327 : :
328 : : public:
329 [ + - ]: 2382 : MultispaceContainer() { (initElements<Spaces>(), ...); }
330 : :
331 : : /*!
332 : : * Constructs a container where all spaces have a mirror with
333 : : * the same data as the provided data structure; a predicate
334 : : * functor can be provided to skip any undesired memory spaces
335 : : * @tparam DataType the type of the provided element
336 : : * @tparam Filter the predicate type, or nullptr_t if there is no predicate
337 : : * @param data the original data
338 : : * @param predicate an optional functor that determines which memory spaces need a copy
339 : : * of the data
340 : : */
341 : : template <typename DataType, typename Filter = std::nullptr_t>
342 : 0 : MultispaceContainer(const DataType& data, Filter&& predicate = nullptr)
343 : 0 : : MultispaceContainer() {
344 : : using space = typename DataType::memory_space;
345 : : static_assert(std::is_same_v<DataType, Type<space>>);
346 : :
347 [ # # # # ]: 0 : elements_m[spaceToIndex<space>()] = data;
348 [ # # ]: 0 : copyToOtherSpaces<space>(predicate);
349 : 0 : }
350 : :
351 : : /*!
352 : : * Copies the data from one memory space to all other memory spaces
353 : : * @tparam Space the source space
354 : : * @tparam Filter the predicate type
355 : : * @param predicate an optional functor that determines which memory spaces need a copy
356 : : * of the data
357 : : */
358 : : template <typename Space, typename Filter = std::nullptr_t>
359 : 24 : void copyToOtherSpaces(Filter&& predicate = nullptr) {
360 [ + - ]: 48 : forAll([&]<typename DataType>(DataType& dst) {
361 : : using memory_space = typename DataType::memory_space;
362 : : if constexpr (!std::is_same_v<Space, memory_space>) {
363 : : if (copyToSpace<memory_space>(predicate)) {
364 : : dst = Kokkos::create_mirror_view_and_copy(
365 : : Kokkos::view_alloc(memory_space{}, Kokkos::WithoutInitializing),
366 : : get<Space>());
367 : : }
368 : : }
369 : : });
370 : 24 : }
371 : :
372 : : /*!
373 : : * Accessor for a space's element
374 : : * @tparam Space the memory space
375 : : * @return The element associated with that space
376 : : */
377 : : template <typename Space>
378 : 12 : const Type<Space>& get() const {
379 : 12 : return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
380 : : }
381 : :
382 : : template <typename Space>
383 : 1484 : Type<Space>& get() {
384 : 1484 : return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
385 : : }
386 : :
387 : : /*!
388 : : * Performs an action for each element
389 : : * @tparam Functor the functor type
390 : : * @param f a functor taking an element for a given space
391 : : */
392 : : template <typename Functor>
393 : : void forAll(Functor&& f) const {
394 : : (f(get<Spaces>()), ...);
395 : : }
396 : :
397 : : template <typename Functor>
398 : 1040 : void forAll(Functor&& f) {
399 : 1040 : (f(get<Spaces>()), ...);
400 : 1040 : }
401 : : };
402 : :
403 : : /*!
404 : : * Constructs a MultispaceContainer for all the available Kokkos memory spaces
405 : : * @tparam Type the element type
406 : : */
407 : : template <template <typename> class Type>
408 : : struct ContainerForAllSpaces {
409 : : template <typename... Spaces>
410 : : using container_type = MultispaceContainer<Type, Spaces...>;
411 : :
412 : : using type = typename TypeForAllSpaces<container_type>::memory_spaces_type;
413 : :
414 : : // Static factory function that takes a lambda to initialize each memory space
415 : : template <typename Functor>
416 : : static type createContainer(Functor&& initFunc) {
417 : : return type{std::forward<Functor>(initFunc)};
418 : : }
419 : : };
420 : :
421 : : /*!
422 : : * Performs an action for all memory spaces
423 : : * @tparam Functor the functor type
424 : : * @param f a functor object whose call operator takes a memory space as a template
425 : : * parameter
426 : : */
427 : : template <typename Functor>
428 : 24 : void runForAllSpaces(Functor&& f) {
429 : : using all_spaces = typename TypeForAllSpaces<std::variant>::memory_spaces_type;
430 : 72 : auto runner = [&]<typename... Spaces>(const std::variant<Spaces...>&) {
431 : 24 : (f.template operator()<Spaces>(), ...);
432 : : };
433 [ + - ]: 24 : runner(all_spaces{});
434 : 24 : }
435 : : } // namespace detail
436 : : } // namespace ippl
437 : :
438 : : #endif
|