Line data Source code
1 : //
2 : // Type Utilities
3 : // Metaprogramming utility functions for type manipulation
4 : //
5 :
6 : #ifndef IPPL_TYPE_UTILS_H
7 : #define IPPL_TYPE_UTILS_H
8 :
9 : #include <Kokkos_Core.hpp>
10 :
11 : #include "Types/Variant.h"
12 :
13 : #include "Utility/IpplException.h"
14 :
15 : namespace ippl {
16 : namespace detail {
17 :
18 : /*!
19 : * Variant verification struct
20 : * Checks that a given type has not already been added to a variant
21 : * @tparam Check the type for whose presence to check
22 : * @tparam Collection... a collection of types
23 : */
24 : template <typename Check, typename... Collection>
25 : struct IsUnique {
26 : constexpr static bool enable = !std::disjunction_v<std::is_same<Check, Collection>...>;
27 : typedef Check type;
28 : };
29 :
30 : /*!
31 : * Defines a variant verification struct
32 : * Performs the same check as IsUnique, but instead of using the provided types
33 : * directly, the types are wrapped in another provided type. For example, if the
34 : * wrapper type is std::shared_ptr and the types are <int, float, int>, then
35 : * the final variant will allow std::shared_ptr<int> and std::shared_ptr<float>.
36 : * @tparam Wrapper the wrapper type
37 : */
38 : template <template <typename> class Wrapper>
39 : struct WrapUnique {
40 : template <typename Check, typename... Collection>
41 : struct Verifier {
42 : typedef Wrapper<Check> type;
43 : constexpr static bool enable =
44 : !std::disjunction_v<std::is_same<type, Collection>...>;
45 : };
46 : };
47 :
48 : /*!
49 : * Convenience alias for types that should or should not be included
50 : * in variants constructed with ConstructVariant (defined below) based
51 : * on some compile-time constant
52 : * @tparam B whether the type should be enabled
53 : * @tparam T the type
54 : */
55 : template <bool B, typename T>
56 : using ConditionalType = std::conditional_t<B, T, void>;
57 :
58 : /*!
59 : * Variant verification struct
60 : * Enables the type if it is not void (intended for use with std::conditional_t
61 : * where the user passes void if the type should not be included)
62 : * @tparam Type the type that should be added, or void
63 : * @tparam ... dummy parameter to ensure compatibility with IsPresent
64 : */
65 : template <typename Type, typename...>
66 : struct IsEnabled {
67 : constexpr static bool enable = !std::is_void_v<Type>;
68 : typedef Type type;
69 : };
70 :
71 : /*!
72 : * Base struct declaration (see full declaration below for details)
73 : */
74 : template <typename, typename, template <typename...> class Verifier = IsUnique>
75 : struct ConstructVariant;
76 :
77 : /*!
78 : * Base case for variant construction with no types to add
79 : */
80 : template <template <typename...> class Verifier>
81 : struct ConstructVariant<std::variant<>, std::variant<>, Verifier> {
82 : typedef std::variant<> type;
83 : };
84 :
85 : /*!
86 : * Base case for a fully constructed variant
87 : * @tparam T... the types to be included in the variant
88 : */
89 : template <typename... T, template <typename...> class Verifier>
90 : struct ConstructVariant<std::variant<>, std::variant<T...>, Verifier> {
91 : typedef std::variant<T...> type;
92 : };
93 :
94 : /*!
95 : * Constructs a variant type containing all the provided types that fulfill a certain
96 : * condition. This is done by recursively adding types to the variant based on the
97 : * inclusion criteria.
98 : *
99 : * The default verification struct is IsPresent defined above, which includes the type
100 : * if it has not already been added to the variant before. This is useful
101 : * if the provided types include duplicates, e.g. if they are type aliases that
102 : * can sometimes refer to the same type. In particular, Kokkos memory spaces can
103 : * sometimes have different names but refer to the same memory space. Variants do
104 : * not allow duplicate types to appear in their parameter packs; each type may only
105 : * appear once.
106 : *
107 : * The verification struct can be user defined as long as it conforms to the variant
108 : * verification struct interface. The struct must accept at least a parameter pack; the
109 : * first parameter is the type currently being checked and the rest are the types already
110 : * added to the variant. The struct must expose a boolean `enable` that indicates whether
111 : * the next type should be included in the variant. The struct must also expose a type
112 : * `type`, which is the type to be added to the variant
113 : *
114 : * @tparam Next the next type to add to the variant
115 : * @tparam ToAdd... the remaining types waiting to be added to the variant
116 : * @tparam Added... the types that have already been added to the variant
117 : * @tparam Verifier the variant verification struct
118 : */
119 : template <typename Next, typename... ToAdd, typename... Added,
120 : template <typename...> class Verifier>
121 : struct ConstructVariant<std::variant<Next, ToAdd...>, std::variant<Added...>, Verifier> {
122 : // Convenience aliases
123 : template <bool B, class T, class F>
124 : using cond = std::conditional_t<B, T, F>;
125 : template <typename... T>
126 : using variant = std::variant<T...>;
127 :
128 : using Check = Verifier<Next, Added...>;
129 :
130 : typedef cond<
131 : Check::enable,
132 : // The verifier has indicated that this type should be added
133 : typename ConstructVariant<variant<ToAdd...>,
134 : variant<typename Check::type, Added...>, Verifier>::type,
135 : // The verifier has indicated that the type should not be added
136 : typename ConstructVariant<variant<ToAdd...>, variant<Added...>, Verifier>::type>
137 : type;
138 : };
139 :
140 : /*!
141 : * A variant containing all the enabled types,
142 : * where "enabled" types are assumed to be void
143 : * when disabled (i.e. std::conditional_t<B, T, void>)
144 : */
145 : template <typename... Types>
146 : using VariantFromConditionalTypes =
147 : typename ConstructVariant<std::variant<Types...>, std::variant<>, IsEnabled>::type;
148 :
149 : /*!
150 : * A variant containing just the unique types
151 : * from the pack
152 : */
153 : template <typename... Types>
154 : using VariantFromUniqueTypes =
155 : typename ConstructVariant<std::variant<Types...>, std::variant<>, IsUnique>::type;
156 :
157 : /*!
158 : * A variant containing the types enabled by a custom
159 : * verifier; to implement a custom verifier, provide the following:
160 : * - template <typename Next, typename... Added>
161 : * Next: the next input type to check
162 : * Added: the types that have already been added
163 : * - bool enable: whether the type should be added
164 : * - typename type: the output type to be added
165 : */
166 : template <template <typename...> class Verifier, typename... Types>
167 : using VariantWithVerifier =
168 : typename ConstructVariant<std::variant<Types...>, std::variant<>, Verifier>::type;
169 :
170 : /*!
171 : * Utility struct for forwarding parameter packs
172 : * (see specializations)
173 : * @tparam Type the templated type
174 : * @tparam Pack a type containing the parameters to forward
175 : */
176 : template <template <typename...> class Type, typename Pack>
177 : struct Forward;
178 :
179 : /*!
180 : * Forwards the types in a variant to another type
181 : */
182 : template <template <typename...> class Type, typename... Spaces>
183 : struct Forward<Type, std::variant<Spaces...>> {
184 : using type = Type<Spaces...>;
185 : };
186 :
187 : /*!
188 : * Forwards the properties of a Kokkos view to another type
189 : */
190 : template <template <typename...> class Type, typename T, typename... Properties>
191 : struct Forward<Type, Kokkos::View<T, Properties...>> {
192 : using type = Type<Properties...>;
193 : };
194 :
195 : /*!
196 : * Constructs a uniform type based on Kokkos views' uniform
197 : * types (i.e. a type where all optional template parameters
198 : * are explicitly specified)
199 : * @tparam Type the type to specialize
200 : * @tparam View the view type
201 : */
202 : template <template <typename...> class Type, typename View>
203 : struct CreateUniformType {
204 : using view_type = typename View::uniform_type;
205 : using type = typename Forward<Type, view_type>::type;
206 : };
207 :
208 : /*!
209 : * Instantiates a parameter pack with all the available Kokkos memory spaces
210 : */
211 : template <template <typename...> class Type>
212 : struct TypeForAllSpaces {
213 : using unique_memory_spaces = VariantFromUniqueTypes<
214 : Kokkos::HostSpace, Kokkos::SharedSpace, Kokkos::SharedHostPinnedSpace
215 : #ifdef KOKKOS_ENABLE_CUDA
216 : ,
217 : Kokkos::CudaSpace, Kokkos::CudaHostPinnedSpace, Kokkos::CudaUVMSpace
218 : #endif
219 : #ifdef KOKKOS_ENABLE_HIP
220 : ,
221 : Kokkos::HIPSpace, Kokkos::HIPHostPinnedSpace, Kokkos::HIPManagedSpace
222 : #endif
223 : #ifdef KOKKOS_ENABLE_SYCL
224 : ,
225 : Kokkos::Experimental::SYCLDeviceUSMSpace, Kokkos::Experimental::SYCLHostUSMSpace,
226 : Kokkos::Experimental::SYCLSharedUSMSpace
227 : #endif
228 : >;
229 :
230 : using unique_exec_spaces = VariantFromUniqueTypes<Kokkos::DefaultExecutionSpace
231 : #ifdef KOKKOS_ENABLE_OPENMP
232 : ,
233 : Kokkos::OpenMP
234 : #endif
235 : #ifdef KOKKOS_ENABLE_OPENMPTARGET
236 : ,
237 : Kokkos::OpenMPTarget
238 : #endif
239 : #ifdef KOKKOS_ENABLE_THREADS
240 : ,
241 : Kokkos::Thread
242 : #endif
243 : #ifdef KOKKOS_ENABLE_SERIAL
244 : ,
245 : Kokkos::Serial
246 : #endif
247 : #ifdef KOKKOS_ENABLE_CUDA
248 : ,
249 : Kokkos::Cuda
250 : #endif
251 : #ifdef KOKKOS_ENABLE_HIP
252 : ,
253 : Kokkos::HIP
254 : #endif
255 : #ifdef KOKKOS_ENABLE_SYCL
256 : ,
257 : Kokkos::Experimental::SYCL
258 : #endif
259 : #ifdef KOKKOS_ENABLE_HPX
260 : ,
261 : Kokkos::HPX
262 : #endif
263 : >;
264 :
265 : using memory_spaces_type = typename Forward<Type, unique_memory_spaces>::type;
266 : using exec_spaces_type = typename Forward<Type, unique_exec_spaces>::type;
267 : };
268 :
269 : /*!
270 : * A container indexed by type instead of by numerical indices;
271 : * designed for storing elements associated with Kokkos memory spaces
272 : * @tparam Type the element type
273 : * @tparam Spaces... the memory spaces of interest
274 : */
275 : template <template <typename> class Type, typename... Spaces>
276 : class MultispaceContainer {
277 : template <typename T, typename... Ts>
278 : using Verifier = typename WrapUnique<Type>::template Verifier<T, Ts...>;
279 :
280 : using Types = VariantWithVerifier<Verifier, Spaces...>;
281 :
282 : std::array<Types, sizeof...(Spaces)> elements_m;
283 :
284 : /*!
285 : * Locates an element associated with a space
286 : * @tparam Space the memory space
287 : * @return The numerical index for that space's element
288 : */
289 : template <typename Space, unsigned Idx = 0>
290 3366 : constexpr static unsigned spaceToIndex() {
291 : static_assert(Idx < sizeof...(Spaces));
292 : if constexpr (std::is_same_v<Space,
293 : std::tuple_element_t<Idx, std::tuple<Spaces...>>>) {
294 3366 : return Idx;
295 : } else {
296 : return spaceToIndex<Space, Idx + 1>();
297 : }
298 : // Silences incorrect nvcc warning: missing return statement at end of non-void
299 : // function
300 : throw IpplException("detail::MultispaceContainer::spaceToIndex",
301 : "Unreachable state");
302 : }
303 :
304 : /*!
305 : * Initializes the element for a space
306 : */
307 : template <typename Space>
308 2090 : void initElements() {
309 2090 : elements_m[spaceToIndex<Space>()] = Type<Space>{};
310 2090 : }
311 :
312 : /*!
313 : * Determine whether the element for a space should be initialized,
314 : * possibly based on a predicate functor
315 : */
316 : template <typename MemorySpace, typename Filter,
317 : std::enable_if_t<std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
318 : constexpr bool copyToSpace(Filter&&) {
319 : return true;
320 : }
321 :
322 : template <typename MemorySpace, typename Filter,
323 : std::enable_if_t<!std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
324 : bool copyToSpace(Filter&& predicate) {
325 : return predicate.template operator()<MemorySpace>();
326 : }
327 :
328 : public:
329 2090 : MultispaceContainer() { (initElements<Spaces>(), ...); }
330 :
331 : /*!
332 : * Constructs a container where all spaces have a mirror with
333 : * the same data as the provided data structure; a predicate
334 : * functor can be provided to skip any undesired memory spaces
335 : * @tparam DataType the type of the provided element
336 : * @tparam Filter the predicate type, or nullptr_t if there is no predicate
337 : * @param data the original data
338 : * @param predicate an optional functor that determines which memory spaces need a copy
339 : * of the data
340 : */
341 : template <typename DataType, typename Filter = std::nullptr_t>
342 0 : MultispaceContainer(const DataType& data, Filter&& predicate = nullptr)
343 0 : : MultispaceContainer() {
344 : using space = typename DataType::memory_space;
345 : static_assert(std::is_same_v<DataType, Type<space>>);
346 :
347 0 : elements_m[spaceToIndex<space>()] = data;
348 0 : copyToOtherSpaces<space>(predicate);
349 0 : }
350 :
351 : /*!
352 : * Copies the data from one memory space to all other memory spaces
353 : * @tparam Space the source space
354 : * @tparam Filter the predicate type
355 : * @param predicate an optional functor that determines which memory spaces need a copy
356 : * of the data
357 : */
358 : template <typename Space, typename Filter = std::nullptr_t>
359 24 : void copyToOtherSpaces(Filter&& predicate = nullptr) {
360 48 : forAll([&]<typename DataType>(DataType& dst) {
361 : using memory_space = typename DataType::memory_space;
362 : if constexpr (!std::is_same_v<Space, memory_space>) {
363 : if (copyToSpace<memory_space>(predicate)) {
364 : dst = Kokkos::create_mirror_view_and_copy(
365 : Kokkos::view_alloc(memory_space{}, Kokkos::WithoutInitializing),
366 : get<Space>());
367 : }
368 : }
369 : });
370 24 : }
371 :
372 : /*!
373 : * Accessor for a space's element
374 : * @tparam Space the memory space
375 : * @return The element associated with that space
376 : */
377 : template <typename Space>
378 12 : const Type<Space>& get() const {
379 12 : return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
380 : }
381 :
382 : template <typename Space>
383 1264 : Type<Space>& get() {
384 1264 : return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
385 : }
386 :
387 : /*!
388 : * Performs an action for each element
389 : * @tparam Functor the functor type
390 : * @param f a functor taking an element for a given space
391 : */
392 : template <typename Functor>
393 : void forAll(Functor&& f) const {
394 : (f(get<Spaces>()), ...);
395 : }
396 :
397 : template <typename Functor>
398 930 : void forAll(Functor&& f) {
399 930 : (f(get<Spaces>()), ...);
400 930 : }
401 : };
402 :
403 : /*!
404 : * Constructs a MultispaceContainer for all the available Kokkos memory spaces
405 : * @tparam Type the element type
406 : */
407 : template <template <typename> class Type>
408 : struct ContainerForAllSpaces {
409 : template <typename... Spaces>
410 : using container_type = MultispaceContainer<Type, Spaces...>;
411 :
412 : using type = typename TypeForAllSpaces<container_type>::memory_spaces_type;
413 :
414 : // Static factory function that takes a lambda to initialize each memory space
415 : template <typename Functor>
416 : static type createContainer(Functor&& initFunc) {
417 : return type{std::forward<Functor>(initFunc)};
418 : }
419 : };
420 :
421 : /*!
422 : * Performs an action for all memory spaces
423 : * @tparam Functor the functor type
424 : * @param f a functor object whose call operator takes a memory space as a template
425 : * parameter
426 : */
427 : template <typename Functor>
428 24 : void runForAllSpaces(Functor&& f) {
429 : using all_spaces = typename TypeForAllSpaces<std::variant>::memory_spaces_type;
430 72 : auto runner = [&]<typename... Spaces>(const std::variant<Spaces...>&) {
431 24 : (f.template operator()<Spaces>(), ...);
432 : };
433 24 : runner(all_spaces{});
434 24 : }
435 : } // namespace detail
436 : } // namespace ippl
437 :
438 : #endif
|