LCOV - code coverage report
Current view: top level - src/Communicate - Communicator.h (source / functions) Coverage Total Hit
Test: report.info Lines: 31.8 % 22 7
Test Date: 2025-05-21 12:58:26 Functions: 9.9 % 81 8
Branches: 0.0 % 16 0

             Branch data     Line data    Source code
       1                 :             : //
       2                 :             : // Class Communicator
       3                 :             : //   Defines a class to do MPI communication.
       4                 :             : //
       5                 :             : #ifndef IPPL_MPI_COMMUNICATOR_H
       6                 :             : #define IPPL_MPI_COMMUNICATOR_H
       7                 :             : 
       8                 :             : #include <memory>
       9                 :             : #include <mpi.h>
      10                 :             : 
      11                 :             : #include "Communicate/BufferHandler.h"
      12                 :             : #include "Communicate/LoggingBufferHandler.h"
      13                 :             : #include "Communicate/Request.h"
      14                 :             : #include "Communicate/Status.h"
      15                 :             : 
      16                 :             : ////////////////////////////////////////////////
      17                 :             : // For message size check; see below
      18                 :             : #include <climits>
      19                 :             : #include <cstdlib>
      20                 :             : 
      21                 :             : #include "Utility/TypeUtils.h"
      22                 :             : 
      23                 :             : #include "Communicate/Archive.h"
      24                 :             : #include "Communicate/TagMaker.h"
      25                 :             : #include "Communicate/Tags.h"
      26                 :             : ////////////////////////////////////////////////////
      27                 :             : 
      28                 :             : namespace ippl {
      29                 :             :     namespace mpi {
      30                 :             : 
      31                 :             :         class Communicator : public TagMaker {
      32                 :             :         public:
      33                 :             :             Communicator();
      34                 :             : 
      35                 :             :             Communicator(MPI_Comm comm);
      36                 :             : 
      37                 :             :             Communicator& operator=(MPI_Comm comm);
      38                 :             : 
      39                 :        3602 :             ~Communicator() = default;
      40                 :             : 
      41                 :             :             Communicator split(int color, int key) const;
      42                 :             : 
      43                 :          60 :             operator const MPI_Comm&() const noexcept { return *comm_m; }
      44                 :             : 
      45                 :        1750 :             int size() const noexcept { return size_m; }
      46                 :             : 
      47                 :        4398 :             int rank() const noexcept { return rank_m; }
      48                 :             : 
      49                 :         382 :             void barrier() { MPI_Barrier(*comm_m); }
      50                 :             : 
      51                 :           0 :             void abort(int errorcode = -1) { MPI_Abort(*comm_m, errorcode); }
      52                 :             : 
      53                 :             :             /*
      54                 :             :              * Blocking point-to-point communication
      55                 :             :              *
      56                 :             :              */
      57                 :             : 
      58                 :             :             template <typename T>
      59                 :             :             void send(const T& buffer, int count, int dest, int tag);
      60                 :             : 
      61                 :             :             template <typename T>
      62                 :             :             void send(const T* buffer, int count, int dest, int tag);
      63                 :             : 
      64                 :             :             template <typename T>
      65                 :             :             void recv(T& output, int count, int source, int tag, Status& status);
      66                 :             : 
      67                 :             :             template <typename T>
      68                 :             :             void recv(T* output, int count, int source, int tag, Status& status);
      69                 :             : 
      70                 :             :             void probe(int source, int tag, Status& status);
      71                 :             : 
      72                 :             :             /*
      73                 :             :              * Non-blocking point-to-point communication
      74                 :             :              *
      75                 :             :              */
      76                 :             : 
      77                 :             :             template <typename T>
      78                 :             :             void isend(const T& buffer, int count, int dest, int tag, Request& request);
      79                 :             : 
      80                 :             :             template <typename T>
      81                 :             :             void isend(const T* buffer, int count, int dest, int tag, Request& request);
      82                 :             : 
      83                 :             :             template <typename T>
      84                 :             :             void irecv(T& buffer, int count, int source, int tag, Request& request);
      85                 :             : 
      86                 :             :             template <typename T>
      87                 :             :             void irecv(T* buffer, int count, int source, int tag, Request& request);
      88                 :             : 
      89                 :             :             bool iprobe(int source, int tag, Status& status);
      90                 :             : 
      91                 :             :             /*
      92                 :             :              * Collective communication
      93                 :             :              */
      94                 :             : 
      95                 :             :             /* Gather the data in the given source container from all other nodes to a
      96                 :             :              * specific node (default: 0).
      97                 :             :              */
      98                 :             :             template <typename T>
      99                 :             :             void gather(const T* input, T* output, int count, int root = 0);
     100                 :             : 
     101                 :             :             /* Scatter the data from all other nodes to a
     102                 :             :              * specific node (default: 0).
     103                 :             :              */
     104                 :             :             template <typename T>
     105                 :             :             void scatter(const T* input, T* output, int count, int root = 0);
     106                 :             : 
     107                 :             :             /* Reduce data coming from all nodes to a specific node
     108                 :             :              * (default: 0). Apply certain operation
     109                 :             :              *
     110                 :             :              */
     111                 :             :             template <typename T, class Op>
     112                 :             :             void reduce(const T* input, T* output, int count, Op op, int root = 0);
     113                 :             : 
     114                 :             :             template <typename T, class Op>
     115                 :             :             void reduce(const T& input, T& output, int count, Op op, int root = 0);
     116                 :             : 
     117                 :             :             template <typename T, class Op>
     118                 :             :             void allreduce(const T* input, T* output, int count, Op op);
     119                 :             : 
     120                 :             :             template <typename T, class Op>
     121                 :             :             void allreduce(const T& input, T& output, int count, Op op);
     122                 :             : 
     123                 :             :             template <typename T, class Op>
     124                 :             :             void allreduce(T* inout, int count, Op op);
     125                 :             : 
     126                 :             :             template <typename T, class Op>
     127                 :             :             void allreduce(T& inout, int count, Op op);
     128                 :             : 
     129                 :             :             /////////////////////////////////////////////////////////////////////////////////////
     130                 :             :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     131                 :             :             using archive_type = detail::Archive<MemorySpace>;
     132                 :             : 
     133                 :             :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     134                 :             :             using buffer_type = std::shared_ptr<archive_type<MemorySpace>>;
     135                 :             : 
     136                 :             :         private:
     137                 :             :             template <typename MemorySpace>
     138                 :             :             using buffer_container_type = LoggingBufferHandler<MemorySpace>;
     139                 :             : 
     140                 :             :             using buffer_handler_type =
     141                 :             :                 typename detail::ContainerForAllSpaces<buffer_container_type>::type;
     142                 :             : 
     143                 :             :         public:
     144                 :             :             using size_type = detail::size_type;
     145                 :         264 :             double getDefaultOverallocation() const { return defaultOveralloc_m; }
     146                 :             :             void setDefaultOverallocation(double factor);
     147                 :             : 
     148                 :             :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space,
     149                 :             :                       typename T           = char>
     150                 :             :             buffer_type<MemorySpace> getBuffer(size_type size, double overallocation = 1.0);
     151                 :             : 
     152                 :             :             void deleteAllBuffers();
     153                 :             :             void freeAllBuffers();
     154                 :             : 
     155                 :             :             template <typename MemorySpace = Kokkos::DefaultExecutionSpace::memory_space>
     156                 :             :             void freeBuffer(buffer_type<MemorySpace> buffer);
     157                 :             : 
     158                 :           8 :             const MPI_Comm& getCommunicator() const noexcept { return *comm_m; }
     159                 :             : 
     160                 :             :             template <class Buffer, typename Archive>
     161                 :           0 :             void recv(int src, int tag, Buffer& buffer, Archive& ar, size_type msize,
     162                 :             :                       size_type nrecvs) {
     163                 :             :                 // Temporary fix. MPI communication seems to have problems when the
     164                 :             :                 // count argument exceeds the range of int, so large messages should
     165                 :             :                 // be split into smaller messages
     166         [ #  # ]:           0 :                 if (msize > INT_MAX) {
     167   [ #  #  #  # ]:           0 :                     std::cerr << "Message size exceeds range of int" << std::endl;
     168         [ #  # ]:           0 :                     this->abort();
     169                 :             :                 }
     170                 :             :                 MPI_Status status;
     171   [ #  #  #  # ]:           0 :                 MPI_Recv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &status);
     172                 :             : 
     173         [ #  # ]:           0 :                 buffer.deserialize(ar, nrecvs);
     174                 :           0 :             }
     175                 :             : 
     176                 :             :             template <class Buffer, typename Archive>
     177                 :           0 :             void isend(int dest, int tag, Buffer& buffer, Archive& ar, MPI_Request& request,
     178                 :             :                        size_type nsends) {
     179         [ #  # ]:           0 :                 if (ar.getSize() > INT_MAX) {
     180                 :           0 :                     std::cerr << "Message size exceeds range of int" << std::endl;
     181                 :           0 :                     this->abort();
     182                 :             :                 }
     183                 :           0 :                 buffer.serialize(ar, nsends);
     184                 :           0 :                 MPI_Isend(ar.getBuffer(), ar.getSize(), MPI_BYTE, dest, tag, *comm_m, &request);
     185                 :           0 :             }
     186                 :             : 
     187                 :             :             template <typename Archive>
     188                 :             :             void irecv(int src, int tag, Archive& ar, MPI_Request& request, size_type msize) {
     189                 :             :                 if (msize > INT_MAX) {
     190                 :             :                     std::cerr << "Message size exceeds range of int" << std::endl;
     191                 :             :                     this->abort();
     192                 :             :                 }
     193                 :             :                 MPI_Irecv(ar.getBuffer(), msize, MPI_BYTE, src, tag, *comm_m, &request);
     194                 :             :             }
     195                 :             : 
     196                 :             :             void printLogs(const std::string& filename);
     197                 :             : 
     198                 :             :         private:
     199                 :             :             std::vector<LogEntry> gatherLocalLogs();
     200                 :             :             void sendLogsToRank0(const std::vector<LogEntry>& localLogs);
     201                 :             :             std::vector<LogEntry> gatherLogsFromAllRanks(const std::vector<LogEntry>& localLogs);
     202                 :             :             void writeLogsToFile(const std::vector<LogEntry>& allLogs, const std::string& filename);
     203                 :             : 
     204                 :             :             buffer_handler_type buffer_handlers_m;
     205                 :             : 
     206                 :             :             double defaultOveralloc_m = 1.0;
     207                 :             : 
     208                 :             :             /////////////////////////////////////////////////////////////////////////////////////
     209                 :             : 
     210                 :             :         protected:
     211                 :             :             std::shared_ptr<MPI_Comm> comm_m;
     212                 :             :             int size_m;
     213                 :             :             int rank_m;
     214                 :             :         };
     215                 :             :     }  // namespace mpi
     216                 :             : }  // namespace ippl
     217                 :             : 
     218                 :             : #include "Communicate/Collectives.hpp"
     219                 :             : #include "Communicate/PointToPoint.hpp"
     220                 :             : 
     221                 :             : ////////////////////////////////////
     222                 :             : 
     223                 :             : #include "Communicate/Buffers.hpp"
     224                 :             : 
     225                 :             : ////////////////////////////////////
     226                 :             : 
     227                 :             : #endif
        

Generated by: LCOV version 2.0-1