Skip to content

Commit 27d21dc

Browse files
author
tyoc213
committed
Adding silu to ops.cpp #2717
1 parent c16fedb commit 27d21dc

File tree

6 files changed

+35
-3
lines changed

6 files changed

+35
-3
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3067,6 +3067,16 @@ TEST_F(AtenXlaTensorTest, TestLogsumexp) {
30673067
}
30683068
}
30693069

3070+
TEST_F(AtenXlaTensorTest, TestSiLU) {
3071+
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
3072+
torch::Tensor b = torch::silu(a);
3073+
ForEachDevice([&](const torch::Device& device) {
3074+
torch::Tensor xla_a = CopyToDevice(a, device);
3075+
torch::Tensor xla_b = torch::silu(xla_a);
3076+
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
3077+
});
3078+
}
3079+
30703080
TEST_F(AtenXlaTensorTest, TestSigmoid) {
30713081
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
30723082
torch::Tensor b = torch::sigmoid(a);

torch_xla/csrc/elementwise.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,14 @@ xla::XlaOp BuildSigmoid(xla::XlaOp input) {
180180
return half + half * xla::Tanh(half * input);
181181
}
182182

183+
184+
xla::XlaOp BuildSiLU(xla::XlaOp input){
185+
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
186+
xla::XlaOp half = XlaHelpers::ScalarValue<float>(0.5, shape.element_type(),
187+
input.builder());
188+
return input * (half + half * xla::Tanh(half * input));
189+
}
190+
183191
xla::XlaOp BuildReciprocal(xla::XlaOp input) {
184192
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
185193
xla::XlaOp one = xla::One(input.builder(), shape.element_type());

torch_xla/csrc/elementwise.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
4747
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
4848
xla::XlaOp BuildSigmoid(xla::XlaOp input);
4949

50+
// Computes silu function using Tanh
51+
// SiLU(x) = x * (tanh(x ∗ 0.5) + 1) ∗ 0.5
52+
xla::XlaOp BuildSiLU(xla::XlaOp input);
53+
5054
// Computes the reciprocal function.
5155
// Reciprocal(x) = 1 / x
5256
xla::XlaOp BuildReciprocal(xla::XlaOp input);

torch_xla/csrc/ops/ops.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,15 @@ NodePtr LogSigmoidBackward(const Value& grad_output, const Value& input,
206206
return grad_output * (Neg(max_deriv) - sign * (buffer - one) / buffer);
207207
}
208208

209+
NodePtr SiLU(const Value& input) {
210+
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
211+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
212+
return node.ReturnOp(BuildSiLU(xla_input), loctx);
213+
};
214+
return GenericOp(OpKind(at::aten::silu), {input}, input.shape(),
215+
std::move(lower_fn));;
216+
}
217+
209218
NodePtr Sigmoid(const Value& input) {
210219
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
211220
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ NodePtr LogSigmoidBackward(const Value& grad_output, const Value& input,
126126

127127
NodePtr Sigmoid(const Value& input);
128128

129+
NodePtr SiLU(const Value& input);
130+
129131
NodePtr SigmoidBackward(const Value& grad_output, const Value& output);
130132

131133
NodePtr LogSoftmaxBackwardOp(const Value& grad_output, const Value& output,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,9 +2337,8 @@ XLATensor XLATensor::select(const XLATensor& input, xla::int64 dim,
23372337
return tensor_ops::Select(input, dim, index);
23382338
}
23392339

2340-
void XLATensor::silu_out(XLATensor& input, XLATensor& out){
2341-
auto value = input.GetIrValue() * ir::ops::Sigmoid(input.GetIrValue());
2342-
out.SetInPlaceIrValue(value);
2340+
void XLATensor::silu_out(XLATensor& input, XLATensor& out) {
2341+
out.SetInPlaceIrValue(ir::ops::SiLU(input.GetIrValue()));
23432342
}
23442343

23452344
XLATensor XLATensor::sigmoid(const XLATensor& input) {

0 commit comments

Comments
 (0)