diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 517043597..be4cc1848 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -12,11 +12,11 @@ namespace mlx::core::distributed { -void signal_and_wait(const array& in, const array& out) { - if (in.event().valid()) { - encode_signal(in.event()); +void signal_and_wait(const Event& e_signal, const Event& e_wait) { + if (e_signal.valid()) { + encode_signal(e_signal); } - encode_wait(out.event()); + encode_wait(e_wait); } void AllReduce::eval_gpu( @@ -33,8 +33,12 @@ void AllReduce::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); } + auto e = Event(stream()); + e.set_value(1); + signal_and_wait(in.event(), e); auto task = [in = in, out = out, + e = std::move(e), reduce_type = reduce_type_, group = group()]() mutable { if (in.event().valid()) { @@ -48,11 +52,9 @@ void AllReduce::eval_gpu( default: throw std::runtime_error("Only all reduce sum is supported for now"); } - out.event().signal(); + e.signal(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - - signal_and_wait(in, out); } void AllGather::eval_gpu( @@ -65,15 +67,19 @@ void AllGather::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); - auto task = [in = in, out = out, group = group()]() mutable { - if (in.event().valid()) { - in.event().wait(); - } - distributed::detail::all_gather(group, in, out); - out.event().signal(); - }; + auto e = Event(stream()); + e.set_value(1); + signal_and_wait(in.event(), e); + + auto task = + [in = in, out = out, e = std::move(e), group = group()]() mutable { + if (in.event().valid()) { + in.event().wait(); + } + distributed::detail::all_gather(group, in, out); + e.signal(); + }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - signal_and_wait(in, out); } void Send::eval_gpu( @@ -92,12 +98,10 @@ void Send::eval_gpu( in.event().wait(); } distributed::detail::send(group, out, dst); - out.event().signal(); }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - // Encode a signal event for the input but not a wait since we don't need to - // wait on the output. + // Encode a signal event for the input if (in.event().valid()) { encode_signal(in.event()); } @@ -113,15 +117,18 @@ void Recv::eval_gpu( out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto e = Event(stream()); + e.set_value(1); + + encode_wait(e); + // Schedule an async recv on the comm stream - auto task = [out = out, group = group(), src = src_]() mutable { - distributed::detail::recv(group, out, src); - out.event().signal(); - }; + auto task = + [out = out, e = std::move(e), group = group(), src = src_]() mutable { + distributed::detail::recv(group, out, src); + e.signal(); + }; scheduler::enqueue(detail::communication_stream(), std::move(task)); - - // Encode a wait event as there is no input for the recv to encode a signal. - encode_wait(out.event()); } } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 3a2491e2d..47190daf3 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -316,13 +316,17 @@ void Load::eval_gpu(const std::vector& inputs, array& out) { read_task(); return; } + auto fut = io::thread_pool().enqueue(std::move(read_task)).share(); - auto signal_task = [out = out, fut = std::move(fut)]() { + + auto e = Event(stream()); + e.set_value(1); + encode_wait(e); + auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable { fut.wait(); - out.event().signal(); + e.signal(); }; scheduler::enqueue(io_stream(), std::move(signal_task)); - encode_wait(out.event()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) {