Skip to content

Commit 8fb44f9

Browse files
authored
Silu backward r1 10 (#3202)
* SiLU backward lowering * Use silu as symbol since silu_backward string is missing from pytorch 1.10
1 parent 81d26e3 commit 8fb44f9

File tree

9 files changed

+62
-0
lines changed

9 files changed

+62
-0
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3406,6 +3406,20 @@ TEST_F(AtenXlaTensorTest, TestSiLU) {
34063406
ExpectCounterChanged("xla::silu_out", cpp_test::GetIgnoredCounters());
34073407
}
34083408

3409+
TEST_F(AtenXlaTensorTest, TestSiLUBackward) {
3410+
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
3411+
return torch::silu(inputs[0]);
3412+
};
3413+
ForEachDevice([&](const torch::Device& device) {
3414+
TestBackward(
3415+
{torch::rand({2, 2},
3416+
torch::TensorOptions(torch::kFloat).requires_grad(true))},
3417+
device, testfn);
3418+
});
3419+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
3420+
ExpectCounterChanged("xla::silu_backward", cpp_test::GetIgnoredCounters());
3421+
}
3422+
34093423
TEST_F(AtenXlaTensorTest, TestSigmoid) {
34103424
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
34113425
torch::Tensor b = torch::sigmoid(a);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,6 +2879,15 @@ at::Tensor& XLANativeFunctions::silu_out(const at::Tensor& self,
28792879
return out;
28802880
}
28812881

2882+
at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
2883+
const at::Tensor& self) {
2884+
XLA_FN_COUNTER("xla::");
2885+
XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output);
2886+
XLATensor self_tensor = bridge::GetXlaTensor(self);
2887+
return bridge::AtenFromXlaTensor(
2888+
XLATensor::silu_backward(grad_output_tensor, self_tensor));
2889+
}
2890+
28822891
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
28832892
XLA_FN_COUNTER("xla::");
28842893
return bridge::AtenFromXlaTensor(

torch_xla/csrc/elementwise.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input) {
182182
return half + half * xla::Tanh(half * input);
183183
}
184184

185+
xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
186+
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
187+
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
188+
xla::XlaOp input_sigmoid = BuildSigmoid(input);
189+
return grad_output * (input_sigmoid * (one + input * (one - input_sigmoid)));
190+
}
191+
185192
xla::XlaOp BuildReciprocal(xla::XlaOp input) {
186193
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
187194
xla::XlaOp one = xla::One(input.builder(), shape.element_type());

torch_xla/csrc/elementwise.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
4848
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
4949
xla::XlaOp BuildSigmoid(xla::XlaOp input);
5050

51+
// Computes the backward of Silu
52+
// grad_output * (sigmoid(input) * (1 + input * (1 - sigmoid(input))))
53+
xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input);
54+
5155
// Computes the reciprocal function.
5256
// Reciprocal(x) = 1 / x
5357
xla::XlaOp BuildReciprocal(xla::XlaOp input);

torch_xla/csrc/ops/ops.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,25 @@ NodePtr SiLU(const Value& input) {
219219
std::move(lower_fn));
220220
}
221221

222+
NodePtr SiLUBackward(const Value& grad_output, const Value& input) {
223+
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
224+
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
225+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
226+
return node.ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
227+
};
228+
auto lower_for_shape_fn =
229+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
230+
return BuildSiLUBackward(operands[0], operands[1]);
231+
};
232+
return GenericOp(OpKind(at::aten::silu), {grad_output, input},
233+
[&]() {
234+
return InferOutputShape(
235+
{grad_output.shape(), input.shape()},
236+
lower_for_shape_fn);
237+
},
238+
std::move(lower_fn));
239+
}
240+
222241
NodePtr Sigmoid(const Value& input) {
223242
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
224243
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ NodePtr Sigmoid(const Value& input);
128128

129129
NodePtr SiLU(const Value& input);
130130

131+
NodePtr SiLUBackward(const Value& grad_output, const Value& input);
132+
131133
NodePtr SigmoidBackward(const Value& grad_output, const Value& output);
132134

133135
NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,

torch_xla/csrc/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,7 @@ class XLATensor {
966966
xla::int64_t index);
967967

968968
static void silu_out(XLATensor& input, XLATensor& out);
969+
static XLATensor silu_backward(XLATensor& grad_output, XLATensor& input);
969970
static XLATensor sigmoid(const XLATensor& input);
970971
static XLATensor sigmoid_backward(const XLATensor& grad_output,
971972
const XLATensor& output);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,11 @@ void XLATensor::silu_out(XLATensor& input, XLATensor& out) {
22862286
out.SetInPlaceIrValue(ir::ops::SiLU(input.GetIrValue()));
22872287
}
22882288

2289+
XLATensor XLATensor::silu_backward(XLATensor& grad_output, XLATensor& input) {
2290+
return input.CreateFrom(
2291+
ir::ops::SiLUBackward(grad_output.GetIrValue(), input.GetIrValue()));
2292+
}
2293+
22892294
XLATensor XLATensor::sigmoid(const XLATensor& input) {
22902295
return input.CreateFrom(ir::ops::Sigmoid(input.GetIrValue()));
22912296
}

xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ supported:
111111
- rsqrt
112112
- select.int
113113
- silu.out
114+
- silu_backward
114115
- sigmoid
115116
- sin
116117
- sinh

0 commit comments

Comments
 (0)