Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix synchronization bug for GPU stream async CPU work #1768

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions mlx/backend/metal/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

namespace mlx::core::distributed {

void signal_and_wait(const array& in, const array& out) {
if (in.event().valid()) {
encode_signal(in.event());
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (e_signal.valid()) {
encode_signal(e_signal);
}
encode_wait(out.event());
encode_wait(e_wait);
}

void AllReduce::eval_gpu(
Expand All @@ -33,8 +33,12 @@ void AllReduce::eval_gpu(
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}

auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);
auto task = [in = in,
out = out,
e = std::move(e),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
Expand All @@ -48,11 +52,9 @@ void AllReduce::eval_gpu(
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
out.event().signal();
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

signal_and_wait(in, out);
}

void AllGather::eval_gpu(
Expand All @@ -65,15 +67,19 @@ void AllGather::eval_gpu(

out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto task = [in = in, out = out, group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
out.event().signal();
};
auto e = Event(stream());
e.set_value(1);
signal_and_wait(in.event(), e);

auto task =
[in = in, out = out, e = std::move(e), group = group()]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::all_gather(group, in, out);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
signal_and_wait(in, out);
}

void Send::eval_gpu(
Expand All @@ -92,12 +98,10 @@ void Send::eval_gpu(
in.event().wait();
}
distributed::detail::send(group, out, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
// Encode a signal event for the input
if (in.event().valid()) {
encode_signal(in.event());
}
Expand All @@ -113,15 +117,18 @@ void Recv::eval_gpu(

out.set_data(allocator::malloc_or_wait(out.nbytes()));

auto e = Event(stream());
e.set_value(1);

encode_wait(e);

// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
auto task =
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
e.signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));

// Encode a wait event as there is no input for the recv to encode a signal.
encode_wait(out.event());
}

} // namespace mlx::core::distributed
10 changes: 7 additions & 3 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
read_task();
return;
}

auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto signal_task = [out = out, fut = std::move(fut)]() {

auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait();
out.event().signal();
e.signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
encode_wait(out.event());
}

void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
Expand Down