LCOV - code coverage report
Current view: top level - src/Communicate - Communicator.h (source / functions) Coverage Total Hit
Test: final_report.info Lines: 31.8 % 22 7
Test Date: 2025-07-10 08:04:31 Functions: 9.6 % 83 8

            Line data    Source code
       1              : //
       2              : // Class Communicator
       3              : //   Defines a class to do MPI communication.
       4              : //
       5              : #ifndef IPPL_MPI_COMMUNICATOR_H
       6              : #define IPPL_MPI_COMMUNICATOR_H
       7              : 
       8              : #include <memory>
       9              : #include <mpi.h>
      10              : 
      11              : #include "Communicate/BufferHandler.h"
      12              : #include "Communicate/LoggingBufferHandler.h"
      13              : #include "Communicate/Request.h"
      14              : #include "Communicate/Status.h"
      15              : 
      16              : ////////////////////////////////////////////////
      17              : // For message size check; see below
      18              : #include <climits>
      19              : #include <cstdlib>
      20              : 
      21              : #include "Utility/TypeUtils.h"
      22              : 
      23              : #include "Communicate/Archive.h"
      24              : #include "Communicate/TagMaker.h"
      25              : #include "Communicate/Tags.h"
      26              : ////////////////////////////////////////////////////
      27              : 
      28              : namespace ippl {
      29              :     namespace mpi {
      30              : 
      31              :         class Communicator : public TagMaker {
      32              :         public:
      33              :             Communicator();
      34              : 
      35              :             Communicator(MPI_Comm comm);
      36              : 
      37              :             Communicator& operator=(MPI_Comm comm);
      38              : 
      39         3602 :             ~Communicator() = default;
      40              : 
      41              :             Communicator split(int color, int key) const;
      42              : 
      43           60 :             operator const MPI_Comm&() const noexcept { return *comm_m; }
      44              : 
      45         1750 :             int size() const noexcept { return size_m; }
      46              : 
      47         4398 :             int rank() const noexcept { return rank_m; }
      48              : 
      49          382 :             void barrier() { MPI_Barrier(*comm_m); }
      50              : 
      51            0 :             void abort(int errorcode = -1) { MPI_Abort(*comm_m, errorcode); }
      52              : 
      53              :             /*
      54              :              * Blocking point-to-point communication
      55              :              *
      56              :              */
      57              : 
      58              :             template <typename T>
      59              :             void send(const T& buffer, int count, int dest, int tag);
      60              : 
      61              :             template <typename T>
      62              :             void send(const T* buffer, int count, int dest, int tag);
      63              : 
      64              :             template <typename T>
      65              :             void recv(T& output, int count, int source, int tag, Status& status);
      66              : 
      67              :             template <typename T>
      68              :             void recv(T* output, int count, int source, int tag, Status& status);
      69              : 
      70              :             void probe(int source, int tag, Status& status);
      71              : 
      72              :             /*
      73              :              * Non-blocking point-to-point communication
      74              :              *
      75              :              */
      76              : 
      77              :             template <typename T>
      78              :             void isend(const T& buffer, int count, int dest, int tag, Request& request);
      79              : 
      80              :             template <typename T>
      81              :             void isend(const T* buffer, int count, int dest, int tag, Request& request);
      82              : 
      83              :             template <typename T>
      84              :             void irecv(T& buffer, int count, int source, int tag, Request& request);
      85              : 
      86              :             template <typename T>
      87              :             void irecv(T* buffer, int count, int source, int tag, Request& request);
      88              : 
      89              :             bool iprobe(int source, int tag, Status& status);
      90              : 
      91              :             /*
      92              :              * Collective communication
      93              :              */
      94              : 
      95              :             /* Gather the data in the given source container from all other nodes to a
      96              :              * specific node (default: 0).
      97              :              */
      98              :             template <typename T>
      99              :             void gather(const T* input, T* output, int count, int root = 0);
     100              : 
     101              :             /* Scatter the data from all other nodes to a
     102              :              * specific node (default: 0).
     103              :              */
     104              :             template <typename T>
     105              :             void scatter(const T* input, T* output, int count, int root = 0);
     106              : 
     107              :             /* Reduce data coming from all nodes to a specific node
     108              :              * (default: 0). Apply certain operation
     109              :              *
     110              :              */
     111              :             template <typename T, class Op>
     112              :             void reduce(const T* input, T* output, int count, Op op, int root = 0);
     113              : 
     114              :             template <typename T, class Op>
     115              :             void reduce(const T& input, T& output, int count, Op op, int root = 0);
     116              : 
     117              :             template <typename T, class Op>
     118              :             void allreduce(const T* input, T* output, int count, Op op);
     119              : 
     120              :             template <typename T, class Op>
     121              :             void allreduce(const T& input, T& output, int count, Op op);
     122              : 
     123              :             template <typename T, class Op>
     124              :             void allreduce(T* inout, int count, Op op);
     125              : 
     126              :             template <typename T, class Op>
     127              :             void allreduce(T& inout, int count, Op op);
     128              : 
     129              :             /////////////////////////////////////////////////////////////////////////////////////
     130              :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     131              :             using archive_type = detail::Archive<MemorySpace>;
     132              : 
     133              :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     134              :             using buffer_type = std::shared_ptr<archive_type<MemorySpace>>;
     135              : 
     136              :         private:
     137              :             template <typename MemorySpace>
     138              :             using buffer_container_type = LoggingBufferHandler<MemorySpace>;
     139              : 
     140              :             using buffer_handler_type =
     141              :                 typename detail::ContainerForAllSpaces<buffer_container_type>::type;
     142              : 
     143              :         public:
     144              :             using size_type = detail::size_type;
     145          264 :             double getDefaultOverallocation() const { return defaultOveralloc_m; }
     146              :             void setDefaultOverallocation(double factor);
     147              : 
     148              :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space,
     149              :                       typename T           = char>
     150              :             buffer_type<MemorySpace> getBuffer(size_type size, double overallocation = 1.0);
     151              : 
     152              :             void deleteAllBuffers();
     153              :             void freeAllBuffers();
     154              : 
     155              :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     156              :             void freeBuffer(buffer_type<MemorySpace> buffer);
     157              : 
     158            8 :             const MPI_Comm& getCommunicator() const noexcept { return *comm_m; }
     159              : 
     160              :             template <class Buffer, typename Archive>
     161            0 :             void recv(int src, int tag, Buffer& buffer, Archive& ar, size_type msize,
     162              :                       size_type nrecvs) {
     163              :                 // Temporary fix. MPI communication seems to have problems when the
     164              :                 // count argument exceeds the range of int, so large messages should
     165              :                 // be split into smaller messages
     166            0 :                 if (msize > INT_MAX) {
     167            0 :                     std::cerr << "Message size exceeds range of int" << std::endl;
     168            0 :                     this->abort();
     169              :                 }
     170              :                 MPI_Status status;
     171            0 :                 MPI_Recv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &status);
     172              : 
     173            0 :                 buffer.deserialize(ar, nrecvs);
     174            0 :             }
     175              : 
     176              :             template <class Buffer, typename Archive>
     177            0 :             void isend(int dest, int tag, Buffer& buffer, Archive& ar, MPI_Request& request,
     178              :                        size_type nsends) {
     179            0 :                 if (ar.getSize() > INT_MAX) {
     180            0 :                     std::cerr << "Message size exceeds range of int" << std::endl;
     181            0 :                     this->abort();
     182              :                 }
     183            0 :                 buffer.serialize(ar, nsends);
     184            0 :                 MPI_Isend(ar.getBuffer(), ar.getSize(), MPI_BYTE, dest, tag, *comm_m, &request);
     185            0 :             }
     186              : 
     187              :             template <typename Archive>
     188              :             void irecv(int src, int tag, Archive& ar, MPI_Request& request, size_type msize) {
     189              :                 if (msize > INT_MAX) {
     190              :                     std::cerr << "Message size exceeds range of int" << std::endl;
     191              :                     this->abort();
     192              :                 }
     193              :                 MPI_Irecv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &request);
     194              :             }
     195              : 
     196              :             void printLogs(const std::string& filename);
     197              : 
     198              :         private:
     199              :             std::vector<LogEntry> gatherLocalLogs();
     200              :             void sendLogsToRank0(const std::vector<LogEntry>& localLogs);
     201              :             std::vector<LogEntry> gatherLogsFromAllRanks(const std::vector<LogEntry>& localLogs);
     202              :             void writeLogsToFile(const std::vector<LogEntry>& allLogs, const std::string& filename);
     203              : 
     204              :             buffer_handler_type buffer_handlers_m;
     205              : 
     206              :             double defaultOveralloc_m = 1.0;
     207              : 
     208              :             /////////////////////////////////////////////////////////////////////////////////////
     209              : 
     210              :         protected:
     211              :             std::shared_ptr<MPI_Comm> comm_m;
     212              :             int size_m;
     213              :             int rank_m;
     214              :         };
     215              :     }  // namespace mpi
     216              : }  // namespace ippl
     217              : 
     218              : #include "Communicate/Collectives.hpp"
     219              : #include "Communicate/PointToPoint.hpp"
     220              : 
     221              : ////////////////////////////////////
     222              : 
     223              : #include "Communicate/Buffers.hpp"
     224              : 
     225              : ////////////////////////////////////
     226              : 
     227              : #endif
        

Generated by: LCOV version 2.0-1