Line data Source code
1 : #include "Communicate/DataTypes.h"
2 :
3 : #include "Communicate/Operations.h"
4 :
5 : namespace ippl {
6 : namespace mpi {
7 : template <typename T>
8 : void Communicator::gather(const T* input, T* output, int count, int root) {
9 : MPI_Datatype type = get_mpi_datatype<T>(*input);
10 :
11 : MPI_Gather(const_cast<T*>(input), count, type, output, count, type, root, *comm_m);
12 : }
13 :
14 : template <typename T>
15 : void Communicator::scatter(const T* input, T* output, int count, int root) {
16 : MPI_Datatype type = get_mpi_datatype<T>(*input);
17 :
18 : MPI_Scatter(const_cast<T*>(input), count, type, output, count, type, root, *comm_m);
19 : }
20 :
21 : template <typename T, class Op>
22 16 : void Communicator::reduce(const T* input, T* output, int count, Op, int root) {
23 16 : MPI_Datatype type = get_mpi_datatype<T>(*input);
24 :
25 16 : MPI_Op mpiOp = get_mpi_op<Op, T>();
26 :
27 16 : MPI_Reduce(const_cast<T*>(input), output, count, type, mpiOp, root, *comm_m);
28 16 : }
29 :
30 : template <typename T, class Op>
31 16 : void Communicator::reduce(const T& input, T& output, int count, Op op, int root) {
32 16 : reduce(&input, &output, count, op, root);
33 16 : }
34 :
35 : template <typename T, class Op>
36 378 : void Communicator::allreduce(const T* input, T* output, int count, Op) {
37 378 : MPI_Datatype type = get_mpi_datatype<T>(*input);
38 :
39 378 : MPI_Op mpiOp = get_mpi_op<Op, T>();
40 :
41 378 : MPI_Allreduce(const_cast<T*>(input), output, count, type, mpiOp, *comm_m);
42 378 : }
43 :
44 : template <typename T, class Op>
45 378 : void Communicator::allreduce(const T& input, T& output, int count, Op op) {
46 378 : allreduce(&input, &output, count, op);
47 378 : }
48 :
49 : template <typename T, class Op>
50 : void Communicator::allreduce(T* inout, int count, Op) {
51 : MPI_Datatype type = get_mpi_datatype<T>(*inout);
52 :
53 : MPI_Op mpiOp = get_mpi_op<Op, T>();
54 :
55 : MPI_Allreduce(MPI_IN_PLACE, inout, count, type, mpiOp, *comm_m);
56 : }
57 :
58 : template <typename T, class Op>
59 : void Communicator::allreduce(T& inout, int count, Op op) {
60 : allreduce(&inout, count, op);
61 : }
62 : } // namespace mpi
63 : } // namespace ippl
|