VirtualFluids 0.2.0
Parallel CFD LBM Solver
Loading...
Searching...
No Matches
MPICommunicator.h
Go to the documentation of this file.
1//=======================================================================================
2// ____ ____ __ ______ __________ __ __ __ __
3// \ \ | | | | | _ \ |___ ___| | | | | / \ | |
4// \ \ | | | | | |_) | | | | | | | / \ | |
5// \ \ | | | | | _ / | | | | | | / /\ \ | |
6// \ \ | | | | | | \ \ | | | \__/ | / ____ \ | |____
7// \ \ | | |__| |__| \__\ |__| \________/ /__/ \__\ |_______|
8// \ \ | | ________________________________________________________________
9// \ \ | | | ______________________________________________________________|
10// \ \| | | | __ __ __ __ ______ _______
11// \ | | |_____ | | | | | | | | | _ \ / _____)
12// \ | | _____| | | | | | | | | | | \ \ \_______
13// \ | | | | |_____ | \_/ | | | | |_/ / _____ |
14// \ _____| |__| |________| \_______/ |__| |______/ (_______/
15//
16// This file is part of VirtualFluids. VirtualFluids is free software: you can
17// redistribute it and/or modify it under the terms of the GNU General Public
18// License as published by the Free Software Foundation, either version 3 of
19// the License, or (at your option) any later version.
20//
21// VirtualFluids is distributed in the hope that it will be useful, but WITHOUT
22// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
23// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
24// for more details.
25//
26// SPDX-License-Identifier: GPL-3.0-or-later
27// SPDX-FileCopyrightText: Copyright © VirtualFluids Project contributors, see AUTHORS.md in root folder
28//
32//=======================================================================================
33#include <stdexcept>
34#if defined VF_MPI
35
36#ifndef MPI_MPICOMMUNICATOR_H
37#define MPI_MPICOMMUNICATOR_H
38
39#include "Communicator.h"
43#include <mpi.h>
44#include <string>
45#include <vector>
46
48#ifdef VF_DOUBLE_ACCURACY
49#define VF_MPI_REAL MPI_DOUBLE
50#else
51#define VF_MPI_REAL MPI_FLOAT
52#endif
54
55namespace vf::parallel
56{
57
62{
63public:
66
67 ~MPICommunicator() override;
68 static std::shared_ptr<Communicator> getInstance();
69 double Wtime() override;
70 int getBundleID() const override;
71 int getNumberOfBundles() const override;
72 int getProcessID() const override;
73 int getProcessID(int bundle, int rank) const override;
74 int getNumberOfProcesses() const override;
75 void *getNativeCommunicator() override;
76 int getRoot() const override;
77 int getBundleRoot() const override;
78 int getProcessRoot() const override;
79 int getNumberOfProcessesInBundle(int bundle) const override;
80 bool isRoot() const override;
81 void abort(int errorcode) override;
82
83 void sendSerializedObject(std::stringstream &ss, int target) override;
84 void receiveSerializedObject(std::stringstream &ss, int source) override;
85
86 void barrier() override;
87
88 std::vector<std::string> gather(const std::string &str) override;
89 std::vector<int> gather(std::vector<int> &values) override;
90 std::vector<float> gather(std::vector<float> &values) override;
91 std::vector<double> gather(std::vector<double> &values) override;
92 std::vector<unsigned long long> gather(std::vector<unsigned long long> &values) override;
93
94 void allGather(std::vector<int> &svalues, std::vector<int> &rvalues) override;
95 void allGather(std::vector<float> &svalues, std::vector<float> &rvalues) override;
96 void allGather(std::vector<double> &svalues, std::vector<double> &rvalues) override;
97 void allGather(std::vector<unsigned long long> &svalues, std::vector<unsigned long long> &rvalues) override;
98 void allGather(std::vector<uint> &svalues, std::vector<uint> &rvalues) override;
99
100 void allReduceSum(std::vector<float>& values) override;
101 void allReduceSum(std::vector<double>& values) override;
102 void allReduceSum(std::vector<uint>& values) override;
103
104 void broadcast(int &value) override;
105 void broadcast(float &value) override;
106 void broadcast(double &value) override;
107 void broadcast(long int &value) override;
108 void broadcast(std::vector<int> &values) override;
109 void broadcast(std::vector<float> &values) override;
110 void broadcast(std::vector<double> &values) override;
111 void broadcast(std::vector<long int> &values) override;
112
113 template <class T>
114 std::vector<T> gather(std::vector<T> &values);
115
116 template <class T>
117 void allGather(std::vector<T> &svalues, std::vector<T> &rvalues);
118 template <class T>
119 void allReduceSum(std::vector<T> &values);
120 template <class T>
121 void broadcast(std::vector<T> &values);
122
123 template <class T>
124 void broadcast(T &value);
125
127 int size_buffer_send, int neighbor_rank_send) const override;
128
129 void send(real *sbuf, int count_s, int nb_rank) const override;
130 double reduceSum(double quantityPerProcess) const override;
131
132 int mapCudaDevicesOnHosts(const std::vector<unsigned int> &devices, int numberOfDevices) const override;
134 int neighbor_rank) const override;
135
136 void receiveNonBlocking(real *rbuf, int count_r, int sourceRank) override;
137 void sendNonBlocking(real *sbuf, int count_s, int destinationRank) override;
138 void send(real *sbuf, int count_s, int destinationRank) override;
139 void waitAll() override;
140 void resetRequests() override;
141
142private:
144
145 int numprocs, PID;
146 MPI_Comm comm;
147 int root;
148
149 std::vector<MPI_Request> requests;
150};
151
153template<typename T>
155
156template<> inline MPI_Datatype getDataType<double>(){return MPI_DOUBLE; }
157template<> inline MPI_Datatype getDataType<float>(){return MPI_FLOAT; }
158template<> inline MPI_Datatype getDataType<int>(){return MPI_INT; }
160template<> inline MPI_Datatype getDataType<char>(){return MPI_CHAR; }
162
164template <class T>
165std::vector<T> MPICommunicator::gather(std::vector<T> &values)
166{
167 MPI_Datatype mpiDataType;
168 if ((std::string) typeid(T).name() == (std::string) typeid(double).name())
169 mpiDataType = MPI_DOUBLE;
170 else if ((std::string) typeid(T).name() == (std::string) typeid(float).name())
171 mpiDataType = MPI_FLOAT;
172 else if ((std::string) typeid(T).name() == (std::string) typeid(int).name())
173 mpiDataType = MPI_INT;
174 else if ((std::string) typeid(T).name() == (std::string) typeid(unsigned long long).name())
175 mpiDataType = MPI_UNSIGNED_LONG_LONG;
176 else if ((std::string) typeid(T).name() == (std::string) typeid(char).name())
177 mpiDataType = MPI_CHAR;
178 else
179 throw UbException(UB_EXARGS, "no MpiDataType for T" + (std::string) typeid(T).name());
180
181 int count = static_cast<int>(values.size());
182 std::vector<T> rvalues(1);
183
184 if (PID == root) {
185 rvalues.resize(numprocs * count);
186 }
187
188 MPI_Gather(&values[0], count, mpiDataType, &rvalues[0], count, mpiDataType, root, comm);
189
190 return rvalues;
191}
193template <class T>
194void MPICommunicator::allGather(std::vector<T> &svalues, std::vector<T> &rvalues)
195{
196 MPI_Datatype mpiDataType;
197 if ((std::string) typeid(T).name() == (std::string) typeid(double).name())
198 mpiDataType = MPI_DOUBLE;
199 else if ((std::string) typeid(T).name() == (std::string) typeid(float).name())
200 mpiDataType = MPI_FLOAT;
201 else if ((std::string) typeid(T).name() == (std::string) typeid(int).name())
202 mpiDataType = MPI_INT;
203 else if ((std::string) typeid(T).name() == (std::string) typeid(unsigned long long).name())
204 mpiDataType = MPI_UNSIGNED_LONG_LONG;
205 else
206 throw UbException(UB_EXARGS, "no MpiDataType for T" + (std::string) typeid(T).name());
207
208 int scount;
209 std::vector<int> displs, rcounts;
210
211 scount = (int)(svalues.size());
212
213 rcounts.resize(numprocs);
214 MPI_Allgather(&scount, 1, MPI_INT, &rcounts[0], 1, MPI_INT, comm);
215 displs.resize(numprocs);
216
217 if(numprocs < 1)
218 throw std::runtime_error("No processors!");
219
220 displs[0] = 0;
221
222 for (int i = 1; i < numprocs; ++i) {
223 displs[i] = displs[i - 1] + rcounts[i - 1];
224 }
225
226 rvalues.resize(displs[numprocs - 1] + rcounts[numprocs - 1]);
227
228 T* sval = NULL;
229 T* rval = NULL;
230
231 if (svalues.size() > 0) {
232 //svalues.resize(1);
233 //svalues[0] = 999;
234 sval = &svalues[0];
235 }
236
237 if (rvalues.size() > 0) {
238 //rvalues.resize(1);
239 //rvalues[0] = 999;
240 rval = &rvalues[0];
241 }
242
243 //MPI_Allgatherv(&svalues[0], scount, mpiDataType, &rvalues[0], &rcounts[0], &displs[0], mpiDataType, comm);
244 MPI_Allgatherv(sval, scount, mpiDataType, rval, &rcounts[0], &displs[0], mpiDataType, comm);
245}
247template <class T>
248void MPICommunicator::allReduceSum(std::vector<T> &values)
249{
250 MPI_Allreduce(MPI_IN_PLACE, values.data(), int(values.size()), getDataType<T>(), MPI_SUM, comm);
251}
253
254template <class T>
255void MPICommunicator::broadcast(std::vector<T> &values)
256{
257 MPI_Datatype mpiDataType;
258 if ((std::string) typeid(T).name() == (std::string) typeid(double).name())
259 mpiDataType = MPI_DOUBLE;
260 else if ((std::string) typeid(T).name() == (std::string) typeid(float).name())
261 mpiDataType = MPI_FLOAT;
262 else if ((std::string) typeid(T).name() == (std::string) typeid(int).name())
263 mpiDataType = MPI_INT;
264 else if ((std::string) typeid(T).name() == (std::string) typeid(long int).name())
265 mpiDataType = MPI_LONG_INT;
266 else
267 throw UbException(UB_EXARGS, "no MpiDataType for T" + (std::string) typeid(T).name());
268
269 int rcount;
270 if (this->PID == this->root) {
271 rcount = (int)values.size();
272 }
273
274 MPI_Bcast(&rcount, 1, MPI_INT, this->root, comm);
275
276 if (this->PID != this->root) {
277 values.resize(rcount);
278 }
279
280 MPI_Bcast(&values[0], (int)values.size(), mpiDataType, this->root, comm);
281}
283template <class T>
285{
286 MPI_Datatype mpiDataType;
287 if ((std::string) typeid(T).name() == (std::string) typeid(double).name())
288 mpiDataType = MPI_DOUBLE;
289 else if ((std::string) typeid(T).name() == (std::string) typeid(float).name())
290 mpiDataType = MPI_FLOAT;
291 else if ((std::string) typeid(T).name() == (std::string) typeid(int).name())
292 mpiDataType = MPI_INT;
293 else if ((std::string) typeid(T).name() == (std::string) typeid(long int).name())
294 mpiDataType = MPI_LONG_INT;
295 else
296 throw UbException(UB_EXARGS, "no MpiDataType for T" + (std::string) typeid(T).name());
297
298 MPI_Bcast(&value, 1, mpiDataType, this->root, comm);
299}
301
302
303#endif
304
305}
306
307#endif
308
An abstract class for communication between processes in parallel computation.
A class uses MPI library to communication.
double reduceSum(double quantityPerProcess) const override
void * getNativeCommunicator() override
void allGather(std::vector< int > &svalues, std::vector< int > &rvalues) override
MPICommunicator & operator=(MPICommunicator const &)=delete
void abort(int errorcode) override
int getProcessID() const override
void receiveNonBlocking(real *rbuf, int count_r, int sourceRank) override
int getBundleRoot() const override
void send(real *sbuf, int count_s, int nb_rank) const override
int getNumberOfBundles() const override
void allReduceSum(std::vector< float > &values) override
int mapCudaDevicesOnHosts(const std::vector< unsigned int > &devices, int numberOfDevices) const override
int getNumberOfProcesses() const override
MPICommunicator(MPICommunicator const &)=delete
std::vector< float > gather(std::vector< float > &values) override
void receiveSerializedObject(std::stringstream &ss, int source) override
void receiveSend(uint *buffer_receive, int size_buffer_recv, int neighbor_rank_recv, const uint *buffer_send, int size_buffer_send, int neighbor_rank_send) const override
int getProcessRoot() const override
std::vector< int > gather(std::vector< int > &values) override
void sendNonBlocking(real *sbuf, int count_s, int destinationRank) override
void broadcast(int &value) override
void sendSerializedObject(std::stringstream &ss, int target) override
std::vector< std::string > gather(const std::string &str) override
std::vector< double > gather(std::vector< double > &values) override
static std::shared_ptr< Communicator > getInstance()
int getNumberOfProcessesInBundle(int bundle) const override
std::shared_ptr< T > SPtr
float real
Definition DataTypes.h:42
unsigned int uint
Definition DataTypes.h:47
#define UB_EXARGS
Definition UbException.h:73
MPI_Datatype getDataType< char >()
MPI_Datatype getDataType< int >()
MPI_Datatype getDataType()
MPI_Datatype getDataType< double >()
MPI_Datatype getDataType< unsigned int >()
MPI_Datatype getDataType< unsigned long long >()
MPI_Datatype getDataType< float >()