Line data Source code
1 : //
2 : // Class Communicator
3 : // Defines a class to do MPI communication.
4 : //
5 : #ifndef IPPL_MPI_COMMUNICATOR_H
6 : #define IPPL_MPI_COMMUNICATOR_H
7 :
8 : #include <memory>
9 : #include <mpi.h>
10 :
11 : #include "Communicate/BufferHandler.h"
12 : #include "Communicate/LoggingBufferHandler.h"
13 : #include "Communicate/Request.h"
14 : #include "Communicate/Status.h"
15 :
16 : ////////////////////////////////////////////////
17 : // For message size check; see below
18 : #include <climits>
19 : #include <cstdlib>
20 :
21 : #include "Utility/TypeUtils.h"
22 :
23 : #include "Communicate/Archive.h"
24 : #include "Communicate/TagMaker.h"
25 : #include "Communicate/Tags.h"
26 : ////////////////////////////////////////////////////
27 :
28 : namespace ippl {
29 : namespace mpi {
30 :
31 : class Communicator : public TagMaker {
32 : public:
33 : Communicator();
34 :
35 : Communicator(MPI_Comm comm);
36 :
37 : Communicator& operator=(MPI_Comm comm);
38 :
39 3602 : ~Communicator() = default;
40 :
41 : Communicator split(int color, int key) const;
42 :
43 60 : operator const MPI_Comm&() const noexcept { return *comm_m; }
44 :
45 1750 : int size() const noexcept { return size_m; }
46 :
47 4398 : int rank() const noexcept { return rank_m; }
48 :
49 382 : void barrier() { MPI_Barrier(*comm_m); }
50 :
51 0 : void abort(int errorcode = -1) { MPI_Abort(*comm_m, errorcode); }
52 :
53 : /*
54 : * Blocking point-to-point communication
55 : *
56 : */
57 :
58 : template <typename T>
59 : void send(const T& buffer, int count, int dest, int tag);
60 :
61 : template <typename T>
62 : void send(const T* buffer, int count, int dest, int tag);
63 :
64 : template <typename T>
65 : void recv(T& output, int count, int source, int tag, Status& status);
66 :
67 : template <typename T>
68 : void recv(T* output, int count, int source, int tag, Status& status);
69 :
70 : void probe(int source, int tag, Status& status);
71 :
72 : /*
73 : * Non-blocking point-to-point communication
74 : *
75 : */
76 :
77 : template <typename T>
78 : void isend(const T& buffer, int count, int dest, int tag, Request& request);
79 :
80 : template <typename T>
81 : void isend(const T* buffer, int count, int dest, int tag, Request& request);
82 :
83 : template <typename T>
84 : void irecv(T& buffer, int count, int source, int tag, Request& request);
85 :
86 : template <typename T>
87 : void irecv(T* buffer, int count, int source, int tag, Request& request);
88 :
89 : bool iprobe(int source, int tag, Status& status);
90 :
91 : /*
92 : * Collective communication
93 : */
94 :
95 : /* Gather the data in the given source container from all other nodes to a
96 : * specific node (default: 0).
97 : */
98 : template <typename T>
99 : void gather(const T* input, T* output, int count, int root = 0);
100 :
101 : /* Scatter the data from all other nodes to a
102 : * specific node (default: 0).
103 : */
104 : template <typename T>
105 : void scatter(const T* input, T* output, int count, int root = 0);
106 :
107 : /* Reduce data coming from all nodes to a specific node
108 : * (default: 0). Apply certain operation
109 : *
110 : */
111 : template <typename T, class Op>
112 : void reduce(const T* input, T* output, int count, Op op, int root = 0);
113 :
114 : template <typename T, class Op>
115 : void reduce(const T& input, T& output, int count, Op op, int root = 0);
116 :
117 : template <typename T, class Op>
118 : void allreduce(const T* input, T* output, int count, Op op);
119 :
120 : template <typename T, class Op>
121 : void allreduce(const T& input, T& output, int count, Op op);
122 :
123 : template <typename T, class Op>
124 : void allreduce(T* inout, int count, Op op);
125 :
126 : template <typename T, class Op>
127 : void allreduce(T& inout, int count, Op op);
128 :
129 : /////////////////////////////////////////////////////////////////////////////////////
130 : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
131 : using archive_type = detail::Archive<MemorySpace>;
132 :
133 : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
134 : using buffer_type = std::shared_ptr<archive_type<MemorySpace>>;
135 :
136 : private:
137 : template <typename MemorySpace>
138 : using buffer_container_type = LoggingBufferHandler<MemorySpace>;
139 :
140 : using buffer_handler_type =
141 : typename detail::ContainerForAllSpaces<buffer_container_type>::type;
142 :
143 : public:
144 : using size_type = detail::size_type;
145 264 : double getDefaultOverallocation() const { return defaultOveralloc_m; }
146 : void setDefaultOverallocation(double factor);
147 :
148 : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space,
149 : typename T = char>
150 : buffer_type<MemorySpace> getBuffer(size_type size, double overallocation = 1.0);
151 :
152 : void deleteAllBuffers();
153 : void freeAllBuffers();
154 :
155 : template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
156 : void freeBuffer(buffer_type<MemorySpace> buffer);
157 :
158 8 : const MPI_Comm& getCommunicator() const noexcept { return *comm_m; }
159 :
160 : template <class Buffer, typename Archive>
161 0 : void recv(int src, int tag, Buffer& buffer, Archive& ar, size_type msize,
162 : size_type nrecvs) {
163 : // Temporary fix. MPI communication seems to have problems when the
164 : // count argument exceeds the range of int, so large messages should
165 : // be split into smaller messages
166 0 : if (msize > INT_MAX) {
167 0 : std::cerr << "Message size exceeds range of int" << std::endl;
168 0 : this->abort();
169 : }
170 : MPI_Status status;
171 0 : MPI_Recv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &status);
172 :
173 0 : buffer.deserialize(ar, nrecvs);
174 0 : }
175 :
176 : template <class Buffer, typename Archive>
177 0 : void isend(int dest, int tag, Buffer& buffer, Archive& ar, MPI_Request& request,
178 : size_type nsends) {
179 0 : if (ar.getSize() > INT_MAX) {
180 0 : std::cerr << "Message size exceeds range of int" << std::endl;
181 0 : this->abort();
182 : }
183 0 : buffer.serialize(ar, nsends);
184 0 : MPI_Isend(ar.getBuffer(), ar.getSize(), MPI_BYTE, dest, tag, *comm_m, &request);
185 0 : }
186 :
187 : template <typename Archive>
188 : void irecv(int src, int tag, Archive& ar, MPI_Request& request, size_type msize) {
189 : if (msize > INT_MAX) {
190 : std::cerr << "Message size exceeds range of int" << std::endl;
191 : this->abort();
192 : }
193 : MPI_Irecv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &request);
194 : }
195 :
196 : void printLogs(const std::string& filename);
197 :
198 : private:
199 : std::vector<LogEntry> gatherLocalLogs();
200 : void sendLogsToRank0(const std::vector<LogEntry>& localLogs);
201 : std::vector<LogEntry> gatherLogsFromAllRanks(const std::vector<LogEntry>& localLogs);
202 : void writeLogsToFile(const std::vector<LogEntry>& allLogs, const std::string& filename);
203 :
204 : buffer_handler_type buffer_handlers_m;
205 :
206 : double defaultOveralloc_m = 1.0;
207 :
208 : /////////////////////////////////////////////////////////////////////////////////////
209 :
210 : protected:
211 : std::shared_ptr<MPI_Comm> comm_m;
212 : int size_m;
213 : int rank_m;
214 : };
215 : } // namespace mpi
216 : } // namespace ippl
217 :
218 : #include "Communicate/Collectives.hpp"
219 : #include "Communicate/PointToPoint.hpp"
220 :
221 : ////////////////////////////////////
222 :
223 : #include "Communicate/Buffers.hpp"
224 :
225 : ////////////////////////////////////
226 :
227 : #endif
|