LCOV - code coverage report
Current view: top level - src/Communicate - Operations.h (source / functions) Coverage Total Hit
Test: final_report.info Lines: 96.8 % 31 30
Test Date: 2025-07-18 17:15:09 Functions: 100.0 % 132 132

            Line data    Source code
       1              : //
       2              : // File Operations
       3              : //   Definition of MPI operations following the implementation of Boost.MPI.
       4              : //
       5              : #ifndef IPPL_MPI_OPERATIONS_H
       6              : #define IPPL_MPI_OPERATIONS_H
       7              : 
       8              : #include <Kokkos_Complex.hpp>
       9              : #include <algorithm>
      10              : #include <complex>
      11              : #include <functional>
      12              : #include <map>
      13              : #include <mpi.h>
      14              : #include <typeindex>
      15              : #include <utility>
      16              : 
      17              : namespace ippl {
      18              :     namespace mpi {
      19              : 
      20              :         enum struct binaryOperationKind {
      21              :             SUM,
      22              :             MIN,
      23              :             MAX,
      24              :             MULTIPLICATION  // TODO: Add all
      25              :         };
      26              : 
      27              :         /**
      28              :          * @brief Helper struct to distinguish between the four basic associative operation types
      29              :          *
      30              :          * @tparam T
      31              :          */
      32              :         template <typename T>
      33              :         struct extractBinaryOperationKind {};
      34              :         template <typename T>
      35              :         struct extractBinaryOperationKind<std::plus<T>> {
      36              :             constexpr static binaryOperationKind value = binaryOperationKind::SUM;
      37              :         };
      38              :         template <typename T>
      39              :         struct extractBinaryOperationKind<std::multiplies<T>> {
      40              :             constexpr static binaryOperationKind value = binaryOperationKind::MULTIPLICATION;
      41              :         };
      42              :         template <typename T>
      43              :         struct extractBinaryOperationKind<std::less<T>> {
      44              :             constexpr static binaryOperationKind value = binaryOperationKind::MIN;
      45              :         };
      46              :         template <typename T>
      47              :         struct extractBinaryOperationKind<std::greater<T>> {
      48              :             constexpr static binaryOperationKind value = binaryOperationKind::MAX;
      49              :         };
      50              :         template <typename T>
      51              :         struct always_false : std::false_type {};
      52              :         template <class>
      53              :         struct is_ippl_mpi_type : std::false_type {};
      54              :         struct dummy {};
      55              : 
      56              :         // Global map from {Operation, Type} to MPI_Op, for example: {std::plus<Vector<double, 3>>,
      57              :         // Vector<double, 3>} -> _some MPI_Op_
      58              :         static std::map<std::pair<std::type_index, std::type_index>, MPI_Op> mpiOperations;
      59              : 
      60              :         template <typename CppOpType, typename Datatype_IfNotTrivial>
      61              :         struct getMpiOpImpl {
      62              :             constexpr MPI_Op operator()() const noexcept {
      63              :                 // We can't do static_assert(false) because static_assert(false), even if never
      64              :                 // instantiated, was only made well-formed in 2022
      65              :                 // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2593r1.html)
      66              :                 static_assert(always_false<CppOpType>::value, "This optype is not supported");
      67              :                 return 0;  // Dummy return
      68              :             }
      69              :         };
      70              :         /**
      71              :          * @brief Helper struct to look up and store MPI_Op types for custom types and custom
      72              :          * operations
      73              :          *
      74              :          * @tparam Op Operation type, for examples std::plus<Vector<double, 3>>,
      75              :          * std::multiplies<...> etc.
      76              :          * @tparam Type The underlying datatype, such as Vector<double, 3>, Matrix<...> etc.
      77              :          */
      78              :         template <class Op, typename Type>
      79              :         struct getNontrivialMpiOpImpl /*<Op, Type>*/ {
      80              :             /**
      81              :              * @brief Get the MPI_Op for this CppOp + Type combo
      82              :              *
      83              :              * @return MPI_Op
      84              :              */
      85           72 :             MPI_Op operator()() {
      86          144 :                 std::pair<std::type_index, std::type_index> pear{std::type_index(typeid(Op)),
      87           72 :                                                                  std::type_index(typeid(Type))};
      88           72 :                 if (mpiOperations.contains(pear)) {
      89            0 :                     return mpiOperations.at(pear);
      90              :                 }
      91           72 :                 constexpr binaryOperationKind opKind = extractBinaryOperationKind<Op>::value;
      92              :                 MPI_Op ret;
      93           72 :                 MPI_Op_create(
      94              :                     /**
      95              :                      * @brief Construct a new lambda object without captures, therefore convertible
      96              :                      * to a function pointer
      97              :                      *
      98              :                      * @param inputBuffer pointing to a Type object
      99              :                      * @param outputBuffer pointing to a Type object
     100              :                      * @param len Amount of _Type objects_! NOT amount of bytes!
     101              :                      */
     102           96 :                     [](void* inputBuffer, void* outputBuffer, int* len, MPI_Datatype*) {
     103           72 :                         Type* input = (Type*)inputBuffer;
     104              : 
     105           72 :                         Type* output = (Type*)outputBuffer;
     106              : 
     107          144 :                         for (int i = 0; i < *len; i++) {
     108              :                             if constexpr (opKind == binaryOperationKind::SUM) {
     109           24 :                                 output[i] += input[i];
     110              :                             }
     111              :                             if constexpr (opKind == binaryOperationKind::MIN) {
     112           24 :                                 output[i] = min(output[i], input[i]);
     113              :                             }
     114              :                             if constexpr (opKind == binaryOperationKind::MAX) {
     115           24 :                                 output[i] = max(output[i], input[i]);
     116              :                             }
     117              :                             if constexpr (opKind == binaryOperationKind::MULTIPLICATION) {
     118              :                                 output[i] *= input[i];
     119              :                             }
     120              :                         }
     121              :                     },
     122              :                     1, &ret);
     123           72 :                 mpiOperations[pear] = ret;
     124           72 :                 return ret;
     125              :             }
     126              :         };
     127              : 
     128              : #define IPPL_MPI_OP(CppOp, MPIOp)                                      \
     129              :     template <typename Datatype_IfNotTrivial>                          \
     130              :     struct getMpiOpImpl<CppOp, Datatype_IfNotTrivial> {                \
     131              :         constexpr MPI_Op operator()() const noexcept { return MPIOp; } \
     132              :     };                                                                 \
     133              :     template <>                                                        \
     134              :     struct is_ippl_mpi_type<CppOp> : std::true_type {};
     135              : 
     136              :         /* with C++14 we should be able
     137              :          * to simply write
     138              :          *
     139              :          * IPPL_MPI_OP(std::plus<>, MPI_SUM);
     140              :          *
     141              :          */
     142              : 
     143              :         IPPL_MPI_OP(std::plus<char>, MPI_SUM);
     144              :         IPPL_MPI_OP(std::plus<short>, MPI_SUM);
     145              :         IPPL_MPI_OP(std::plus<int>, MPI_SUM);
     146              :         IPPL_MPI_OP(std::plus<long>, MPI_SUM);
     147              :         IPPL_MPI_OP(std::plus<long long>, MPI_SUM);
     148              :         IPPL_MPI_OP(std::plus<unsigned char>, MPI_SUM);
     149              :         IPPL_MPI_OP(std::plus<unsigned short>, MPI_SUM);
     150           24 :         IPPL_MPI_OP(std::plus<unsigned int>, MPI_SUM);
     151          312 :         IPPL_MPI_OP(std::plus<unsigned long>, MPI_SUM);
     152              :         IPPL_MPI_OP(std::plus<unsigned long long>, MPI_SUM);
     153           82 :         IPPL_MPI_OP(std::plus<float>, MPI_SUM);
     154          202 :         IPPL_MPI_OP(std::plus<double>, MPI_SUM);
     155              :         IPPL_MPI_OP(std::plus<long double>, MPI_SUM);
     156              : 
     157              :         IPPL_MPI_OP(std::plus<std::complex<float>>, MPI_SUM);
     158              :         IPPL_MPI_OP(std::plus<std::complex<double>>, MPI_SUM);
     159            4 :         IPPL_MPI_OP(std::plus<Kokkos::complex<float>>, MPI_SUM);
     160            4 :         IPPL_MPI_OP(std::plus<Kokkos::complex<double>>, MPI_SUM);
     161              : 
     162              :         IPPL_MPI_OP(std::less<char>, MPI_MIN);
     163              :         IPPL_MPI_OP(std::less<short>, MPI_MIN);
     164              :         IPPL_MPI_OP(std::less<int>, MPI_MIN);
     165              :         IPPL_MPI_OP(std::less<long>, MPI_MIN);
     166              :         IPPL_MPI_OP(std::less<long long>, MPI_MIN);
     167              :         IPPL_MPI_OP(std::less<unsigned char>, MPI_MIN);
     168              :         IPPL_MPI_OP(std::less<unsigned short>, MPI_MIN);
     169              :         IPPL_MPI_OP(std::less<unsigned int>, MPI_MIN);
     170              :         IPPL_MPI_OP(std::less<unsigned long>, MPI_MIN);
     171              :         IPPL_MPI_OP(std::less<unsigned long long>, MPI_MIN);
     172           12 :         IPPL_MPI_OP(std::less<float>, MPI_MIN);
     173           12 :         IPPL_MPI_OP(std::less<double>, MPI_MIN);
     174              :         IPPL_MPI_OP(std::less<long double>, MPI_MIN);
     175              : 
     176              :         IPPL_MPI_OP(std::greater<char>, MPI_MAX);
     177              :         IPPL_MPI_OP(std::greater<short>, MPI_MAX);
     178              :         IPPL_MPI_OP(std::greater<int>, MPI_MAX);
     179              :         IPPL_MPI_OP(std::greater<long>, MPI_MAX);
     180              :         IPPL_MPI_OP(std::greater<long long>, MPI_MAX);
     181              :         IPPL_MPI_OP(std::greater<unsigned char>, MPI_MAX);
     182              :         IPPL_MPI_OP(std::greater<unsigned short>, MPI_MAX);
     183              :         IPPL_MPI_OP(std::greater<unsigned int>, MPI_MAX);
     184              :         IPPL_MPI_OP(std::greater<unsigned long>, MPI_MAX);
     185              :         IPPL_MPI_OP(std::greater<unsigned long long>, MPI_MAX);
     186           28 :         IPPL_MPI_OP(std::greater<float>, MPI_MAX);
     187           28 :         IPPL_MPI_OP(std::greater<double>, MPI_MAX);
     188              :         IPPL_MPI_OP(std::greater<long double>, MPI_MAX);
     189              : 
     190              :         IPPL_MPI_OP(std::multiplies<short>, MPI_PROD);
     191              :         IPPL_MPI_OP(std::multiplies<int>, MPI_PROD);
     192              :         IPPL_MPI_OP(std::multiplies<long>, MPI_PROD);
     193              :         IPPL_MPI_OP(std::multiplies<long long>, MPI_PROD);
     194              :         IPPL_MPI_OP(std::multiplies<unsigned short>, MPI_PROD);
     195              :         IPPL_MPI_OP(std::multiplies<unsigned int>, MPI_PROD);
     196              :         IPPL_MPI_OP(std::multiplies<unsigned long>, MPI_PROD);
     197              :         IPPL_MPI_OP(std::multiplies<unsigned long long>, MPI_PROD);
     198           12 :         IPPL_MPI_OP(std::multiplies<float>, MPI_PROD);
     199           12 :         IPPL_MPI_OP(std::multiplies<double>, MPI_PROD);
     200              :         IPPL_MPI_OP(std::multiplies<long double>, MPI_PROD);
     201              : 
     202              :         IPPL_MPI_OP(std::logical_or<bool>, MPI_LOR);
     203              :         IPPL_MPI_OP(std::logical_and<bool>, MPI_LAND);
     204              : 
     205              :         template <typename Op, typename Datatype>
     206          804 :         MPI_Op get_mpi_op() {
     207              :             if constexpr (is_ippl_mpi_type<Op>::value) {
     208          732 :                 return getMpiOpImpl<Op, Datatype>{}();
     209              :             } else {
     210           72 :                 return getNontrivialMpiOpImpl<Op, Datatype>{}();
     211              :             }
     212              :         }
     213              :     }  // namespace mpi
     214              : }  // namespace ippl
     215              : 
     216              : #endif
        

Generated by: LCOV version 2.0-1