Line data Source code
1 : //
2 : // File Operations
3 : // Definition of MPI operations following the implementation of Boost.MPI.
4 : //
5 : #ifndef IPPL_MPI_OPERATIONS_H
6 : #define IPPL_MPI_OPERATIONS_H
7 :
8 : #include <Kokkos_Complex.hpp>
9 : #include <algorithm>
10 : #include <complex>
11 : #include <functional>
12 : #include <map>
13 : #include <mpi.h>
14 : #include <typeindex>
15 : #include <utility>
16 :
17 : namespace ippl {
18 : namespace mpi {
19 :
20 : enum struct binaryOperationKind {
21 : SUM,
22 : MIN,
23 : MAX,
24 : MULTIPLICATION // TODO: Add all
25 : };
26 :
27 : /**
28 : * @brief Helper struct to distinguish between the four basic associative operation types
29 : *
30 : * @tparam T
31 : */
32 : template <typename T>
33 : struct extractBinaryOperationKind {};
34 : template <typename T>
35 : struct extractBinaryOperationKind<std::plus<T>> {
36 : constexpr static binaryOperationKind value = binaryOperationKind::SUM;
37 : };
38 : template <typename T>
39 : struct extractBinaryOperationKind<std::multiplies<T>> {
40 : constexpr static binaryOperationKind value = binaryOperationKind::MULTIPLICATION;
41 : };
42 : template <typename T>
43 : struct extractBinaryOperationKind<std::less<T>> {
44 : constexpr static binaryOperationKind value = binaryOperationKind::MIN;
45 : };
46 : template <typename T>
47 : struct extractBinaryOperationKind<std::greater<T>> {
48 : constexpr static binaryOperationKind value = binaryOperationKind::MAX;
49 : };
50 : template <typename T>
51 : struct always_false : std::false_type {};
52 : template <class>
53 : struct is_ippl_mpi_type : std::false_type {};
54 : struct dummy {};
55 :
56 : // Global map from {Operation, Type} to MPI_Op, for example: {std::plus<Vector<double, 3>>,
57 : // Vector<double, 3>} -> _some MPI_Op_
58 : static std::map<std::pair<std::type_index, std::type_index>, MPI_Op> mpiOperations;
59 :
60 : template <typename CppOpType, typename Datatype_IfNotTrivial>
61 : struct getMpiOpImpl {
62 : constexpr MPI_Op operator()() const noexcept {
63 : // We can't do static_assert(false) because static_assert(false), even if never
64 : // instantiated, was only made well-formed in 2022
65 : // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2593r1.html)
66 : static_assert(always_false<CppOpType>::value, "This optype is not supported");
67 : return 0; // Dummy return
68 : }
69 : };
70 : /**
71 : * @brief Helper struct to look up and store MPI_Op types for custom types and custom
72 : * operations
73 : *
74 : * @tparam Op Operation type, for examples std::plus<Vector<double, 3>>,
75 : * std::multiplies<...> etc.
76 : * @tparam Type The underlying datatype, such as Vector<double, 3>, Matrix<...> etc.
77 : */
78 : template <class Op, typename Type>
79 : struct getNontrivialMpiOpImpl /*<Op, Type>*/ {
80 : /**
81 : * @brief Get the MPI_Op for this CppOp + Type combo
82 : *
83 : * @return MPI_Op
84 : */
85 72 : MPI_Op operator()() {
86 144 : std::pair<std::type_index, std::type_index> pear{std::type_index(typeid(Op)),
87 72 : std::type_index(typeid(Type))};
88 72 : if (mpiOperations.contains(pear)) {
89 0 : return mpiOperations.at(pear);
90 : }
91 72 : constexpr binaryOperationKind opKind = extractBinaryOperationKind<Op>::value;
92 : MPI_Op ret;
93 72 : MPI_Op_create(
94 : /**
95 : * @brief Construct a new lambda object without captures, therefore convertible
96 : * to a function pointer
97 : *
98 : * @param inputBuffer pointing to a Type object
99 : * @param outputBuffer pointing to a Type object
100 : * @param len Amount of _Type objects_! NOT amount of bytes!
101 : */
102 96 : [](void* inputBuffer, void* outputBuffer, int* len, MPI_Datatype*) {
103 72 : Type* input = (Type*)inputBuffer;
104 :
105 72 : Type* output = (Type*)outputBuffer;
106 :
107 144 : for (int i = 0; i < *len; i++) {
108 : if constexpr (opKind == binaryOperationKind::SUM) {
109 24 : output[i] += input[i];
110 : }
111 : if constexpr (opKind == binaryOperationKind::MIN) {
112 24 : output[i] = min(output[i], input[i]);
113 : }
114 : if constexpr (opKind == binaryOperationKind::MAX) {
115 24 : output[i] = max(output[i], input[i]);
116 : }
117 : if constexpr (opKind == binaryOperationKind::MULTIPLICATION) {
118 : output[i] *= input[i];
119 : }
120 : }
121 : },
122 : 1, &ret);
123 72 : mpiOperations[pear] = ret;
124 72 : return ret;
125 : }
126 : };
127 :
128 : #define IPPL_MPI_OP(CppOp, MPIOp) \
129 : template <typename Datatype_IfNotTrivial> \
130 : struct getMpiOpImpl<CppOp, Datatype_IfNotTrivial> { \
131 : constexpr MPI_Op operator()() const noexcept { return MPIOp; } \
132 : }; \
133 : template <> \
134 : struct is_ippl_mpi_type<CppOp> : std::true_type {};
135 :
136 : /* with C++14 we should be able
137 : * to simply write
138 : *
139 : * IPPL_MPI_OP(std::plus<>, MPI_SUM);
140 : *
141 : */
142 :
143 : IPPL_MPI_OP(std::plus<char>, MPI_SUM);
144 : IPPL_MPI_OP(std::plus<short>, MPI_SUM);
145 : IPPL_MPI_OP(std::plus<int>, MPI_SUM);
146 : IPPL_MPI_OP(std::plus<long>, MPI_SUM);
147 : IPPL_MPI_OP(std::plus<long long>, MPI_SUM);
148 : IPPL_MPI_OP(std::plus<unsigned char>, MPI_SUM);
149 : IPPL_MPI_OP(std::plus<unsigned short>, MPI_SUM);
150 24 : IPPL_MPI_OP(std::plus<unsigned int>, MPI_SUM);
151 312 : IPPL_MPI_OP(std::plus<unsigned long>, MPI_SUM);
152 : IPPL_MPI_OP(std::plus<unsigned long long>, MPI_SUM);
153 82 : IPPL_MPI_OP(std::plus<float>, MPI_SUM);
154 202 : IPPL_MPI_OP(std::plus<double>, MPI_SUM);
155 : IPPL_MPI_OP(std::plus<long double>, MPI_SUM);
156 :
157 : IPPL_MPI_OP(std::plus<std::complex<float>>, MPI_SUM);
158 : IPPL_MPI_OP(std::plus<std::complex<double>>, MPI_SUM);
159 4 : IPPL_MPI_OP(std::plus<Kokkos::complex<float>>, MPI_SUM);
160 4 : IPPL_MPI_OP(std::plus<Kokkos::complex<double>>, MPI_SUM);
161 :
162 : IPPL_MPI_OP(std::less<char>, MPI_MIN);
163 : IPPL_MPI_OP(std::less<short>, MPI_MIN);
164 : IPPL_MPI_OP(std::less<int>, MPI_MIN);
165 : IPPL_MPI_OP(std::less<long>, MPI_MIN);
166 : IPPL_MPI_OP(std::less<long long>, MPI_MIN);
167 : IPPL_MPI_OP(std::less<unsigned char>, MPI_MIN);
168 : IPPL_MPI_OP(std::less<unsigned short>, MPI_MIN);
169 : IPPL_MPI_OP(std::less<unsigned int>, MPI_MIN);
170 : IPPL_MPI_OP(std::less<unsigned long>, MPI_MIN);
171 : IPPL_MPI_OP(std::less<unsigned long long>, MPI_MIN);
172 12 : IPPL_MPI_OP(std::less<float>, MPI_MIN);
173 12 : IPPL_MPI_OP(std::less<double>, MPI_MIN);
174 : IPPL_MPI_OP(std::less<long double>, MPI_MIN);
175 :
176 : IPPL_MPI_OP(std::greater<char>, MPI_MAX);
177 : IPPL_MPI_OP(std::greater<short>, MPI_MAX);
178 : IPPL_MPI_OP(std::greater<int>, MPI_MAX);
179 : IPPL_MPI_OP(std::greater<long>, MPI_MAX);
180 : IPPL_MPI_OP(std::greater<long long>, MPI_MAX);
181 : IPPL_MPI_OP(std::greater<unsigned char>, MPI_MAX);
182 : IPPL_MPI_OP(std::greater<unsigned short>, MPI_MAX);
183 : IPPL_MPI_OP(std::greater<unsigned int>, MPI_MAX);
184 : IPPL_MPI_OP(std::greater<unsigned long>, MPI_MAX);
185 : IPPL_MPI_OP(std::greater<unsigned long long>, MPI_MAX);
186 28 : IPPL_MPI_OP(std::greater<float>, MPI_MAX);
187 28 : IPPL_MPI_OP(std::greater<double>, MPI_MAX);
188 : IPPL_MPI_OP(std::greater<long double>, MPI_MAX);
189 :
190 : IPPL_MPI_OP(std::multiplies<short>, MPI_PROD);
191 : IPPL_MPI_OP(std::multiplies<int>, MPI_PROD);
192 : IPPL_MPI_OP(std::multiplies<long>, MPI_PROD);
193 : IPPL_MPI_OP(std::multiplies<long long>, MPI_PROD);
194 : IPPL_MPI_OP(std::multiplies<unsigned short>, MPI_PROD);
195 : IPPL_MPI_OP(std::multiplies<unsigned int>, MPI_PROD);
196 : IPPL_MPI_OP(std::multiplies<unsigned long>, MPI_PROD);
197 : IPPL_MPI_OP(std::multiplies<unsigned long long>, MPI_PROD);
198 12 : IPPL_MPI_OP(std::multiplies<float>, MPI_PROD);
199 12 : IPPL_MPI_OP(std::multiplies<double>, MPI_PROD);
200 : IPPL_MPI_OP(std::multiplies<long double>, MPI_PROD);
201 :
202 : IPPL_MPI_OP(std::logical_or<bool>, MPI_LOR);
203 : IPPL_MPI_OP(std::logical_and<bool>, MPI_LAND);
204 :
205 : template <typename Op, typename Datatype>
206 804 : MPI_Op get_mpi_op() {
207 : if constexpr (is_ippl_mpi_type<Op>::value) {
208 732 : return getMpiOpImpl<Op, Datatype>{}();
209 : } else {
210 72 : return getNontrivialMpiOpImpl<Op, Datatype>{}();
211 : }
212 : }
213 : } // namespace mpi
214 : } // namespace ippl
215 :
216 : #endif
|