Skip to content

Commit

Permalink
Fix synchronization bug for in stream async works
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 14, 2025
1 parent 5cc5201 commit 58ae068
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
58 changes: 33 additions & 25 deletions mlx/backend/metal/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

namespace mlx::core::distributed {

void signal_and_wait(const array& in, const array& out) {
if (in.event().valid()) {
encode_signal(in.event());
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (e_signal.valid()) {
encode_signal(e_signal);
}
encode_wait(out.event());
encode_wait(e_wait);
}

void AllReduce::eval_gpu(
Expand All @@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}

auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in,
out = out,
e = std::move(e),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
Expand All @@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
out.event().signal();
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

signal_and_wait(in, out);
}

void AllGather::eval_gpu(
Expand All @@ -65,15 +67,19 @@ void AllGather::eval_gpu(

out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto task = [in = in, out = out, group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
out.event().signal();
};
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);

auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
}

void Send::eval_gpu(
Expand All @@ -92,15 +98,14 @@ void Send::eval_gpu(
in.event().wait();
}
distributed::detail::send(group, out, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
// Encode a signal event for the input
if (in.event().valid()) {
encode_signal(in.event());
}
// TODO do we need an event wait as well?
}

void Recv::eval_gpu(
Expand All @@ -113,15 +118,18 @@ void Recv::eval_gpu(

out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto e = Event(stream());
e.set_value(1);

encode_wait(e);

// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

// Encode a wait event as there is no input for the recv to encode a signal.
encode_wait(out.event());
}

} // namespace mlx::core::distributed
10 changes: 7 additions & 3 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
read_task();
return;
}

auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {

auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait();
out.event().signal();
e.signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
encode_wait(out.event());
}

void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
Expand Down

0 comments on commit 58ae068

Please sign in to comment.