Skip to content

Commit

Permalink
Final cherry-pick for r1.6 release (#2479)
Browse files Browse the repository at this point in the history
Co-authored-by: Davide Libenzi <dlibenzi@google.com>
Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Co-authored-by: JackCaoG <jackcao@google.com>
  • Loading branch information
4 people authored Sep 3, 2020
1 parent 06d564b commit 9703109
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ xla_model
.. autofunction:: xrt_world_size
.. autofunction:: all_reduce
.. autofunction:: all_gather
.. autofunction:: all_to_all
.. autofunction:: add_step_closure
.. autofunction:: wait_device_ops
.. autofunction:: optimizer_step
Expand Down
2 changes: 2 additions & 0 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@
'test_masked_select_mem_overlap', # doesn't raise
'test_scatter_mem_overlap', # doesn't raise
'test_index_mem_overlap', # doesn't raise
'test_topk_nonfinite_xla_float32', # TFXLA update HLO changed for 1.6
'test_topk_nonfinite_xla_float64', # TFXLA update HLO changed for 1.6
},
'TestViewOpsXLA': {
'test_contiguous_nonview',
Expand Down
14 changes: 14 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,20 @@ def test_get_xla_tensor(self):
self.assertEqual(tx, sx.data.cpu())


class TestBinaryCrossEntropyLimitValue(XlaTestCase):

def test_cross_entropy_loss(self):

def test_fn(pred, target):
lossfn = nn.BCELoss()
return lossfn(pred, target)

pred = torch.tensor(1.0)
target = torch.tensor(1.0)
for offset in [1, 0, 1e-8, 1e-7]:
self.runAtenTest([pred - offset, target], test_fn)


class TestDynamicShape(XlaTestCase):

def test_nonzero_shape(self):
Expand Down
2 changes: 1 addition & 1 deletion third_party/tensorflow
Submodule tensorflow updated 5737 files
9 changes: 5 additions & 4 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import time

from .version import __version__ as version
from .version import __version__


def _maybe_select_tpu_version():
Expand Down Expand Up @@ -40,17 +40,18 @@ def _wait_for_open(version, timeout=100, interval=10, log=True):

import cloud_tpu_client
client = cloud_tpu_client.Client(tpu_name)
client.configure_tpu_version(f'pytorch-{version}', restart_type='ifNeeded')
client.configure_tpu_version(
f'pytorch-{__version__}', restart_type='ifNeeded')
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
_wait_for_open(version)
_wait_for_open(__version__)
except ImportError:
logging.warning((
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
'installed. Ignore if not running on Colab/Kaggle TPU.'))
except Exception:
# This path is hit, when we get throttled by the verison changer
# when we import torch_xla from xmp.spawn-ed processes.
_wait_for_open(version, log=False)
_wait_for_open(__version__, log=False)


def _setup_grpc():
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,6 @@ def all_to_all(value,
groups=None):
"""Performs an XLA `AllToAll()` operation on the input tensor.
WARNING: This function is not very reliable, may produce wrong results under
certain inputs. Use it at your own risk.
See: https://www.tensorflow.org/xla/operation_semantics#alltoall
Args:
Expand Down
14 changes: 11 additions & 3 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input,
xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
const absl::optional<xla::XlaOp>& weight,
ReductionMode reduction) {
static const float kLogBound = -100;
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp xweight;
if (weight) {
Expand All @@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
}
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
xla::XlaOp result = -xweight * (target * xla::Log(input) +
(one - target) * xla::Log(one - input));
xla::XlaOp log_bound = XlaHelpers::ScalarValue(
kLogBound, input_shape.element_type(), input.builder());
xla::XlaOp result =
-xweight * (target * xla::Max(xla::Log(input), log_bound) +
(one - target) * xla::Max(xla::Log(one - input), log_bound));
if (reduction == ReductionMode::kNone) {
return result;
}
Expand All @@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
xla::XlaOp BuildBinaryCrossEntropyBackward(
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target,
const absl::optional<xla::XlaOp>& weight, ReductionMode reduction) {
static const float kEpsilon = 1e-12;
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp xweight;
if (weight) {
Expand All @@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward(
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
}
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
xla::XlaOp result = xweight * (input - target) / input / (one - input);
xla::XlaOp epsilon = XlaHelpers::ScalarValue(
kEpsilon, input_shape.element_type(), input.builder());
xla::XlaOp result =
xweight * (input - target) / xla::Max(input * (one - input), epsilon);
if (reduction == ReductionMode::kNone) {
return result * grad_output;
}
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/distributed/xla_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,13 @@ def _start_fn(index, pf_cfg, fn, args):
# Calling _setup_replication() will trigger XLA library initialization, so the
# environment must be fully setup before doing so.
_setup_replication()
fn(gindex, *args)


def _mp_start_fn(index, pf_cfg, fn, args):
exit_code = 0
try:
fn(gindex, *args)
_start_fn(index, pf_cfg, fn, args)
except Exception as e:
print(
'Exception in device={}: {}'.format(_get_multiprocessing_device(),
Expand Down Expand Up @@ -288,7 +292,7 @@ def spawn(fn,
_start_fn(0, pf_cfg, fn, args)
else:
return torch.multiprocessing.start_processes(
_start_fn,
_mp_start_fn,
args=(pf_cfg, fn, args),
nprocs=pf_cfg.num_devices,
join=join,
Expand Down

0 comments on commit 9703109

Please sign in to comment.