LCOV - code coverage report
Current view: top level - src/Utility - ParallelDispatch.h (source / functions) Coverage Total Hit
Test: report.info Lines: 100.0 % 29 29
Test Date: 2025-05-12 09:25:18 Functions: 65.2 % 1818 1186
Branches: 55.0 % 20 11

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // Parallel dispatch
       3                 :             : //   Utility functions relating to parallel dispatch in IPPL
       4                 :             : //
       5                 :             : 
       6                 :             : #ifndef IPPL_PARALLEL_DISPATCH_H
       7                 :             : #define IPPL_PARALLEL_DISPATCH_H
       8                 :             : 
       9                 :             : #include <Kokkos_Core.hpp>
      10                 :             : 
      11                 :             : #include <tuple>
      12                 :             : 
      13                 :             : #include "Types/Vector.h"
      14                 :             : 
      15                 :             : #include "Utility/IpplException.h"
      16                 :             : 
      17                 :             : namespace ippl {
      18                 :             :     /*!
      19                 :             :      * Wrapper type for Kokkos range policies with some convenience aliases
      20                 :             :      * @tparam Dim range policy rank
      21                 :             :      * @tparam PolicyArgs... additional template parameters for the range policy
      22                 :             :      */
      23                 :             :     template <unsigned Dim, class... PolicyArgs>
      24                 :             :     struct RangePolicy {
      25                 :             :         // The range policy type
      26                 :             :         using policy_type = Kokkos::MDRangePolicy<PolicyArgs..., Kokkos::Rank<Dim>>;
      27                 :             :         // The index type used by the range policy
      28                 :             :         using index_type = typename policy_type::array_index_type;
      29                 :             :         // A vector type containing the index type
      30                 :             :         using index_array_type = ::ippl::Vector<index_type, Dim>;
      31                 :             :     };
      32                 :             : 
      33                 :             :     /*!
      34                 :             :      * Specialized range policy for one dimension.
      35                 :             :      */
      36                 :             :     template <class... PolicyArgs>
      37                 :             :     struct RangePolicy<1, PolicyArgs...> {
      38                 :             :         using policy_type      = Kokkos::RangePolicy<PolicyArgs...>;
      39                 :             :         using index_type       = typename policy_type::index_type;
      40                 :             :         using index_array_type = ::ippl::Vector<index_type, 1>;
      41                 :             :     };
      42                 :             : 
      43                 :             :     /*!
      44                 :             :      * Create a range policy that spans an entire Kokkos view, excluding
      45                 :             :      * a specifiable number of ghost cells at the extremes.
      46                 :             :      * @tparam Tag range policy tag
      47                 :             :      * @tparam View the view type
      48                 :             :      *
      49                 :             :      * @param view to span
      50                 :             :      * @param shift number of ghost cells
      51                 :             :      *
      52                 :             :      * @return A (MD)RangePolicy that spans the desired elements of the given view
      53                 :             :      */
      54                 :             :     template <class... PolicyArgs, typename View>
      55                 :             :     typename RangePolicy<View::rank, typename View::execution_space, PolicyArgs...>::policy_type
      56                 :         914 :     getRangePolicy(const View& view, int shift = 0) {
      57                 :         914 :         constexpr unsigned Dim = View::rank;
      58                 :             :         using exec_space       = typename View::execution_space;
      59                 :             :         using policy_type      = typename RangePolicy<Dim, exec_space, PolicyArgs...>::policy_type;
      60                 :             :         if constexpr (Dim == 1) {
      61                 :         144 :             return policy_type(shift, view.size() - shift);
      62                 :             :         } else {
      63                 :             :             using index_type = typename RangePolicy<Dim, exec_space, PolicyArgs...>::index_type;
      64                 :             :             Kokkos::Array<index_type, Dim> begin, end;
      65         [ +  + ]:        3776 :             for (unsigned int d = 0; d < Dim; d++) {
      66                 :        3006 :                 begin[d] = shift;
      67                 :        3006 :                 end[d]   = view.extent(d) - shift;
      68                 :             :             }
      69         [ +  - ]:        1540 :             return policy_type(begin, end);
      70                 :             :         }
      71                 :             :         // Silences incorrect nvcc warning: missing return statement at end of non-void function
      72                 :             :         throw IpplException("detail::getRangePolicy", "Unreachable state");
      73                 :             :     }
      74                 :             : 
      75                 :             :     /*!
      76                 :             :      * Create a range policy for an index range given in the form of arrays
      77                 :             :      * (required because Kokkos doesn't allow the initialization of 1D range
      78                 :             :      * policies using arrays)
      79                 :             :      * @tparam Dim the dimension of the range
      80                 :             :      * @tparam PolicyArgs... additional template parameters for the range policy
      81                 :             :      *
      82                 :             :      * @param begin the starting indices
      83                 :             :      * @param end the ending indices
      84                 :             :      *
      85                 :             :      * @return A (MD)RangePolicy spanning the given range
      86                 :             :      */
      87                 :             :     template <size_t Dim, class... PolicyArgs>
      88                 :         798 :     typename RangePolicy<Dim, PolicyArgs...>::policy_type createRangePolicy(
      89                 :             :         const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& begin,
      90                 :             :         const Kokkos::Array<typename RangePolicy<Dim, PolicyArgs...>::index_type, Dim>& end) {
      91                 :             :         using policy_type = typename RangePolicy<Dim, PolicyArgs...>::policy_type;
      92                 :             :         if constexpr (Dim == 1) {
      93         [ +  - ]:          38 :             return policy_type(begin[0], end[0]);
      94                 :             :         } else {
      95         [ +  - ]:        1520 :             return policy_type(begin, end);
      96                 :             :         }
      97                 :             :         // Silences incorrect nvcc warning: missing return statement at end of non-void function
      98                 :             :         throw IpplException("detail::createRangePolicy", "Unreachable state");
      99                 :             :     }
     100                 :             : 
     101                 :             :     namespace detail {
     102                 :             :         /*!
     103                 :             :          * Recursively templated struct for defining tuples with arbitrary
     104                 :             :          * length
     105                 :             :          * @tparam Dim the length of the tuple
     106                 :             :          * @tparam T the data type to repeat (default size_t)
     107                 :             :          */
     108                 :             :         template <unsigned Dim, typename T = size_t>
     109                 :             :         struct Coords {
     110                 :             :             // https://stackoverflow.com/a/53398815/2773311
     111                 :             :             // https://en.cppreference.com/w/cpp/utility/declval
     112                 :             :             using type =
     113                 :             :                 decltype(std::tuple_cat(std::declval<typename Coords<1, T>::type>(),
     114                 :             :                                         std::declval<typename Coords<Dim - 1, T>::type>()));
     115                 :             :         };
     116                 :             : 
     117                 :             :         template <typename T>
     118                 :             :         struct Coords<1, T> {
     119                 :             :             using type = std::tuple<T>;
     120                 :             :         };
     121                 :             : 
     122                 :             :         enum e_functor_type {
     123                 :             :             FOR,
     124                 :             :             REDUCE,
     125                 :             :             SCAN
     126                 :             :         };
     127                 :             : 
     128                 :             :         template <e_functor_type, typename, typename, typename, typename...>
     129                 :             :         struct FunctorWrapper;
     130                 :             : 
     131                 :             :         /*!
     132                 :             :          * Wrapper struct for reduction kernels
     133                 :             :          * Source:
     134                 :             :          * https://stackoverflow.com/questions/50713214/familiar-template-syntax-for-generic-lambdas
     135                 :             :          * @tparam Functor functor type
     136                 :             :          * @tparam Policy range policy type
     137                 :             :          * @tparam T... index types
     138                 :             :          * @tparam Acc accumulator data type
     139                 :             :          */
     140                 :             :         template <typename Functor, typename Policy, typename... T, typename... Acc>
     141                 :             :         struct FunctorWrapper<REDUCE, Functor, Policy, std::tuple<T...>, Acc...> {
     142                 :             :             Functor f;
     143                 :             : 
     144                 :             :             /*!
     145                 :             :              * Inline operator forwarding to a specialized instantiation
     146                 :             :              * of the functor's own operator()
     147                 :             :              * @param x... the indices
     148                 :             :              * @param res the accumulator variable
     149                 :             :              * @return The functor's return value
     150                 :             :              */
     151                 :     3579840 :             KOKKOS_INLINE_FUNCTION void operator()(T... x, Acc&... res) const {
     152                 :             :                 using index_type                       = typename Policy::index_type;
     153                 :     3579840 :                 typename Policy::index_array_type args = {static_cast<index_type>(x)...};
     154         [ +  - ]:     3579840 :                 f(args, res...);
     155                 :     3579840 :             }
     156                 :             :         };
     157                 :             : 
     158                 :             :         template <typename Functor, typename Policy, typename... T>
     159                 :             :         struct FunctorWrapper<FOR, Functor, Policy, std::tuple<T...>> {
     160                 :             :             Functor f;
     161                 :             : 
     162                 :    56663798 :             KOKKOS_INLINE_FUNCTION void operator()(T... x) const {
     163                 :             :                 using index_type                       = typename Policy::index_type;
     164                 :    56663798 :                 typename Policy::index_array_type args = {static_cast<index_type>(x)...};
     165         [ +  - ]:    56663798 :                 f(args);
     166                 :    56663798 :             }
     167                 :             :         };
     168                 :             : 
     169                 :             :         // Extracts the rank of a Kokkos range policy
     170                 :             :         template <typename>
     171                 :             :         struct ExtractRank;
     172                 :             : 
     173                 :             :         template <typename... T>
     174                 :             :         struct ExtractRank<Kokkos::RangePolicy<T...>> {
     175                 :             :             static constexpr int rank = 1;
     176                 :             :         };
     177                 :             :         template <typename... T>
     178                 :             :         struct ExtractRank<Kokkos::MDRangePolicy<T...>> {
     179                 :             :             static constexpr int rank = Kokkos::MDRangePolicy<T...>::rank;
     180                 :             :         };
     181                 :             :         template <typename T>
     182                 :             :         concept HasMemberValueType = requires() {
     183                 :             :                                          { typename T::value_type() };
     184                 :             :                                      };
     185                 :             :         template <typename T>
     186                 :             :         struct ExtractReducerReturnType {
     187                 :             :             using type = T;
     188                 :             :         };
     189                 :             :         template <HasMemberValueType T>
     190                 :             :         struct ExtractReducerReturnType<T> {
     191                 :             :             using type = typename T::value_type;
     192                 :             :         };
     193                 :             : 
     194                 :             :         /*!
     195                 :             :          * Convenience function for wrapping a functor with the wrapper struct.
     196                 :             :          * @tparam Functor the functor type
     197                 :             :          * @tparam Type the parallel dispatch type
     198                 :             :          * @tparam Policy the range policy type
     199                 :             :          * @tparam Acc... the accumulator type(s)
     200                 :             :          * @return A wrapper containing the given functor
     201                 :             :          */
     202                 :             :         template <e_functor_type Type, typename Policy, typename... Acc, typename Functor>
     203                 :        1616 :         auto functorize(const Functor& f) {
     204                 :        1616 :             constexpr unsigned Dim = ExtractRank<Policy>::rank;
     205                 :             :             using PolicyProperties = RangePolicy<Dim, typename Policy::execution_space>;
     206                 :             :             using index_type       = typename PolicyProperties::index_type;
     207                 :             :             return FunctorWrapper<Type, Functor, PolicyProperties,
     208                 :        1616 :                                   typename Coords<Dim, index_type>::type, Acc...>{f};
     209                 :             :         }
     210                 :             :     }  // namespace detail
     211                 :             : 
     212                 :             :     // Wrappers for Kokkos' parallel dispatch functions that use
     213                 :             :     // the IPPL functor wrapper
     214                 :             :     template <class ExecPolicy, class FunctorType>
     215                 :        1436 :     void parallel_for(const std::string& name, const ExecPolicy& policy,
     216                 :             :                       const FunctorType& functor) {
     217   [ +  -  +  - ]:        1436 :         Kokkos::parallel_for(name, policy, detail::functorize<detail::FOR, ExecPolicy>(functor));
     218                 :        1436 :     }
     219                 :             : 
     220                 :             :     template <class ExecPolicy, class FunctorType, class... ReducerArgument>
     221                 :         180 :     void parallel_reduce(const std::string& name, const ExecPolicy& policy,
     222                 :             :                          const FunctorType& functor, ReducerArgument&&... reducer) {
     223         [ +  - ]:         180 :         Kokkos::parallel_reduce(
     224                 :             :             name, policy,
     225                 :             :             detail::functorize<detail::REDUCE, ExecPolicy,
     226         [ +  - ]:         360 :                                typename detail::ExtractReducerReturnType<ReducerArgument>::type...>(
     227                 :             :                 functor),
     228                 :         180 :             std::forward<ReducerArgument>(reducer)...);
     229                 :         180 :     }
     230                 :             : }  // namespace ippl
     231                 :             : 
     232                 :             : #endif
        

Generated by: LCOV version 2.0-1