Branch data Line data Source code
1 : : //
2 : : // File Operations
3 : : // Definition of MPI operations following the implementation of Boost.MPI.
4 : : //
5 : : #ifndef IPPL_MPI_OPERATIONS_H
6 : : #define IPPL_MPI_OPERATIONS_H
7 : :
8 : : #include <Kokkos_Complex.hpp>
9 : : #include <algorithm>
10 : : #include <complex>
11 : : #include <functional>
12 : : #include <map>
13 : : #include <mpi.h>
14 : : #include <typeindex>
15 : : #include <utility>
16 : :
17 : : namespace ippl {
18 : : namespace mpi {
19 : :
20 : : enum struct binaryOperationKind {
21 : : SUM,
22 : : MIN,
23 : : MAX,
24 : : MULTIPLICATION // TODO: Add all
25 : : };
26 : :
27 : : /**
28 : : * @brief Helper struct to distinguish between the four basic associative operation types
29 : : *
30 : : * @tparam T
31 : : */
32 : : template <typename T>
33 : : struct extractBinaryOperationKind {};
34 : : template <typename T>
35 : : struct extractBinaryOperationKind<std::plus<T>> {
36 : : constexpr static binaryOperationKind value = binaryOperationKind::SUM;
37 : : };
38 : : template <typename T>
39 : : struct extractBinaryOperationKind<std::multiplies<T>> {
40 : : constexpr static binaryOperationKind value = binaryOperationKind::MULTIPLICATION;
41 : : };
42 : : template <typename T>
43 : : struct extractBinaryOperationKind<std::less<T>> {
44 : : constexpr static binaryOperationKind value = binaryOperationKind::MIN;
45 : : };
46 : : template <typename T>
47 : : struct extractBinaryOperationKind<std::greater<T>> {
48 : : constexpr static binaryOperationKind value = binaryOperationKind::MAX;
49 : : };
50 : : template <typename T>
51 : : struct always_false : std::false_type {};
52 : : template <class>
53 : : struct is_ippl_mpi_type : std::false_type {};
54 : : struct dummy {};
55 : :
56 : : // Global map from {Operation, Type} to MPI_Op, for example: {std::plus<Vector<double, 3>>,
57 : : // Vector<double, 3>} -> _some MPI_Op_
58 : : static std::map<std::pair<std::type_index, std::type_index>, MPI_Op> mpiOperations;
59 : :
60 : : template <typename CppOpType, typename Datatype_IfNotTrivial>
61 : : struct getMpiOpImpl {
62 : : constexpr MPI_Op operator()() const noexcept {
63 : : // We can't do static_assert(false) because static_assert(false), even if never
64 : : // instantiated, was only made well-formed in 2022
65 : : // (https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2593r1.html)
66 : : static_assert(always_false<CppOpType>::value, "This optype is not supported");
67 : : return 0; // Dummy return
68 : : }
69 : : };
70 : : /**
71 : : * @brief Helper struct to look up and store MPI_Op types for custom types and custom
72 : : * operations
73 : : *
74 : : * @tparam Op Operation type, for examples std::plus<Vector<double, 3>>,
75 : : * std::multiplies<...> etc.
76 : : * @tparam Type The underlying datatype, such as Vector<double, 3>, Matrix<...> etc.
77 : : */
78 : : template <class Op, typename Type>
79 : : struct getNontrivialMpiOpImpl /*<Op, Type>*/ {
80 : : /**
81 : : * @brief Get the MPI_Op for this CppOp + Type combo
82 : : *
83 : : * @return MPI_Op
84 : : */
85 : 36 : MPI_Op operator()() {
86 : 72 : std::pair<std::type_index, std::type_index> pear{std::type_index(typeid(Op)),
87 : 36 : std::type_index(typeid(Type))};
88 [ + - - + ]: 36 : if (mpiOperations.contains(pear)) {
89 [ # # ]: 0 : return mpiOperations.at(pear);
90 : : }
91 : 36 : constexpr binaryOperationKind opKind = extractBinaryOperationKind<Op>::value;
92 : : MPI_Op ret;
93 [ + - ]: 36 : MPI_Op_create(
94 : : /**
95 : : * @brief Construct a new lambda object without captures, therefore convertible
96 : : * to a function pointer
97 : : *
98 : : * @param inputBuffer pointing to a Type object
99 : : * @param outputBuffer pointing to a Type object
100 : : * @param len Amount of _Type objects_! NOT amount of bytes!
101 : : */
102 : 0 : [](void* inputBuffer, void* outputBuffer, int* len, MPI_Datatype*) {
103 : 0 : Type* input = (Type*)inputBuffer;
104 : :
105 : 0 : Type* output = (Type*)outputBuffer;
106 : :
107 [ # # # # : 0 : for (int i = 0; i < *len; i++) {
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # # #
# # # ]
108 : : if constexpr (opKind == binaryOperationKind::SUM) {
109 : 0 : output[i] += input[i];
110 : : }
111 : : if constexpr (opKind == binaryOperationKind::MIN) {
112 [ # # # # : 0 : output[i] = min(output[i], input[i]);
# # # # #
# # # # #
# # # # #
# # # #
# ]
113 : : }
114 : : if constexpr (opKind == binaryOperationKind::MAX) {
115 [ # # # # : 0 : output[i] = max(output[i], input[i]);
# # # # #
# # # # #
# # # # #
# # # #
# ]
116 : : }
117 : : if constexpr (opKind == binaryOperationKind::MULTIPLICATION) {
118 : : output[i] *= input[i];
119 : : }
120 : : }
121 : : },
122 : : 1, &ret);
123 [ + - ]: 36 : mpiOperations[pear] = ret;
124 : 36 : return ret;
125 : : }
126 : : };
127 : :
128 : : #define IPPL_MPI_OP(CppOp, MPIOp) \
129 : : template <typename Datatype_IfNotTrivial> \
130 : : struct getMpiOpImpl<CppOp, Datatype_IfNotTrivial> { \
131 : : constexpr MPI_Op operator()() const noexcept { return MPIOp; } \
132 : : }; \
133 : : template <> \
134 : : struct is_ippl_mpi_type<CppOp> : std::true_type {};
135 : :
136 : : /* with C++14 we should be able
137 : : * to simply write
138 : : *
139 : : * IPPL_MPI_OP(std::plus<>, MPI_SUM);
140 : : *
141 : : */
142 : :
143 : : IPPL_MPI_OP(std::plus<char>, MPI_SUM);
144 : : IPPL_MPI_OP(std::plus<short>, MPI_SUM);
145 : : IPPL_MPI_OP(std::plus<int>, MPI_SUM);
146 : : IPPL_MPI_OP(std::plus<long>, MPI_SUM);
147 : : IPPL_MPI_OP(std::plus<long long>, MPI_SUM);
148 : : IPPL_MPI_OP(std::plus<unsigned char>, MPI_SUM);
149 : : IPPL_MPI_OP(std::plus<unsigned short>, MPI_SUM);
150 : 12 : IPPL_MPI_OP(std::plus<unsigned int>, MPI_SUM);
151 : 180 : IPPL_MPI_OP(std::plus<unsigned long>, MPI_SUM);
152 : : IPPL_MPI_OP(std::plus<unsigned long long>, MPI_SUM);
153 : 37 : IPPL_MPI_OP(std::plus<float>, MPI_SUM);
154 : 73 : IPPL_MPI_OP(std::plus<double>, MPI_SUM);
155 : : IPPL_MPI_OP(std::plus<long double>, MPI_SUM);
156 : :
157 : : IPPL_MPI_OP(std::plus<std::complex<float>>, MPI_SUM);
158 : : IPPL_MPI_OP(std::plus<std::complex<double>>, MPI_SUM);
159 : 2 : IPPL_MPI_OP(std::plus<Kokkos::complex<float>>, MPI_SUM);
160 : 2 : IPPL_MPI_OP(std::plus<Kokkos::complex<double>>, MPI_SUM);
161 : :
162 : : IPPL_MPI_OP(std::less<char>, MPI_MIN);
163 : : IPPL_MPI_OP(std::less<short>, MPI_MIN);
164 : : IPPL_MPI_OP(std::less<int>, MPI_MIN);
165 : : IPPL_MPI_OP(std::less<long>, MPI_MIN);
166 : : IPPL_MPI_OP(std::less<long long>, MPI_MIN);
167 : : IPPL_MPI_OP(std::less<unsigned char>, MPI_MIN);
168 : : IPPL_MPI_OP(std::less<unsigned short>, MPI_MIN);
169 : : IPPL_MPI_OP(std::less<unsigned int>, MPI_MIN);
170 : : IPPL_MPI_OP(std::less<unsigned long>, MPI_MIN);
171 : : IPPL_MPI_OP(std::less<unsigned long long>, MPI_MIN);
172 : 6 : IPPL_MPI_OP(std::less<float>, MPI_MIN);
173 : 6 : IPPL_MPI_OP(std::less<double>, MPI_MIN);
174 : : IPPL_MPI_OP(std::less<long double>, MPI_MIN);
175 : :
176 : : IPPL_MPI_OP(std::greater<char>, MPI_MAX);
177 : : IPPL_MPI_OP(std::greater<short>, MPI_MAX);
178 : : IPPL_MPI_OP(std::greater<int>, MPI_MAX);
179 : : IPPL_MPI_OP(std::greater<long>, MPI_MAX);
180 : : IPPL_MPI_OP(std::greater<long long>, MPI_MAX);
181 : : IPPL_MPI_OP(std::greater<unsigned char>, MPI_MAX);
182 : : IPPL_MPI_OP(std::greater<unsigned short>, MPI_MAX);
183 : : IPPL_MPI_OP(std::greater<unsigned int>, MPI_MAX);
184 : : IPPL_MPI_OP(std::greater<unsigned long>, MPI_MAX);
185 : : IPPL_MPI_OP(std::greater<unsigned long long>, MPI_MAX);
186 : 14 : IPPL_MPI_OP(std::greater<float>, MPI_MAX);
187 : 14 : IPPL_MPI_OP(std::greater<double>, MPI_MAX);
188 : : IPPL_MPI_OP(std::greater<long double>, MPI_MAX);
189 : :
190 : : IPPL_MPI_OP(std::multiplies<short>, MPI_PROD);
191 : : IPPL_MPI_OP(std::multiplies<int>, MPI_PROD);
192 : : IPPL_MPI_OP(std::multiplies<long>, MPI_PROD);
193 : : IPPL_MPI_OP(std::multiplies<long long>, MPI_PROD);
194 : : IPPL_MPI_OP(std::multiplies<unsigned short>, MPI_PROD);
195 : : IPPL_MPI_OP(std::multiplies<unsigned int>, MPI_PROD);
196 : : IPPL_MPI_OP(std::multiplies<unsigned long>, MPI_PROD);
197 : : IPPL_MPI_OP(std::multiplies<unsigned long long>, MPI_PROD);
198 : 6 : IPPL_MPI_OP(std::multiplies<float>, MPI_PROD);
199 : 6 : IPPL_MPI_OP(std::multiplies<double>, MPI_PROD);
200 : : IPPL_MPI_OP(std::multiplies<long double>, MPI_PROD);
201 : :
202 : : IPPL_MPI_OP(std::logical_or<bool>, MPI_LOR);
203 : : IPPL_MPI_OP(std::logical_and<bool>, MPI_LAND);
204 : :
205 : : template <typename Op, typename Datatype>
206 : 394 : MPI_Op get_mpi_op() {
207 : : if constexpr (is_ippl_mpi_type<Op>::value) {
208 : 358 : return getMpiOpImpl<Op, Datatype>{}();
209 : : } else {
210 [ + - ]: 36 : return getNontrivialMpiOpImpl<Op, Datatype>{}();
211 : : }
212 : : }
213 : : } // namespace mpi
214 : : } // namespace ippl
215 : :
216 : : #endif
|