LCOV - code coverage report
Current view: top level - src/Utility - TypeUtils.h (source / functions) Coverage Total Hit
Test: final_report.info Lines: 80.8 % 26 21
Test Date: 2025-07-17 17:44:22 Functions: 36.9 % 696 257

            Line data    Source code
       1              : //
       2              : // Type Utilities
       3              : //   Metaprogramming utility functions for type manipulation
       4              : //
       5              : 
       6              : #ifndef IPPL_TYPE_UTILS_H
       7              : #define IPPL_TYPE_UTILS_H
       8              : 
       9              : #include <Kokkos_Core.hpp>
      10              : 
      11              : #include "Types/Variant.h"
      12              : 
      13              : #include "Utility/IpplException.h"
      14              : 
      15              : namespace ippl {
      16              :     namespace detail {
      17              : 
      18              :         /*!
      19              :          * Variant verification struct
      20              :          * Checks that a given type has not already been added to a variant
      21              :          * @tparam Check the type for whose presence to check
      22              :          * @tparam Collection... a collection of types
      23              :          */
      24              :         template <typename Check, typename... Collection>
      25              :         struct IsUnique {
      26              :             constexpr static bool enable = !std::disjunction_v<std::is_same<Check, Collection>...>;
      27              :             typedef Check type;
      28              :         };
      29              : 
      30              :         /*!
      31              :          * Defines a variant verification struct
      32              :          * Performs the same check as IsUnique, but instead of using the provided types
      33              :          * directly, the types are wrapped in another provided type. For example, if the
      34              :          * wrapper type is std::shared_ptr and the types are <int, float, int>, then
      35              :          * the final variant will allow std::shared_ptr<int> and std::shared_ptr<float>.
      36              :          * @tparam Wrapper the wrapper type
      37              :          */
      38              :         template <template <typename> class Wrapper>
      39              :         struct WrapUnique {
      40              :             template <typename Check, typename... Collection>
      41              :             struct Verifier {
      42              :                 typedef Wrapper<Check> type;
      43              :                 constexpr static bool enable =
      44              :                     !std::disjunction_v<std::is_same<type, Collection>...>;
      45              :             };
      46              :         };
      47              : 
      48              :         /*!
      49              :          * Convenience alias for types that should or should not be included
      50              :          * in variants constructed with ConstructVariant (defined below) based
      51              :          * on some compile-time constant
      52              :          * @tparam B whether the type should be enabled
      53              :          * @tparam T the type
      54              :          */
      55              :         template <bool B, typename T>
      56              :         using ConditionalType = std::conditional_t<B, T, void>;
      57              : 
      58              :         /*!
      59              :          * Variant verification struct
      60              :          * Enables the type if it is not void (intended for use with std::conditional_t
      61              :          * where the user passes void if the type should not be included)
      62              :          * @tparam Type the type that should be added, or void
      63              :          * @tparam ... dummy parameter to ensure compatibility with IsPresent
      64              :          */
      65              :         template <typename Type, typename...>
      66              :         struct IsEnabled {
      67              :             constexpr static bool enable = !std::is_void_v<Type>;
      68              :             typedef Type type;
      69              :         };
      70              : 
      71              :         /*!
      72              :          * Base struct declaration (see full declaration below for details)
      73              :          */
      74              :         template <typename, typename, template <typename...> class Verifier = IsUnique>
      75              :         struct ConstructVariant;
      76              : 
      77              :         /*!
      78              :          * Base case for variant construction with no types to add
      79              :          */
      80              :         template <template <typename...> class Verifier>
      81              :         struct ConstructVariant<std::variant<>, std::variant<>, Verifier> {
      82              :             typedef std::variant<> type;
      83              :         };
      84              : 
      85              :         /*!
      86              :          * Base case for a fully constructed variant
      87              :          * @tparam T... the types to be included in the variant
      88              :          */
      89              :         template <typename... T, template <typename...> class Verifier>
      90              :         struct ConstructVariant<std::variant<>, std::variant<T...>, Verifier> {
      91              :             typedef std::variant<T...> type;
      92              :         };
      93              : 
      94              :         /*!
      95              :          * Constructs a variant type containing all the provided types that fulfill a certain
      96              :          * condition. This is done by recursively adding types to the variant based on the
      97              :          * inclusion criteria.
      98              :          *
      99              :          * The default verification struct is IsPresent defined above, which includes the type
     100              :          * if it has not already been added to the variant before. This is useful
     101              :          * if the provided types include duplicates, e.g. if they are type aliases that
     102              :          * can sometimes refer to the same type. In particular, Kokkos memory spaces can
     103              :          * sometimes have different names but refer to the same memory space. Variants do
     104              :          * not allow duplicate types to appear in their parameter packs; each type may only
     105              :          * appear once.
     106              :          *
     107              :          * The verification struct can be user defined as long as it conforms to the variant
     108              :          * verification struct interface. The struct must accept at least a parameter pack; the
     109              :          * first parameter is the type currently being checked and the rest are the types already
     110              :          * added to the variant. The struct must expose a boolean `enable` that indicates whether
     111              :          * the next type should be included in the variant. The struct must also expose a type
     112              :          * `type`, which is the type to be added to the variant
     113              :          *
     114              :          * @tparam Next the next type to add to the variant
     115              :          * @tparam ToAdd... the remaining types waiting to be added to the variant
     116              :          * @tparam Added... the types that have already been added to the variant
     117              :          * @tparam Verifier the variant verification struct
     118              :          */
     119              :         template <typename Next, typename... ToAdd, typename... Added,
     120              :                   template <typename...> class Verifier>
     121              :         struct ConstructVariant<std::variant<Next, ToAdd...>, std::variant<Added...>, Verifier> {
     122              :             // Convenience aliases
     123              :             template <bool B, class T, class F>
     124              :             using cond = std::conditional_t<B, T, F>;
     125              :             template <typename... T>
     126              :             using variant = std::variant<T...>;
     127              : 
     128              :             using Check = Verifier<Next, Added...>;
     129              : 
     130              :             typedef cond<
     131              :                 Check::enable,
     132              :                 // The verifier has indicated that this type should be added
     133              :                 typename ConstructVariant<variant<ToAdd...>,
     134              :                                           variant<typename Check::type, Added...>, Verifier>::type,
     135              :                 // The verifier has indicated that the type should not be added
     136              :                 typename ConstructVariant<variant<ToAdd...>, variant<Added...>, Verifier>::type>
     137              :                 type;
     138              :         };
     139              : 
     140              :         /*!
     141              :          * A variant containing all the enabled types,
     142              :          * where "enabled" types are assumed to be void
     143              :          * when disabled (i.e. std::conditional_t<B, T, void>)
     144              :          */
     145              :         template <typename... Types>
     146              :         using VariantFromConditionalTypes =
     147              :             typename ConstructVariant<std::variant<Types...>, std::variant<>, IsEnabled>::type;
     148              : 
     149              :         /*!
     150              :          * A variant containing just the unique types
     151              :          * from the pack
     152              :          */
     153              :         template <typename... Types>
     154              :         using VariantFromUniqueTypes =
     155              :             typename ConstructVariant<std::variant<Types...>, std::variant<>, IsUnique>::type;
     156              : 
     157              :         /*!
     158              :          * A variant containing the types enabled by a custom
     159              :          * verifier; to implement a custom verifier, provide the following:
     160              :          * - template <typename Next, typename... Added>
     161              :          *   Next: the next input type to check
     162              :          *   Added: the types that have already been added
     163              :          * - bool enable: whether the type should be added
     164              :          * - typename type: the output type to be added
     165              :          */
     166              :         template <template <typename...> class Verifier, typename... Types>
     167              :         using VariantWithVerifier =
     168              :             typename ConstructVariant<std::variant<Types...>, std::variant<>, Verifier>::type;
     169              : 
     170              :         /*!
     171              :          * Utility struct for forwarding parameter packs
     172              :          * (see specializations)
     173              :          * @tparam Type the templated type
     174              :          * @tparam Pack a type containing the parameters to forward
     175              :          */
     176              :         template <template <typename...> class Type, typename Pack>
     177              :         struct Forward;
     178              : 
     179              :         /*!
     180              :          * Forwards the types in a variant to another type
     181              :          */
     182              :         template <template <typename...> class Type, typename... Spaces>
     183              :         struct Forward<Type, std::variant<Spaces...>> {
     184              :             using type = Type<Spaces...>;
     185              :         };
     186              : 
     187              :         /*!
     188              :          * Forwards the properties of a Kokkos view to another type
     189              :          */
     190              :         template <template <typename...> class Type, typename T, typename... Properties>
     191              :         struct Forward<Type, Kokkos::View<T, Properties...>> {
     192              :             using type = Type<Properties...>;
     193              :         };
     194              : 
     195              :         /*!
     196              :          * Constructs a uniform type based on Kokkos views' uniform
     197              :          * types (i.e. a type where all optional template parameters
     198              :          * are explicitly specified)
     199              :          * @tparam Type the type to specialize
     200              :          * @tparam View the view type
     201              :          */
     202              :         template <template <typename...> class Type, typename View>
     203              :         struct CreateUniformType {
     204              :             using view_type = typename View::uniform_type;
     205              :             using type      = typename Forward<Type, view_type>::type;
     206              :         };
     207              : 
     208              :         /*!
     209              :          * Instantiates a parameter pack with all the available Kokkos memory spaces
     210              :          */
     211              :         template <template <typename...> class Type>
     212              :         struct TypeForAllSpaces {
     213              :             using unique_memory_spaces = VariantFromUniqueTypes<
     214              :                 Kokkos::HostSpace, Kokkos::SharedSpace, Kokkos::SharedHostPinnedSpace
     215              : #ifdef KOKKOS_ENABLE_CUDA
     216              :                 ,
     217              :                 Kokkos::CudaSpace, Kokkos::CudaHostPinnedSpace, Kokkos::CudaUVMSpace
     218              : #endif
     219              : #ifdef KOKKOS_ENABLE_HIP
     220              :                 ,
     221              :                 Kokkos::HIPSpace, Kokkos::HIPHostPinnedSpace, Kokkos::HIPManagedSpace
     222              : #endif
     223              : #ifdef KOKKOS_ENABLE_SYCL
     224              :                 ,
     225              :                 Kokkos::Experimental::SYCLDeviceUSMSpace, Kokkos::Experimental::SYCLHostUSMSpace,
     226              :                 Kokkos::Experimental::SYCLSharedUSMSpace
     227              : #endif
     228              :                 >;
     229              : 
     230              :             using unique_exec_spaces = VariantFromUniqueTypes<Kokkos::DefaultExecutionSpace
     231              : #ifdef KOKKOS_ENABLE_OPENMP
     232              :                                                               ,
     233              :                                                               Kokkos::OpenMP
     234              : #endif
     235              : #ifdef KOKKOS_ENABLE_OPENMPTARGET
     236              :                                                               ,
     237              :                                                               Kokkos::OpenMPTarget
     238              : #endif
     239              : #ifdef KOKKOS_ENABLE_THREADS
     240              :                                                               ,
     241              :                                                               Kokkos::Thread
     242              : #endif
     243              : #ifdef KOKKOS_ENABLE_SERIAL
     244              :                                                               ,
     245              :                                                               Kokkos::Serial
     246              : #endif
     247              : #ifdef KOKKOS_ENABLE_CUDA
     248              :                                                               ,
     249              :                                                               Kokkos::Cuda
     250              : #endif
     251              : #ifdef KOKKOS_ENABLE_HIP
     252              :                                                               ,
     253              :                                                               Kokkos::HIP
     254              : #endif
     255              : #ifdef KOKKOS_ENABLE_SYCL
     256              :                                                               ,
     257              :                                                               Kokkos::Experimental::SYCL
     258              : #endif
     259              : #ifdef KOKKOS_ENABLE_HPX
     260              :                                                               ,
     261              :                                                               Kokkos::HPX
     262              : #endif
     263              :                                                               >;
     264              : 
     265              :             using memory_spaces_type = typename Forward<Type, unique_memory_spaces>::type;
     266              :             using exec_spaces_type   = typename Forward<Type, unique_exec_spaces>::type;
     267              :         };
     268              : 
     269              :         /*!
     270              :          * A container indexed by type instead of by numerical indices;
     271              :          * designed for storing elements associated with Kokkos memory spaces
     272              :          * @tparam Type the element type
     273              :          * @tparam Spaces... the memory spaces of interest
     274              :          */
     275              :         template <template <typename> class Type, typename... Spaces>
     276              :         class MultispaceContainer {
     277              :             template <typename T, typename... Ts>
     278              :             using Verifier = typename WrapUnique<Type>::template Verifier<T, Ts...>;
     279              : 
     280              :             using Types = VariantWithVerifier<Verifier, Spaces...>;
     281              : 
     282              :             std::array<Types, sizeof...(Spaces)> elements_m;
     283              : 
     284              :             /*!
     285              :              * Locates an element associated with a space
     286              :              * @tparam Space the memory space
     287              :              * @return The numerical index for that space's element
     288              :              */
     289              :             template <typename Space, unsigned Idx = 0>
     290         3366 :             constexpr static unsigned spaceToIndex() {
     291              :                 static_assert(Idx < sizeof...(Spaces));
     292              :                 if constexpr (std::is_same_v<Space,
     293              :                                              std::tuple_element_t<Idx, std::tuple<Spaces...>>>) {
     294         3366 :                     return Idx;
     295              :                 } else {
     296              :                     return spaceToIndex<Space, Idx + 1>();
     297              :                 }
     298              :                 // Silences incorrect nvcc warning: missing return statement at end of non-void
     299              :                 // function
     300              :                 throw IpplException("detail::MultispaceContainer::spaceToIndex",
     301              :                                     "Unreachable state");
     302              :             }
     303              : 
     304              :             /*!
     305              :              * Initializes the element for a space
     306              :              */
     307              :             template <typename Space>
     308         2090 :             void initElements() {
     309         2090 :                 elements_m[spaceToIndex<Space>()] = Type<Space>{};
     310         2090 :             }
     311              : 
     312              :             /*!
     313              :              * Determine whether the element for a space should be initialized,
     314              :              * possibly based on a predicate functor
     315              :              */
     316              :             template <typename MemorySpace, typename Filter,
     317              :                       std::enable_if_t<std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
     318              :             constexpr bool copyToSpace(Filter&&) {
     319              :                 return true;
     320              :             }
     321              : 
     322              :             template <typename MemorySpace, typename Filter,
     323              :                       std::enable_if_t<!std::is_null_pointer_v<std::decay_t<Filter>>, int> = 0>
     324              :             bool copyToSpace(Filter&& predicate) {
     325              :                 return predicate.template operator()<MemorySpace>();
     326              :             }
     327              : 
     328              :         public:
     329         2090 :             MultispaceContainer() { (initElements<Spaces>(), ...); }
     330              : 
     331              :             /*!
     332              :              * Constructs a container where all spaces have a mirror with
     333              :              * the same data as the provided data structure; a predicate
     334              :              * functor can be provided to skip any undesired memory spaces
     335              :              * @tparam DataType the type of the provided element
     336              :              * @tparam Filter the predicate type, or nullptr_t if there is no predicate
     337              :              * @param data the original data
     338              :              * @param predicate an optional functor that determines which memory spaces need a copy
     339              :              * of the data
     340              :              */
     341              :             template <typename DataType, typename Filter = std::nullptr_t>
     342            0 :             MultispaceContainer(const DataType& data, Filter&& predicate = nullptr)
     343            0 :                 : MultispaceContainer() {
     344              :                 using space = typename DataType::memory_space;
     345              :                 static_assert(std::is_same_v<DataType, Type<space>>);
     346              : 
     347            0 :                 elements_m[spaceToIndex<space>()] = data;
     348            0 :                 copyToOtherSpaces<space>(predicate);
     349            0 :             }
     350              : 
     351              :             /*!
     352              :              * Copies the data from one memory space to all other memory spaces
     353              :              * @tparam Space the source space
     354              :              * @tparam Filter the predicate type
     355              :              * @param predicate an optional functor that determines which memory spaces need a copy
     356              :              * of the data
     357              :              */
     358              :             template <typename Space, typename Filter = std::nullptr_t>
     359           24 :             void copyToOtherSpaces(Filter&& predicate = nullptr) {
     360           48 :                 forAll([&]<typename DataType>(DataType& dst) {
     361              :                     using memory_space = typename DataType::memory_space;
     362              :                     if constexpr (!std::is_same_v<Space, memory_space>) {
     363              :                         if (copyToSpace<memory_space>(predicate)) {
     364              :                             dst = Kokkos::create_mirror_view_and_copy(
     365              :                                 Kokkos::view_alloc(memory_space{}, Kokkos::WithoutInitializing),
     366              :                                 get<Space>());
     367              :                         }
     368              :                     }
     369              :                 });
     370           24 :             }
     371              : 
     372              :             /*!
     373              :              * Accessor for a space's element
     374              :              * @tparam Space the memory space
     375              :              * @return The element associated with that space
     376              :              */
     377              :             template <typename Space>
     378           12 :             const Type<Space>& get() const {
     379           12 :                 return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
     380              :             }
     381              : 
     382              :             template <typename Space>
     383         1264 :             Type<Space>& get() {
     384         1264 :                 return std::get<Type<Space>>(elements_m[spaceToIndex<Space>()]);
     385              :             }
     386              : 
     387              :             /*!
     388              :              * Performs an action for each element
     389              :              * @tparam Functor the functor type
     390              :              * @param f a functor taking an element for a given space
     391              :              */
     392              :             template <typename Functor>
     393              :             void forAll(Functor&& f) const {
     394              :                 (f(get<Spaces>()), ...);
     395              :             }
     396              : 
     397              :             template <typename Functor>
     398          930 :             void forAll(Functor&& f) {
     399          930 :                 (f(get<Spaces>()), ...);
     400          930 :             }
     401              :         };
     402              : 
     403              :         /*!
     404              :          * Constructs a MultispaceContainer for all the available Kokkos memory spaces
     405              :          * @tparam Type the element type
     406              :          */
     407              :         template <template <typename> class Type>
     408              :         struct ContainerForAllSpaces {
     409              :             template <typename... Spaces>
     410              :             using container_type = MultispaceContainer<Type, Spaces...>;
     411              : 
     412              :             using type = typename TypeForAllSpaces<container_type>::memory_spaces_type;
     413              : 
     414              :             // Static factory function that takes a lambda to initialize each memory space
     415              :             template <typename Functor>
     416              :             static type createContainer(Functor&& initFunc) {
     417              :                 return type{std::forward<Functor>(initFunc)};
     418              :             }
     419              :         };
     420              : 
     421              :         /*!
     422              :          * Performs an action for all memory spaces
     423              :          * @tparam Functor the functor type
     424              :          * @param f a functor object whose call operator takes a memory space as a template
     425              :          * parameter
     426              :          */
     427              :         template <typename Functor>
     428           24 :         void runForAllSpaces(Functor&& f) {
     429              :             using all_spaces = typename TypeForAllSpaces<std::variant>::memory_spaces_type;
     430           72 :             auto runner      = [&]<typename... Spaces>(const std::variant<Spaces...>&) {
     431           24 :                 (f.template operator()<Spaces>(), ...);
     432              :             };
     433           24 :             runner(all_spaces{});
     434           24 :         }
     435              :     }  // namespace detail
     436              : }  // namespace ippl
     437              : 
     438              : #endif
        

Generated by: LCOV version 2.0-1