Skip to content

Commit 6090542

Browse files
committed
Add float16 reduction to MPI
1 parent 0b93f63 commit 6090542

File tree

1 file changed

+67
-5
lines changed

1 file changed

+67
-5
lines changed

mlx/distributed/mpi/mpi.cpp

+67-5
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,30 @@ array ensure_row_contiguous(const array& arr) {
3131
}
3232
}
3333

34+
// TODO: Change to a vectorized sum
35+
template <typename T>
36+
void simple_sum(
37+
void* input,
38+
void* accumulator,
39+
int* len,
40+
MPI_Datatype* datatype) {
41+
T* in = (T*)input;
42+
T* acc = (T*)accumulator;
43+
int N = *len;
44+
45+
while (N-- > 0) {
46+
*acc += *in;
47+
acc++;
48+
in++;
49+
}
50+
}
51+
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
52+
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);
53+
3454
struct MPIWrapper {
3555
MPIWrapper() {
56+
initialized_ = false;
57+
3658
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
3759
if (libmpi_handle_ == nullptr) {
3860
return;
@@ -47,6 +69,9 @@ struct MPIWrapper {
4769
LOAD_SYMBOL(MPI_Comm_free, comm_free);
4870
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
4971
LOAD_SYMBOL(MPI_Allgather, all_gather);
72+
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
73+
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
74+
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);
5075

5176
// Objects
5277
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@@ -76,7 +101,24 @@ struct MPIWrapper {
76101
if (!is_available()) {
77102
return false;
78103
}
79-
return init(nullptr, nullptr) == MPI_SUCCESS;
104+
bool success = init(nullptr, nullptr) == MPI_SUCCESS;
105+
106+
// Initialize custom types and ops
107+
if (success && !initialized_) {
108+
// Custom float16 dtypes
109+
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
110+
mpi_type_commit(&mpi_float16_);
111+
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
112+
mpi_type_commit(&mpi_bfloat16_);
113+
114+
// Custom sum ops
115+
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
116+
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);
117+
118+
initialized_ = true;
119+
}
120+
121+
return success;
80122
}
81123

82124
void finalize_safe() {
@@ -114,13 +156,21 @@ struct MPIWrapper {
114156
case complex64:
115157
return mpi_complex_;
116158
case float16:
159+
return mpi_float16_;
117160
case bfloat16:
118-
throw std::runtime_error("MPI doesn't support 16-bit floats");
161+
return mpi_bfloat16_;
119162
}
120163
}
121164

122-
MPI_Op op_sum() {
123-
return op_sum_;
165+
MPI_Op op_sum(const array& arr) {
166+
switch (arr.dtype()) {
167+
case float16:
168+
return op_sum_f16_;
169+
case bfloat16:
170+
return op_sum_bf16_;
171+
default:
172+
return op_sum_;
173+
}
124174
}
125175

126176
void* libmpi_handle_;
@@ -147,6 +197,8 @@ struct MPIWrapper {
147197

148198
// Ops
149199
MPI_Op op_sum_;
200+
MPI_Op op_sum_f16_;
201+
MPI_Op op_sum_bf16_;
150202

151203
// Datatypes
152204
MPI_Datatype mpi_bool_;
@@ -160,6 +212,16 @@ struct MPIWrapper {
160212
MPI_Datatype mpi_uint64_;
161213
MPI_Datatype mpi_float_;
162214
MPI_Datatype mpi_complex_;
215+
MPI_Datatype mpi_float16_;
216+
MPI_Datatype mpi_bfloat16_;
217+
218+
private:
219+
bool initialized_;
220+
221+
// Private API
222+
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
223+
int (*mpi_type_commit)(MPI_Datatype*);
224+
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
163225
};
164226

165227
MPIWrapper& mpi() {
@@ -268,7 +330,7 @@ void all_sum(Group group, const array& input_, array& output) {
268330
output.data<void>(),
269331
input.size(),
270332
mpi().datatype(input),
271-
mpi().op_sum(),
333+
mpi().op_sum(input),
272334
to_comm(group));
273335
}
274336

0 commit comments

Comments
 (0)