LCOV - code coverage report
Current view: top level - src/Communicate - Operations.h (source / functions) Coverage Total Hit
Test: report.info Lines: 74.2 % 31 23
Test Date: 2025-05-12 09:25:18 Functions: 72.7 % 132 96
Branches: 3.8 % 132 5

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // File Operations
       3                 :             : //   Definition of MPI operations following the implementation of Boost.MPI.
       4                 :             : //
       5                 :             : #ifndef IPPL_MPI_OPERATIONS_H
       6                 :             : #define IPPL_MPI_OPERATIONS_H
       7                 :             : 
       8                 :             : #include <Kokkos_Complex.hpp>
       9                 :             : #include <algorithm>
      10                 :             : #include <complex>
      11                 :             : #include <functional>
      12                 :             : #include <map>
      13                 :             : #include <mpi.h>
      14                 :             : #include <typeindex>
      15                 :             : #include <utility>
      16                 :             : 
      17                 :             : namespace ippl {
      18                 :             :     namespace mpi {
      19                 :             : 
      20                 :             :         enum struct binaryOperationKind {
      21                 :             :             SUM,
      22                 :             :             MIN,
      23                 :             :             MAX,
      24                 :             :             MULTIPLICATION  // TODO: Add all
      25                 :             :         };
      26                 :             : 
      27                 :             :         /**
      28                 :             :          * @brief Helper struct to distinguish between the four basic associative operation types
      29                 :             :          *
      30                 :             :          * @tparam T
      31                 :             :          */
      32                 :             :         template <typename T>
      33                 :             :         struct extractBinaryOperationKind {};
      34                 :             :         template <typename T>
      35                 :             :         struct extractBinaryOperationKind<std::plus<T>> {
      36                 :             :             constexpr static binaryOperationKind value = binaryOperationKind::SUM;
      37                 :             :         };
      38                 :             :         template <typename T>
      39                 :             :         struct extractBinaryOperationKind<std::multiplies<T>> {
      40                 :             :             constexpr static binaryOperationKind value = binaryOperationKind::MULTIPLICATION;
      41                 :             :         };
      42                 :             :         template <typename T>
      43                 :             :         struct extractBinaryOperationKind<std::less<T>> {
      44                 :             :             constexpr static binaryOperationKind value = binaryOperationKind::MIN;
      45                 :             :         };
      46                 :             :         template <typename T>
      47                 :             :         struct extractBinaryOperationKind<std::greater<T>> {
      48                 :             :             constexpr static binaryOperationKind value = binaryOperationKind::MAX;
      49                 :             :         };
      50                 :             :         template <typename T>
      51                 :             :         struct always_false : std::false_type {};
      52                 :             :         template <class>
      53                 :             :         struct is_ippl_mpi_type : std::false_type {};
      54                 :             :         struct dummy {};
      55                 :             : 
      56                 :             :         // Global map from {Operation, Type} to MPI_Op, for example: {std::plus<Vector<double, 3>>,
      57                 :             :         // Vector<double, 3>} -> _some MPI_Op_
      58                 :             :         static std::map<std::pair<std::type_index, std::type_index>, MPI_Op> mpiOperations;
      59                 :             : 
      60                 :             :         template <typename CppOpType, typename Datatype_IfNotTrivial>
      61                 :             :         struct getMpiOpImpl {
      62                 :             :             constexpr MPI_Op operator()() const noexcept {
      63                 :             :                 // We can't do static_assert(false) because static_assert(false), even if never
      64                 :             :                 // instantiated, was only made well-formed in 2022
      65                 :             :                 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2593r1.html)
      66                 :             :                 static_assert(always_false<CppOpType>::value, "This optype is not supported");
      67                 :             :                 return 0;  // Dummy return
      68                 :             :             }
      69                 :             :         };
      70                 :             :         /**
      71                 :             :          * @brief Helper struct to look up and store MPI_Op types for custom types and custom
      72                 :             :          * operations
      73                 :             :          *
      74                 :             :          * @tparam Op Operation type, for examples std::plus<Vector<double, 3>>,
      75                 :             :          * std::multiplies<...> etc.
      76                 :             :          * @tparam Type The underlying datatype, such as Vector<double, 3>, Matrix<...> etc.
      77                 :             :          */
      78                 :             :         template <class Op, typename Type>
      79                 :             :         struct getNontrivialMpiOpImpl /*<Op, Type>*/ {
      80                 :             :             /**
      81                 :             :              * @brief Get the MPI_Op for this CppOp + Type combo
      82                 :             :              *
      83                 :             :              * @return MPI_Op
      84                 :             :              */
      85                 :          36 :             MPI_Op operator()() {
      86                 :          72 :                 std::pair<std::type_index, std::type_index> pear{std::type_index(typeid(Op)),
      87                 :          36 :                                                                  std::type_index(typeid(Type))};
      88   [ +  -  -  + ]:          36 :                 if (mpiOperations.contains(pear)) {
      89         [ #  # ]:           0 :                     return mpiOperations.at(pear);
      90                 :             :                 }
      91                 :          36 :                 constexpr binaryOperationKind opKind = extractBinaryOperationKind<Op>::value;
      92                 :             :                 MPI_Op ret;
      93         [ +  - ]:          36 :                 MPI_Op_create(
      94                 :             :                     /**
      95                 :             :                      * @brief Construct a new lambda object without captures, therefore convertible
      96                 :             :                      * to a function pointer
      97                 :             :                      *
      98                 :             :                      * @param inputBuffer pointing to a Type object
      99                 :             :                      * @param outputBuffer pointing to a Type object
     100                 :             :                      * @param len Amount of _Type objects_! NOT amount of bytes!
     101                 :             :                      */
     102                 :           0 :                     [](void* inputBuffer, void* outputBuffer, int* len, MPI_Datatype*) {
     103                 :           0 :                         Type* input = (Type*)inputBuffer;
     104                 :             : 
     105                 :           0 :                         Type* output = (Type*)outputBuffer;
     106                 :             : 
     107   [ #  #  #  #  :           0 :                         for (int i = 0; i < *len; i++) {
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
                #  #  # ]
     108                 :             :                             if constexpr (opKind == binaryOperationKind::SUM) {
     109                 :           0 :                                 output[i] += input[i];
     110                 :             :                             }
     111                 :             :                             if constexpr (opKind == binaryOperationKind::MIN) {
     112   [ #  #  #  #  :           0 :                                 output[i] = min(output[i], input[i]);
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     113                 :             :                             }
     114                 :             :                             if constexpr (opKind == binaryOperationKind::MAX) {
     115   [ #  #  #  #  :           0 :                                 output[i] = max(output[i], input[i]);
          #  #  #  #  #  
          #  #  #  #  #  
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     116                 :             :                             }
     117                 :             :                             if constexpr (opKind == binaryOperationKind::MULTIPLICATION) {
     118                 :             :                                 output[i] *= input[i];
     119                 :             :                             }
     120                 :             :                         }
     121                 :             :                     },
     122                 :             :                     1, &ret);
     123         [ +  - ]:          36 :                 mpiOperations[pear] = ret;
     124                 :          36 :                 return ret;
     125                 :             :             }
     126                 :             :         };
     127                 :             : 
     128                 :             : #define IPPL_MPI_OP(CppOp, MPIOp)                                      \
     129                 :             :     template <typename Datatype_IfNotTrivial>                          \
     130                 :             :     struct getMpiOpImpl<CppOp, Datatype_IfNotTrivial> {                \
     131                 :             :         constexpr MPI_Op operator()() const noexcept { return MPIOp; } \
     132                 :             :     };                                                                 \
     133                 :             :     template <>                                                        \
     134                 :             :     struct is_ippl_mpi_type<CppOp> : std::true_type {};
     135                 :             : 
     136                 :             :         /* with C++14 we should be able
     137                 :             :          * to simply write
     138                 :             :          *
     139                 :             :          * IPPL_MPI_OP(std::plus<>, MPI_SUM);
     140                 :             :          *
     141                 :             :          */
     142                 :             : 
     143                 :             :         IPPL_MPI_OP(std::plus<char>, MPI_SUM);
     144                 :             :         IPPL_MPI_OP(std::plus<short>, MPI_SUM);
     145                 :             :         IPPL_MPI_OP(std::plus<int>, MPI_SUM);
     146                 :             :         IPPL_MPI_OP(std::plus<long>, MPI_SUM);
     147                 :             :         IPPL_MPI_OP(std::plus<long long>, MPI_SUM);
     148                 :             :         IPPL_MPI_OP(std::plus<unsigned char>, MPI_SUM);
     149                 :             :         IPPL_MPI_OP(std::plus<unsigned short>, MPI_SUM);
     150                 :          12 :         IPPL_MPI_OP(std::plus<unsigned int>, MPI_SUM);
     151                 :         180 :         IPPL_MPI_OP(std::plus<unsigned long>, MPI_SUM);
     152                 :             :         IPPL_MPI_OP(std::plus<unsigned long long>, MPI_SUM);
     153                 :          36 :         IPPL_MPI_OP(std::plus<float>, MPI_SUM);
     154                 :          72 :         IPPL_MPI_OP(std::plus<double>, MPI_SUM);
     155                 :             :         IPPL_MPI_OP(std::plus<long double>, MPI_SUM);
     156                 :             : 
     157                 :             :         IPPL_MPI_OP(std::plus<std::complex<float>>, MPI_SUM);
     158                 :             :         IPPL_MPI_OP(std::plus<std::complex<double>>, MPI_SUM);
     159                 :           2 :         IPPL_MPI_OP(std::plus<Kokkos::complex<float>>, MPI_SUM);
     160                 :           2 :         IPPL_MPI_OP(std::plus<Kokkos::complex<double>>, MPI_SUM);
     161                 :             : 
     162                 :             :         IPPL_MPI_OP(std::less<char>, MPI_MIN);
     163                 :             :         IPPL_MPI_OP(std::less<short>, MPI_MIN);
     164                 :             :         IPPL_MPI_OP(std::less<int>, MPI_MIN);
     165                 :             :         IPPL_MPI_OP(std::less<long>, MPI_MIN);
     166                 :             :         IPPL_MPI_OP(std::less<long long>, MPI_MIN);
     167                 :             :         IPPL_MPI_OP(std::less<unsigned char>, MPI_MIN);
     168                 :             :         IPPL_MPI_OP(std::less<unsigned short>, MPI_MIN);
     169                 :             :         IPPL_MPI_OP(std::less<unsigned int>, MPI_MIN);
     170                 :             :         IPPL_MPI_OP(std::less<unsigned long>, MPI_MIN);
     171                 :             :         IPPL_MPI_OP(std::less<unsigned long long>, MPI_MIN);
     172                 :           6 :         IPPL_MPI_OP(std::less<float>, MPI_MIN);
     173                 :           6 :         IPPL_MPI_OP(std::less<double>, MPI_MIN);
     174                 :             :         IPPL_MPI_OP(std::less<long double>, MPI_MIN);
     175                 :             : 
     176                 :             :         IPPL_MPI_OP(std::greater<char>, MPI_MAX);
     177                 :             :         IPPL_MPI_OP(std::greater<short>, MPI_MAX);
     178                 :             :         IPPL_MPI_OP(std::greater<int>, MPI_MAX);
     179                 :             :         IPPL_MPI_OP(std::greater<long>, MPI_MAX);
     180                 :             :         IPPL_MPI_OP(std::greater<long long>, MPI_MAX);
     181                 :             :         IPPL_MPI_OP(std::greater<unsigned char>, MPI_MAX);
     182                 :             :         IPPL_MPI_OP(std::greater<unsigned short>, MPI_MAX);
     183                 :             :         IPPL_MPI_OP(std::greater<unsigned int>, MPI_MAX);
     184                 :             :         IPPL_MPI_OP(std::greater<unsigned long>, MPI_MAX);
     185                 :             :         IPPL_MPI_OP(std::greater<unsigned long long>, MPI_MAX);
     186                 :          14 :         IPPL_MPI_OP(std::greater<float>, MPI_MAX);
     187                 :          14 :         IPPL_MPI_OP(std::greater<double>, MPI_MAX);
     188                 :             :         IPPL_MPI_OP(std::greater<long double>, MPI_MAX);
     189                 :             : 
     190                 :             :         IPPL_MPI_OP(std::multiplies<short>, MPI_PROD);
     191                 :             :         IPPL_MPI_OP(std::multiplies<int>, MPI_PROD);
     192                 :             :         IPPL_MPI_OP(std::multiplies<long>, MPI_PROD);
     193                 :             :         IPPL_MPI_OP(std::multiplies<long long>, MPI_PROD);
     194                 :             :         IPPL_MPI_OP(std::multiplies<unsigned short>, MPI_PROD);
     195                 :             :         IPPL_MPI_OP(std::multiplies<unsigned int>, MPI_PROD);
     196                 :             :         IPPL_MPI_OP(std::multiplies<unsigned long>, MPI_PROD);
     197                 :             :         IPPL_MPI_OP(std::multiplies<unsigned long long>, MPI_PROD);
     198                 :           6 :         IPPL_MPI_OP(std::multiplies<float>, MPI_PROD);
     199                 :           6 :         IPPL_MPI_OP(std::multiplies<double>, MPI_PROD);
     200                 :             :         IPPL_MPI_OP(std::multiplies<long double>, MPI_PROD);
     201                 :             : 
     202                 :             :         IPPL_MPI_OP(std::logical_or<bool>, MPI_LOR);
     203                 :             :         IPPL_MPI_OP(std::logical_and<bool>, MPI_LAND);
     204                 :             : 
     205                 :             :         template <typename Op, typename Datatype>
     206                 :         392 :         MPI_Op get_mpi_op() {
     207                 :             :             if constexpr (is_ippl_mpi_type<Op>::value) {
     208                 :         356 :                 return getMpiOpImpl<Op, Datatype>{}();
     209                 :             :             } else {
     210         [ +  - ]:          36 :                 return getNontrivialMpiOpImpl<Op, Datatype>{}();
     211                 :             :             }
     212                 :             :         }
     213                 :             :     }  // namespace mpi
     214                 :             : }  // namespace ippl
     215                 :             : 
     216                 :             : #endif
        

Generated by: LCOV version 2.0-1