diff --git a/third_party/xla_client/multi_wait.cc b/third_party/xla_client/multi_wait.cc index 3d6e62146fa..77d7a633344 100644 --- a/third_party/xla_client/multi_wait.cc +++ b/third_party/xla_client/multi_wait.cc @@ -11,7 +11,7 @@ void MultiWait::Done() { { std::lock_guard lock(mutex_); completed_count_ += 1; - notify = completed_count_ >= count_; + notify = completed_count_ == count_; } if (notify) { cv_.notify_all(); @@ -45,17 +45,27 @@ void MultiWait::Reset(size_t count) { } std::function MultiWait::Completer(std::function func) { - auto completer = [this, func = std::move(func)]() { - try { - func(); - } catch (...) { - std::lock_guard lock(mutex_); - exptr_ = std::current_exception(); - } - Done(); + auto completer = [this, func = std::move(func)]() { Complete(func); }; + return completer; +} + +std::function MultiWait::Completer(std::shared_ptr mwait, + std::function func) { + auto completer = [mwait = std::move(mwait), func = std::move(func)]() { + mwait->Complete(func); }; return completer; } +void MultiWait::Complete(const std::function& func) { + try { + func(); + } catch (...) { + std::lock_guard lock(mutex_); + exptr_ = std::current_exception(); + } + Done(); +} + } // namespace util } // namespace xla diff --git a/third_party/xla_client/multi_wait.h b/third_party/xla_client/multi_wait.h index c01ccf76f53..d5260e5c5c4 100644 --- a/third_party/xla_client/multi_wait.h +++ b/third_party/xla_client/multi_wait.h @@ -3,6 +3,7 @@ #include #include +#include #include #include "tensorflow/compiler/xla/types.h" @@ -31,10 +32,19 @@ class MultiWait { // Creates a completer functor which signals the mult wait object once func // has completed. Handles exceptions by signaling the multi wait with the - // proper status value. + // proper status value. This API returns a function which captures a MultiWait + // reference, so care must be taken such that the reference remains valid for + // the whole lifetime of the returned function. std::function Completer(std::function func); + // Similar as the above API, but with explicit capture of the MultiWait shared + // pointer. + static std::function Completer(std::shared_ptr mwait, + std::function func); + private: + void Complete(const std::function& func); + std::mutex mutex_; std::condition_variable cv_; size_t count_ = 0; diff --git a/third_party/xla_client/xrt_computation_client.cc b/third_party/xla_client/xrt_computation_client.cc index 52d57d5abac..653093d0626 100644 --- a/third_party/xla_client/xrt_computation_client.cc +++ b/third_party/xla_client/xrt_computation_client.cc @@ -302,7 +302,7 @@ std::vector XrtComputationClient::TransferToServer( } XLA_COUNTER("XrtPartitionedTransferToServer", 1); - util::MultiWait mwait(partitions.size()); + auto mwait = std::make_shared(partitions.size()); std::vector results(tensors.size()); for (size_t i = 0; i < partitions.size(); ++i) { auto sender = [&, i]() { @@ -316,9 +316,10 @@ std::vector XrtComputationClient::TransferToServer( results[base_index + r] = std::move(partitions_results[r]); } }; - env::ScheduleIoClosure(mwait.Completer(std::move(sender))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(sender))); } - mwait.Wait(); + mwait->Wait(); return results; } @@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal( std::mutex lock; XrtSessionCache::SessionMap session_map; int64 total_size = 0; - util::MultiWait mwait(tensors.size()); + auto mwait = std::make_shared(tensors.size()); std::map session_work_map; { metrics::TimedSection timed(TransferToServerTransformMetric()); @@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal( total_size += tdata.size(); } }; - env::ScheduleClosure(mwait.Completer(std::move(converter))); + env::ScheduleClosure( + util::MultiWait::Completer(mwait, std::move(converter))); } - mwait.Wait(); + mwait->Wait(); } OutboundDataMetric()->AddSample(total_size); - mwait.Reset(session_work_map.size()); + mwait->Reset(session_work_map.size()); std::vector results(tensors.size()); for (auto& session_session_work : session_work_map) { XrtSession* session = session_session_work.first; @@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal( } CreateDataHandlesCounter()->AddValue(outputs.size()); }; - env::ScheduleIoClosure(mwait.Completer(std::move(runner))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(runner))); } - mwait.Wait(); + mwait->Wait(); return results; } @@ -426,7 +429,7 @@ std::vector XrtComputationClient::TransferFromServer( session_work->index_mapping.push_back(i); } - util::MultiWait mwait(session_work_map.size()); + auto mwait = std::make_shared(session_work_map.size()); std::atomic total_size(0); std::vector results(handles.size()); for (auto& session_session_work : session_work_map) { @@ -446,9 +449,10 @@ std::vector XrtComputationClient::TransferFromServer( total_size += results[li].size_bytes(); } }; - env::ScheduleIoClosure(mwait.Completer(std::move(runner))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(runner))); } - mwait.Wait(); + mwait->Wait(); InboundDataMetric()->AddSample(total_size.load()); return results; } @@ -458,7 +462,7 @@ std::vector XrtComputationClient::Compile( metrics::TimedSection timed(CompileMetric()); std::mutex lock; - util::MultiWait mwait(instances.size()); + auto mwait = std::make_shared(instances.size()); std::vector program_shapes(instances.size()); std::vector results(instances.size()); std::vector cache_keys(instances.size()); @@ -499,10 +503,10 @@ std::vector XrtComputationClient::Compile( results[i] = computation_ptr; } }; - env::ScheduleClosure(mwait.Completer(std::move(builder))); + env::ScheduleClosure(util::MultiWait::Completer(mwait, std::move(builder))); } - mwait.Wait(); - mwait.Reset(session_work_map.size()); + mwait->Wait(); + mwait->Reset(session_work_map.size()); for (auto& session_and_work : session_work_map) { XrtSession* session = session_and_work.first; @@ -532,9 +536,10 @@ std::vector XrtComputationClient::Compile( CreateCompileHandlesCounter()->AddValue(1); } }; - env::ScheduleIoClosure(mwait.Completer(std::move(session_runner))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(session_runner))); } - mwait.Wait(); + mwait->Wait(); return results; } @@ -626,7 +631,7 @@ XrtComputationClient::RunComputations( } XLA_CHECK_EQ(computations.size(), devices.size()); - util::MultiWait mwait(session_replicas.size()); + auto mwait = std::make_shared(session_replicas.size()); std::vector> results(devices.size()); for (auto& sess_replica : session_replicas) { XrtSession* session = sess_replica.first; @@ -655,9 +660,10 @@ XrtComputationClient::RunComputations( GetEffectiveDevice(devices[replica])); } }; - env::ScheduleIoClosure(mwait.Completer(std::move(session_runner))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(session_runner))); } - mwait.Wait(); + mwait->Wait(); return results; } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ae81b13ace3..18557355cae 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -480,7 +480,7 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset, std::min(num_threads, std::thread::hardware_concurrency()); size_t block_size = size / num_threads; - xla::util::MultiWait mwait(num_threads); + auto mwait = std::make_shared(num_threads); for (size_t i = 0; i < num_threads; ++i) { auto reader = [&, i]() { uint64_t base = static_cast(i) * block_size; @@ -491,9 +491,10 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset, XLA_CHECK_OK( file->Read(offset + base, tsize, &result, buffer.get() + base)); }; - xla::env::ScheduleIoClosure(mwait.Completer(std::move(reader))); + xla::env::ScheduleIoClosure( + xla::util::MultiWait::Completer(mwait, std::move(reader))); } - mwait.Wait(); + mwait->Wait(); } return py::bytes(buffer.get(), size); } diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 2c160ed074a..9b8bf676436 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -423,15 +423,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape, std::vector iter_dims = GetIterationDimensions(dest_shape); std::vector parts = CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front()); - xla::util::MultiWait mwait(parts.size()); + auto mwait = std::make_shared(parts.size()); for (size_t i = 0; i < parts.size(); ++i) { auto copy_fn = [&, i]() { SlicedCopy(dest_shape.dimensions(), src_data, src_strides, dest_data, dest_strides, iter_dims, parts[i]); }; - xla::env::ScheduleClosure(mwait.Completer(std::move(copy_fn))); + xla::env::ScheduleClosure( + xla::util::MultiWait::Completer(mwait, std::move(copy_fn))); } - mwait.Wait(); + mwait->Wait(); } }