Skip to content

Commit

Permalink
Cherrypick fix for DLRM real dataset crash (#2409)
Browse files Browse the repository at this point in the history
Co-authored-by: Davide Libenzi <davide.libenzi@gmail.com>
  • Loading branch information
jysohn23 and davidel authored Aug 10, 2020
1 parent ade6927 commit 06d564b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 37 deletions.
28 changes: 19 additions & 9 deletions third_party/xla_client/multi_wait.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void MultiWait::Done() {
{
std::lock_guard<std::mutex> lock(mutex_);
completed_count_ += 1;
notify = completed_count_ >= count_;
notify = completed_count_ == count_;
}
if (notify) {
cv_.notify_all();
Expand Down Expand Up @@ -45,17 +45,27 @@ void MultiWait::Reset(size_t count) {
}

std::function<void()> MultiWait::Completer(std::function<void()> func) {
auto completer = [this, func = std::move(func)]() {
try {
func();
} catch (...) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::current_exception();
}
Done();
auto completer = [this, func = std::move(func)]() { Complete(func); };
return completer;
}

std::function<void()> MultiWait::Completer(std::shared_ptr<MultiWait> mwait,
std::function<void()> func) {
auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
mwait->Complete(func);
};
return completer;
}

void MultiWait::Complete(const std::function<void()>& func) {
try {
func();
} catch (...) {
std::lock_guard<std::mutex> lock(mutex_);
exptr_ = std::current_exception();
}
Done();
}

} // namespace util
} // namespace xla
12 changes: 11 additions & 1 deletion third_party/xla_client/multi_wait.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <condition_variable>
#include <functional>
#include <memory>
#include <mutex>

#include "tensorflow/compiler/xla/types.h"
Expand Down Expand Up @@ -31,10 +32,19 @@ class MultiWait {

// Creates a completer functor which signals the mult wait object once func
// has completed. Handles exceptions by signaling the multi wait with the
// proper status value.
// proper status value. This API returns a function which captures a MultiWait
// reference, so care must be taken such that the reference remains valid for
// the whole lifetime of the returned function.
std::function<void()> Completer(std::function<void()> func);

// Similar as the above API, but with explicit capture of the MultiWait shared
// pointer.
static std::function<void()> Completer(std::shared_ptr<MultiWait> mwait,
std::function<void()> func);

private:
void Complete(const std::function<void()>& func);

std::mutex mutex_;
std::condition_variable cv_;
size_t count_ = 0;
Expand Down
48 changes: 27 additions & 21 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
}
XLA_COUNTER("XrtPartitionedTransferToServer", 1);

util::MultiWait mwait(partitions.size());
auto mwait = std::make_shared<util::MultiWait>(partitions.size());
std::vector<DataPtr> results(tensors.size());
for (size_t i = 0; i < partitions.size(); ++i) {
auto sender = [&, i]() {
Expand All @@ -316,9 +316,10 @@ std::vector<ComputationClient::DataPtr> XrtComputationClient::TransferToServer(
results[base_index + r] = std::move(partitions_results[r]);
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(sender)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(sender)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand All @@ -330,7 +331,7 @@ XrtComputationClient::TransferToServerInternal(
std::mutex lock;
XrtSessionCache::SessionMap session_map;
int64 total_size = 0;
util::MultiWait mwait(tensors.size());
auto mwait = std::make_shared<util::MultiWait>(tensors.size());
std::map<XrtSession*, SessionWork> session_work_map;
{
metrics::TimedSection timed(TransferToServerTransformMetric());
Expand Down Expand Up @@ -363,13 +364,14 @@ XrtComputationClient::TransferToServerInternal(
total_size += tdata.size();
}
};
env::ScheduleClosure(mwait.Completer(std::move(converter)));
env::ScheduleClosure(
util::MultiWait::Completer(mwait, std::move(converter)));
}
mwait.Wait();
mwait->Wait();
}
OutboundDataMetric()->AddSample(total_size);

mwait.Reset(session_work_map.size());
mwait->Reset(session_work_map.size());
std::vector<DataPtr> results(tensors.size());
for (auto& session_session_work : session_work_map) {
XrtSession* session = session_session_work.first;
Expand All @@ -388,9 +390,10 @@ XrtComputationClient::TransferToServerInternal(
}
CreateDataHandlesCounter()->AddValue(outputs.size());
};
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down Expand Up @@ -426,7 +429,7 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
session_work->index_mapping.push_back(i);
}

util::MultiWait mwait(session_work_map.size());
auto mwait = std::make_shared<util::MultiWait>(session_work_map.size());
std::atomic<int64> total_size(0);
std::vector<Literal> results(handles.size());
for (auto& session_session_work : session_work_map) {
Expand All @@ -446,9 +449,10 @@ std::vector<Literal> XrtComputationClient::TransferFromServer(
total_size += results[li].size_bytes();
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(runner)));
}
mwait.Wait();
mwait->Wait();
InboundDataMetric()->AddSample(total_size.load());
return results;
}
Expand All @@ -458,7 +462,7 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
metrics::TimedSection timed(CompileMetric());

std::mutex lock;
util::MultiWait mwait(instances.size());
auto mwait = std::make_shared<util::MultiWait>(instances.size());
std::vector<ProgramShape> program_shapes(instances.size());
std::vector<ComputationPtr> results(instances.size());
std::vector<CompilationCacheKey> cache_keys(instances.size());
Expand Down Expand Up @@ -499,10 +503,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
results[i] = computation_ptr;
}
};
env::ScheduleClosure(mwait.Completer(std::move(builder)));
env::ScheduleClosure(util::MultiWait::Completer(mwait, std::move(builder)));
}
mwait.Wait();
mwait.Reset(session_work_map.size());
mwait->Wait();
mwait->Reset(session_work_map.size());

for (auto& session_and_work : session_work_map) {
XrtSession* session = session_and_work.first;
Expand Down Expand Up @@ -532,9 +536,10 @@ std::vector<ComputationClient::ComputationPtr> XrtComputationClient::Compile(
CreateCompileHandlesCounter()->AddValue(1);
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(session_runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down Expand Up @@ -626,7 +631,7 @@ XrtComputationClient::RunComputations(
}
XLA_CHECK_EQ(computations.size(), devices.size());

util::MultiWait mwait(session_replicas.size());
auto mwait = std::make_shared<util::MultiWait>(session_replicas.size());
std::vector<std::vector<DataPtr>> results(devices.size());
for (auto& sess_replica : session_replicas) {
XrtSession* session = sess_replica.first;
Expand Down Expand Up @@ -655,9 +660,10 @@ XrtComputationClient::RunComputations(
GetEffectiveDevice(devices[replica]));
}
};
env::ScheduleIoClosure(mwait.Completer(std::move(session_runner)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(session_runner)));
}
mwait.Wait();
mwait->Wait();
return results;
}

Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
std::min<size_t>(num_threads, std::thread::hardware_concurrency());
size_t block_size = size / num_threads;

xla::util::MultiWait mwait(num_threads);
auto mwait = std::make_shared<xla::util::MultiWait>(num_threads);
for (size_t i = 0; i < num_threads; ++i) {
auto reader = [&, i]() {
uint64_t base = static_cast<uint64_t>(i) * block_size;
Expand All @@ -491,9 +491,10 @@ py::bytes ReadTfFile(tensorflow::RandomAccessFile* file, uint64_t offset,
XLA_CHECK_OK(
file->Read(offset + base, tsize, &result, buffer.get() + base));
};
xla::env::ScheduleIoClosure(mwait.Completer(std::move(reader)));
xla::env::ScheduleIoClosure(
xla::util::MultiWait::Completer(mwait, std::move(reader)));
}
mwait.Wait();
mwait->Wait();
}
return py::bytes(buffer.get(), size);
}
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,15 +423,16 @@ void CopyTensors(const void* src_buffer, const xla::Shape& src_shape,
std::vector<xla::int64> iter_dims = GetIterationDimensions(dest_shape);
std::vector<CopyPartition> parts =
CreateCopyPartitions(dest_shape.dimensions(), iter_dims.front());
xla::util::MultiWait mwait(parts.size());
auto mwait = std::make_shared<xla::util::MultiWait>(parts.size());
for (size_t i = 0; i < parts.size(); ++i) {
auto copy_fn = [&, i]() {
SlicedCopy<SType, DType>(dest_shape.dimensions(), src_data, src_strides,
dest_data, dest_strides, iter_dims, parts[i]);
};
xla::env::ScheduleClosure(mwait.Completer(std::move(copy_fn)));
xla::env::ScheduleClosure(
xla::util::MultiWait::Completer(mwait, std::move(copy_fn)));
}
mwait.Wait();
mwait->Wait();
}
}

Expand Down

0 comments on commit 06d564b

Please sign in to comment.