Skip to content

Commit 73a3937

Browse files
authored
Silu backward (#3195)
* SiLU backward lowering * Torch pin * Delete .torch_pin
1 parent 047d89c commit 73a3937

File tree

9 files changed

+62
-0
lines changed

9 files changed

+62
-0
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,6 +3428,20 @@ TEST_F(AtenXlaTensorTest, TestSiLU) {
34283428
ExpectCounterChanged("xla::silu_out", cpp_test::GetIgnoredCounters());
34293429
}
34303430

3431+
TEST_F(AtenXlaTensorTest, TestSiLUBackward) {
3432+
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
3433+
return torch::silu(inputs[0]);
3434+
};
3435+
ForEachDevice([&](const torch::Device& device) {
3436+
TestBackward(
3437+
{torch::rand({2, 2},
3438+
torch::TensorOptions(torch::kFloat).requires_grad(true))},
3439+
device, testfn);
3440+
});
3441+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
3442+
ExpectCounterChanged("xla::silu_backward", cpp_test::GetIgnoredCounters());
3443+
}
3444+
34313445
TEST_F(AtenXlaTensorTest, TestSigmoid) {
34323446
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
34333447
torch::Tensor b = torch::sigmoid(a);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2894,6 +2894,15 @@ at::Tensor& XLANativeFunctions::silu_out(const at::Tensor& self,
28942894
return out;
28952895
}
28962896

2897+
at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
2898+
const at::Tensor& self) {
2899+
XLA_FN_COUNTER("xla::");
2900+
XLATensor grad_output_tensor = bridge::GetXlaTensor(grad_output);
2901+
XLATensor self_tensor = bridge::GetXlaTensor(self);
2902+
return bridge::AtenFromXlaTensor(
2903+
XLATensor::silu_backward(grad_output_tensor, self_tensor));
2904+
}
2905+
28972906
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
28982907
XLA_FN_COUNTER("xla::");
28992908
return bridge::AtenFromXlaTensor(

torch_xla/csrc/elementwise.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input) {
182182
return half + half * xla::Tanh(half * input);
183183
}
184184

185+
xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input) {
186+
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
187+
xla::XlaOp one = xla::One(input.builder(), shape.element_type());
188+
xla::XlaOp input_sigmoid = BuildSigmoid(input);
189+
return grad_output * (input_sigmoid * (one + input * (one - input_sigmoid)));
190+
}
191+
185192
xla::XlaOp BuildReciprocal(xla::XlaOp input) {
186193
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
187194
xla::XlaOp one = xla::One(input.builder(), shape.element_type());

torch_xla/csrc/elementwise.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
4848
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
4949
xla::XlaOp BuildSigmoid(xla::XlaOp input);
5050

51+
// Computes the backward of Silu
52+
// grad_output * (sigmoid(input) * (1 + input * (1 - sigmoid(input))))
53+
xla::XlaOp BuildSiLUBackward(xla::XlaOp grad_output, xla::XlaOp input);
54+
5155
// Computes the reciprocal function.
5256
// Reciprocal(x) = 1 / x
5357
xla::XlaOp BuildReciprocal(xla::XlaOp input);

torch_xla/csrc/ops/ops.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,25 @@ NodePtr SiLU(const Value& input) {
219219
std::move(lower_fn));
220220
}
221221

222+
NodePtr SiLUBackward(const Value& grad_output, const Value& input) {
223+
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
224+
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
225+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
226+
return node.ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
227+
};
228+
auto lower_for_shape_fn =
229+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
230+
return BuildSiLUBackward(operands[0], operands[1]);
231+
};
232+
return GenericOp(OpKind(at::aten::silu_backward), {grad_output, input},
233+
[&]() {
234+
return InferOutputShape(
235+
{grad_output.shape(), input.shape()},
236+
lower_for_shape_fn);
237+
},
238+
std::move(lower_fn));
239+
}
240+
222241
NodePtr Sigmoid(const Value& input) {
223242
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
224243
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ NodePtr Sigmoid(const Value& input);
131131

132132
NodePtr SiLU(const Value& input);
133133

134+
NodePtr SiLUBackward(const Value& grad_output, const Value& input);
135+
134136
NodePtr SigmoidBackward(const Value& grad_output, const Value& output);
135137

136138
NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,

torch_xla/csrc/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,7 @@ class XLATensor {
972972
xla::int64_t index);
973973

974974
static void silu_out(XLATensor& input, XLATensor& out);
975+
static XLATensor silu_backward(XLATensor& grad_output, XLATensor& input);
975976
static XLATensor sigmoid(const XLATensor& input);
976977
static XLATensor sigmoid_backward(const XLATensor& grad_output,
977978
const XLATensor& output);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,6 +2311,11 @@ void XLATensor::silu_out(XLATensor& input, XLATensor& out) {
23112311
out.SetInPlaceIrValue(ir::ops::SiLU(input.GetIrValue()));
23122312
}
23132313

2314+
XLATensor XLATensor::silu_backward(XLATensor& grad_output, XLATensor& input) {
2315+
return input.CreateFrom(
2316+
ir::ops::SiLUBackward(grad_output.GetIrValue(), input.GetIrValue()));
2317+
}
2318+
23142319
XLATensor XLATensor::sigmoid(const XLATensor& input) {
23152320
return input.CreateFrom(ir::ops::Sigmoid(input.GetIrValue()));
23162321
}

xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ supported:
261261
- sigmoid_backward
262262
- sign
263263
- silu.out
264+
- silu_backward
264265
- sin
265266
- sinh
266267
- slice.Tensor

0 commit comments

Comments
 (0)