@@ -31,8 +31,30 @@ array ensure_row_contiguous(const array& arr) {
31
31
}
32
32
}
33
33
34
+ // TODO: Change to a vectorized sum
35
+ template <typename T>
36
+ void simple_sum (
37
+ void * input,
38
+ void * accumulator,
39
+ int * len,
40
+ MPI_Datatype* datatype) {
41
+ T* in = (T*)input;
42
+ T* acc = (T*)accumulator;
43
+ int N = *len;
44
+
45
+ while (N-- > 0 ) {
46
+ *acc += *in;
47
+ acc++;
48
+ in++;
49
+ }
50
+ }
51
+ template void simple_sum<float16_t >(void *, void *, int *, MPI_Datatype*);
52
+ template void simple_sum<bfloat16_t >(void *, void *, int *, MPI_Datatype*);
53
+
34
54
struct MPIWrapper {
35
55
MPIWrapper () {
56
+ initialized_ = false ;
57
+
36
58
libmpi_handle_ = dlopen (" libmpi.dylib" , RTLD_NOW | RTLD_GLOBAL);
37
59
if (libmpi_handle_ == nullptr ) {
38
60
return ;
@@ -47,6 +69,9 @@ struct MPIWrapper {
47
69
LOAD_SYMBOL (MPI_Comm_free, comm_free);
48
70
LOAD_SYMBOL (MPI_Allreduce, all_reduce);
49
71
LOAD_SYMBOL (MPI_Allgather, all_gather);
72
+ LOAD_SYMBOL (MPI_Type_contiguous, mpi_type_contiguous);
73
+ LOAD_SYMBOL (MPI_Type_commit, mpi_type_commit);
74
+ LOAD_SYMBOL (MPI_Op_create, mpi_op_create);
50
75
51
76
// Objects
52
77
LOAD_SYMBOL (ompi_mpi_comm_world, comm_world_);
@@ -76,7 +101,24 @@ struct MPIWrapper {
76
101
if (!is_available ()) {
77
102
return false ;
78
103
}
79
- return init (nullptr , nullptr ) == MPI_SUCCESS;
104
+ bool success = init (nullptr , nullptr ) == MPI_SUCCESS;
105
+
106
+ // Initialize custom types and ops
107
+ if (success && !initialized_) {
108
+ // Custom float16 dtypes
109
+ mpi_type_contiguous (2 , mpi_uint8_, &mpi_float16_);
110
+ mpi_type_commit (&mpi_float16_);
111
+ mpi_type_contiguous (2 , mpi_uint8_, &mpi_bfloat16_);
112
+ mpi_type_commit (&mpi_bfloat16_);
113
+
114
+ // Custom sum ops
115
+ mpi_op_create (&simple_sum<float16_t >, 1 , &op_sum_f16_);
116
+ mpi_op_create (&simple_sum<bfloat16_t >, 1 , &op_sum_bf16_);
117
+
118
+ initialized_ = true ;
119
+ }
120
+
121
+ return success;
80
122
}
81
123
82
124
void finalize_safe () {
@@ -114,13 +156,21 @@ struct MPIWrapper {
114
156
case complex64:
115
157
return mpi_complex_;
116
158
case float16:
159
+ return mpi_float16_;
117
160
case bfloat16:
118
- throw std::runtime_error ( " MPI doesn't support 16-bit floats " ) ;
161
+ return mpi_bfloat16_ ;
119
162
}
120
163
}
121
164
122
- MPI_Op op_sum () {
123
- return op_sum_;
165
+ MPI_Op op_sum (const array& arr) {
166
+ switch (arr.dtype ()) {
167
+ case float16:
168
+ return op_sum_f16_;
169
+ case bfloat16:
170
+ return op_sum_bf16_;
171
+ default :
172
+ return op_sum_;
173
+ }
124
174
}
125
175
126
176
void * libmpi_handle_;
@@ -147,6 +197,8 @@ struct MPIWrapper {
147
197
148
198
// Ops
149
199
MPI_Op op_sum_;
200
+ MPI_Op op_sum_f16_;
201
+ MPI_Op op_sum_bf16_;
150
202
151
203
// Datatypes
152
204
MPI_Datatype mpi_bool_;
@@ -160,6 +212,16 @@ struct MPIWrapper {
160
212
MPI_Datatype mpi_uint64_;
161
213
MPI_Datatype mpi_float_;
162
214
MPI_Datatype mpi_complex_;
215
+ MPI_Datatype mpi_float16_;
216
+ MPI_Datatype mpi_bfloat16_;
217
+
218
+ private:
219
+ bool initialized_;
220
+
221
+ // Private API
222
+ int (*mpi_type_contiguous)(int , MPI_Datatype, MPI_Datatype*);
223
+ int (*mpi_type_commit)(MPI_Datatype*);
224
+ int (*mpi_op_create)(MPI_User_function*, int , MPI_Op*);
163
225
};
164
226
165
227
MPIWrapper& mpi () {
@@ -268,7 +330,7 @@ void all_sum(Group group, const array& input_, array& output) {
268
330
output.data <void >(),
269
331
input.size (),
270
332
mpi ().datatype (input),
271
- mpi ().op_sum (),
333
+ mpi ().op_sum (input ),
272
334
to_comm (group));
273
335
}
274
336
0 commit comments