Skip to content

Commit

Permalink
Silu backward r1 10 (#3202)
Browse files Browse the repository at this point in the history
* SiLU backward lowering

* Use silu as symbol since silu_backward string is missing from pytorch 1.10
  • Loading branch information
JackCaoG authored Nov 8, 2021
1 parent 81d26e3 commit 8fb44f9
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3406,6 +3406,20 @@ TEST_F(AtenXlaTensorTest, TestSiLU) {
ExpectCounterChanged("xla::silu_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSiLUBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::silu(inputs[0]);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::rand({2, 2},
torch::TensorOptions(torch::kFloat).requires_grad(true))},
device, testfn);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::silu_backward", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSigmoid) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::sigmoid(a);
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2879,6 +2879,15 @@ at::Tensor& XLANativeFunctions::silu_out(const at::Tensor& self,
return out;
}

at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output);
XLATensor self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::silu_backward(grad_output_tensor, self_tensor));
}

at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input) {
return half + half * xla::Tanh(half * input);
}

xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
xla::XlaOp input_sigmoid = BuildSigmoid(input);
return grad_output * (input_sigmoid * (one + input * (one - input_sigmoid)));
}

xla::XlaOp BuildReciprocal(xla::XlaOp input) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
xla::XlaOp BuildSigmoid(xla::XlaOp input);

// Computes the backward of Silu
// grad_output * (sigmoid(input) * (1 + input * (1 - sigmoid(input))))
xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input);

// Computes the reciprocal function.
// Reciprocal(x) = 1 / x
xla::XlaOp BuildReciprocal(xla::XlaOp input);
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,25 @@ NodePtr SiLU(const Value& input) {
std::move(lower_fn));
}

NodePtr SiLUBackward(const Value& grad_output, const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildSiLUBackward(operands[0], operands[1]);
};
return GenericOp(OpKind(at::aten::silu), {grad_output, input},
[&]() {
return InferOutputShape(
{grad_output.shape(), input.shape()},
lower_for_shape_fn);
},
std::move(lower_fn));
}

NodePtr Sigmoid(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ NodePtr Sigmoid(const Value& input);

NodePtr SiLU(const Value& input);

NodePtr SiLUBackward(const Value& grad_output, const Value& input);

NodePtr SigmoidBackward(const Value& grad_output, const Value& output);

NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,7 @@ class XLATensor {
xla::int64_t index);

static void silu_out(XLATensor& input, XLATensor& out);
static XLATensor silu_backward(XLATensor& grad_output, XLATensor& input);
static XLATensor sigmoid(const XLATensor& input);
static XLATensor sigmoid_backward(const XLATensor& grad_output,
const XLATensor& output);
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,11 @@ void XLATensor::silu_out(XLATensor& input, XLATensor& out) {
out.SetInPlaceIrValue(ir::ops::SiLU(input.GetIrValue()));
}

XLATensor XLATensor::silu_backward(XLATensor& grad_output, XLATensor& input) {
return input.CreateFrom(
ir::ops::SiLUBackward(grad_output.GetIrValue(), input.GetIrValue()));
}

XLATensor XLATensor::sigmoid(const XLATensor& input) {
return input.CreateFrom(ir::ops::Sigmoid(input.GetIrValue()));
}
Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ supported:
- rsqrt
- select.int
- silu.out
- silu_backward
- sigmoid
- sin
- sinh
Expand Down

0 comments on commit 8fb44f9

Please sign in to comment.