Branch data Line data Source code
1 : : //
2 : : // Class Communicator
3 : : // Defines a class to do MPI communication.
4 : : //
5 : : #ifndef IPPL_MPI_COMMUNICATOR_H
6 : : #define IPPL_MPI_COMMUNICATOR_H
7 : :
8 : : #include <memory>
9 : : #include <mpi.h>
10 : :
11 : : #include "Communicate/BufferHandler.h"
12 : : #include "Communicate/LoggingBufferHandler.h"
13 : : #include "Communicate/Request.h"
14 : : #include "Communicate/Status.h"
15 : :
16 : : ////////////////////////////////////////////////
17 : : // For message size check; see below
18 : : #include <climits>
19 : : #include <cstdlib>
20 : :
21 : : #include "Utility/TypeUtils.h"
22 : :
23 : : #include "Communicate/Archive.h"
24 : : #include "Communicate/TagMaker.h"
25 : : #include "Communicate/Tags.h"
26 : : ////////////////////////////////////////////////////
27 : :
28 : : namespace ippl {
29 : : namespace mpi {
30 : :
31 : : class Communicator : public TagMaker {
32 : : public:
33 : : Communicator();
34 : :
35 : : Communicator(MPI_Comm comm);
36 : :
37 : : Communicator& operator=(MPI_Comm comm);
38 : :
39 : 3602 : ~Communicator() = default;
40 : :
41 : : Communicator split(int color, int key) const;
42 : :
43 : 60 : operator const MPI_Comm&() const noexcept { return *comm_m; }
44 : :
45 : 1750 : int size() const noexcept { return size_m; }
46 : :
47 : 4398 : int rank() const noexcept { return rank_m; }
48 : :
49 : 382 : void barrier() { MPI_Barrier(*comm_m); }
50 : :
51 : 0 : void abort(int errorcode = -1) { MPI_Abort(*comm_m, errorcode); }
52 : :
53 : : /*
54 : : * Blocking point-to-point communication
55 : : *
56 : : */
57 : :
58 : : template <typename T>
59 : : void send(const T& buffer, int count, int dest, int tag);
60 : :
61 : : template <typename T>
62 : : void send(const T* buffer, int count, int dest, int tag);
63 : :
64 : : template <typename T>
65 : : void recv(T& output, int count, int source, int tag, Status& status);
66 : :
67 : : template <typename T>
68 : : void recv(T* output, int count, int source, int tag, Status& status);
69 : :
70 : : void probe(int source, int tag, Status& status);
71 : :
72 : : /*
73 : : * Non-blocking point-to-point communication
74 : : *
75 : : */
76 : :
77 : : template <typename T>
78 : : void isend(const T& buffer, int count, int dest, int tag, Request& request);
79 : :
80 : : template <typename T>
81 : : void isend(const T* buffer, int count, int dest, int tag, Request& request);
82 : :
83 : : template <typename T>
84 : : void irecv(T& buffer, int count, int source, int tag, Request& request);
85 : :
86 : : template <typename T>
87 : : void irecv(T* buffer, int count, int source, int tag, Request& request);
88 : :
89 : : bool iprobe(int source, int tag, Status& status);
90 : :
91 : : /*
92 : : * Collective communication
93 : : */
94 : :
95 : : /* Gather the data in the given source container from all other nodes to a
96 : : * specific node (default: 0).
97 : : */
98 : : template <typename T>
99 : : void gather(const T* input, T* output, int count, int root = 0);
100 : :
101 : : /* Scatter the data from all other nodes to a
102 : : * specific node (default: 0).
103 : : */
104 : : template <typename T>
105 : : void scatter(const T* input, T* output, int count, int root = 0);
106 : :
107 : : /* Reduce data coming from all nodes to a specific node
108 : : * (default: 0). Apply certain operation
109 : : *
110 : : */
111 : : template <typename T, class Op>
112 : : void reduce(const T* input, T* output, int count, Op op, int root = 0);
113 : :
114 : : template <typename T, class Op>
115 : : void reduce(const T& input, T& output, int count, Op op, int root = 0);
116 : :
117 : : template <typename T, class Op>
118 : : void allreduce(const T* input, T* output, int count, Op op);
119 : :
120 : : template <typename T, class Op>
121 : : void allreduce(const T& input, T& output, int count, Op op);
122 : :
123 : : template <typename T, class Op>
124 : : void allreduce(T* inout, int count, Op op);
125 : :
126 : : template <typename T, class Op>
127 : : void allreduce(T& inout, int count, Op op);
128 : :
129 : : /////////////////////////////////////////////////////////////////////////////////////
130 : : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
131 : : using archive_type = detail::Archive<MemorySpace>;
132 : :
133 : : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
134 : : using buffer_type = std::shared_ptr<archive_type<MemorySpace>>;
135 : :
136 : : private:
137 : : template <typename MemorySpace>
138 : : using buffer_container_type = LoggingBufferHandler<MemorySpace>;
139 : :
140 : : using buffer_handler_type =
141 : : typename detail::ContainerForAllSpaces<buffer_container_type>::type;
142 : :
143 : : public:
144 : : using size_type = detail::size_type;
145 : 264 : double getDefaultOverallocation() const { return defaultOveralloc_m; }
146 : : void setDefaultOverallocation(double factor);
147 : :
148 : : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space,
149 : : typename T = char>
150 : : buffer_type<MemorySpace> getBuffer(size_type size, double overallocation = 1.0);
151 : :
152 : : void deleteAllBuffers();
153 : : void freeAllBuffers();
154 : :
155 : : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
156 : : void freeBuffer(buffer_type<MemorySpace> buffer);
157 : :
158 : 8 : const MPI_Comm& getCommunicator() const noexcept { return *comm_m; }
159 : :
160 : : template <class Buffer, typename Archive>
161 : 0 : void recv(int src, int tag, Buffer& buffer, Archive& ar, size_type msize,
162 : : size_type nrecvs) {
163 : : // Temporary fix. MPI communication seems to have problems when the
164 : : // count argument exceeds the range of int, so large messages should
165 : : // be split into smaller messages
166 [ # # ]: 0 : if (msize > INT_MAX) {
167 [ # # # # ]: 0 : std::cerr << "Message size exceeds range of int" << std::endl;
168 [ # # ]: 0 : this->abort();
169 : : }
170 : : MPI_Status status;
171 [ # # # # ]: 0 : MPI_Recv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &status);
172 : :
173 [ # # ]: 0 : buffer.deserialize(ar, nrecvs);
174 : 0 : }
175 : :
176 : : template <class Buffer, typename Archive>
177 : 0 : void isend(int dest, int tag, Buffer& buffer, Archive& ar, MPI_Request& request,
178 : : size_type nsends) {
179 [ # # ]: 0 : if (ar.getSize() > INT_MAX) {
180 : 0 : std::cerr << "Message size exceeds range of int" << std::endl;
181 : 0 : this->abort();
182 : : }
183 : 0 : buffer.serialize(ar, nsends);
184 : 0 : MPI_Isend(ar.getBuffer(), ar.getSize(), MPI_BYTE, dest, tag, *comm_m, &request);
185 : 0 : }
186 : :
187 : : template <typename Archive>
188 : : void irecv(int src, int tag, Archive& ar, MPI_Request& request, size_type msize) {
189 : : if (msize > INT_MAX) {
190 : : std::cerr << "Message size exceeds range of int" << std::endl;
191 : : this->abort();
192 : : }
193 : : MPI_Irecv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &request);
194 : : }
195 : :
196 : : void printLogs(const std::string& filename);
197 : :
198 : : private:
199 : : std::vector<LogEntry> gatherLocalLogs();
200 : : void sendLogsToRank0(const std::vector<LogEntry>& localLogs);
201 : : std::vector<LogEntry> gatherLogsFromAllRanks(const std::vector<LogEntry>& localLogs);
202 : : void writeLogsToFile(const std::vector<LogEntry>& allLogs, const std::string& filename);
203 : :
204 : : buffer_handler_type buffer_handlers_m;
205 : :
206 : : double defaultOveralloc_m = 1.0;
207 : :
208 : : /////////////////////////////////////////////////////////////////////////////////////
209 : :
210 : : protected:
211 : : std::shared_ptr<MPI_Comm> comm_m;
212 : : int size_m;
213 : : int rank_m;
214 : : };
215 : : } // namespace mpi
216 : : } // namespace ippl
217 : :
218 : : #include "Communicate/Collectives.hpp"
219 : : #include "Communicate/PointToPoint.hpp"
220 : :
221 : : ////////////////////////////////////
222 : :
223 : : #include "Communicate/Buffers.hpp"
224 : :
225 : : ////////////////////////////////////
226 : :
227 : : #endif
|