Skip to content

Commit e54604c

Browse files
committed
Codegen selu and silu
1 parent 15f1239 commit e54604c

File tree

7 files changed

+46
-49
lines changed

7 files changed

+46
-49
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6386,8 +6386,8 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) {
63866386
AllClose(input, xla_input);
63876387
});
63886388

6389+
// selu_ uses elu_ instead of selu
63896390
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
6390-
ExpectCounterChanged("xla::selu_", cpp_test::GetIgnoredCounters());
63916391
}
63926392

63936393
TEST_F(AtenXlaTensorTest, TestCelu) {

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,32 +2785,6 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim,
27852785
XLATensor::select(bridge::GetXlaTensor(self), dim, index));
27862786
}
27872787

2788-
at::Tensor XLANativeFunctions::selu(const at::Tensor& self) {
2789-
XLA_FN_COUNTER("xla::");
2790-
return bridge::AtenFromXlaTensor(XLATensor::selu(bridge::GetXlaTensor(self)));
2791-
}
2792-
2793-
at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) {
2794-
XLA_FN_COUNTER("xla::");
2795-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
2796-
XLATensor::selu_(self_tensor);
2797-
return self;
2798-
}
2799-
2800-
at::Tensor XLANativeFunctions::silu(const at::Tensor& self) {
2801-
XLA_FN_COUNTER("xla::");
2802-
return bridge::AtenFromXlaTensor(XLATensor::silu(bridge::GetXlaTensor(self)));
2803-
}
2804-
2805-
at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
2806-
const at::Tensor& self) {
2807-
XLA_FN_COUNTER("xla::");
2808-
XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output);
2809-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
2810-
return bridge::AtenFromXlaTensor(
2811-
XLATensor::silu_backward(grad_output_tensor, self_tensor));
2812-
}
2813-
28142788
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
28152789
XLA_FN_COUNTER("xla::");
28162790
return bridge::AtenFromXlaTensor(

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ torch_xla::XlaOpVector Rsqrt::Lower(LoweringContext* loctx) const {
216216
return ReturnOp(xla::Rsqrt(xla_input), loctx);
217217
}
218218

219+
torch_xla::XlaOpVector Selu::Lower(LoweringContext* loctx) const {
220+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
221+
return ReturnOp(BuildSelu(xla_input), loctx);
222+
}
223+
219224
torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
220225
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
221226
return ReturnOp(BuildSgn(xla_input), loctx);
@@ -226,6 +231,17 @@ torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
226231
return ReturnOp(BuildSign(xla_input), loctx);
227232
}
228233

234+
torch_xla::XlaOpVector Silu::Lower(LoweringContext* loctx) const {
235+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
236+
return ReturnOp(xla_input * BuildSigmoid(xla_input), loctx);
237+
}
238+
239+
torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const {
240+
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
241+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
242+
return ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
243+
}
244+
229245
torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const {
230246
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
231247
return ReturnOp(xla::Sin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "tensorflow/compiler/xla/client/lib/logdet.h"
44
#include "tensorflow/compiler/xla/shape_util.h"
5+
#include "torch_xla/csrc/elementwise.h"
56
#include "torch_xla/csrc/helpers.h"
67
#include "torch_xla/csrc/pooling.h"
78

@@ -208,6 +209,10 @@ xla::Shape RsqrtOutputShape(const torch::lazy::Value& input) {
208209
return GetXlaShape(input);
209210
}
210211

212+
xla::Shape SeluOutputShape(const torch::lazy::Value& input) {
213+
return GetXlaShape(input);
214+
}
215+
211216
xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
212217
return GetXlaShape(input);
213218
}
@@ -216,6 +221,20 @@ xla::Shape SignOutputShape(const torch::lazy::Value& input) {
216221
return GetXlaShape(input);
217222
}
218223

224+
xla::Shape SiluOutputShape(const torch::lazy::Value& input) {
225+
return GetXlaShape(input);
226+
}
227+
228+
xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
229+
const torch::lazy::Value& input) {
230+
auto lower_for_shape_fn =
231+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
232+
return BuildSiLUBackward(operands[0], operands[1]);
233+
};
234+
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
235+
lower_for_shape_fn);
236+
}
237+
219238
xla::Shape SinOutputShape(const torch::lazy::Value& input) {
220239
return GetXlaShape(input);
221240
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,17 @@ xla::Shape RoundOutputShape(const torch::lazy::Value& input);
8484

8585
xla::Shape RsqrtOutputShape(const torch::lazy::Value& input);
8686

87+
xla::Shape SeluOutputShape(const torch::lazy::Value& input);
88+
8789
xla::Shape SgnOutputShape(const torch::lazy::Value& input);
8890

8991
xla::Shape SignOutputShape(const torch::lazy::Value& input);
9092

93+
xla::Shape SiluOutputShape(const torch::lazy::Value& input);
94+
95+
xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
96+
const torch::lazy::Value& input);
97+
9198
xla::Shape SinOutputShape(const torch::lazy::Value& input);
9299

93100
xla::Shape SinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,24 +2420,6 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim,
24202420
return tensor_ops::Select(input, dim, index);
24212421
}
24222422

2423-
XLATensorPtr XLATensor::selu(const XLATensorPtr& input) {
2424-
return input->CreateFrom(Selu(input->GetIrValue()));
2425-
}
2426-
2427-
void XLATensor::selu_(XLATensorPtr& input) {
2428-
input->SetInPlaceIrValue(Selu(input->GetIrValue()));
2429-
}
2430-
2431-
XLATensorPtr XLATensor::silu(const XLATensorPtr& input) {
2432-
return input->CreateFrom(SiLU(input->GetIrValue()));
2433-
}
2434-
2435-
XLATensorPtr XLATensor::silu_backward(XLATensorPtr& grad_output,
2436-
XLATensorPtr& input) {
2437-
return input->CreateFrom(
2438-
SiLUBackward(grad_output->GetIrValue(), input->GetIrValue()));
2439-
}
2440-
24412423
XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) {
24422424
return input->CreateFrom(Sigmoid(input->GetIrValue()));
24432425
}

xla_native_functions.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ full_codegen:
3434
- reciprocal
3535
- round
3636
- rsqrt
37+
- selu
3738
- sgn
3839
- sign
40+
- silu
41+
- silu_backward
3942
- sin
4043
- sinh
4144
- tan
@@ -271,12 +274,8 @@ supported:
271274
- scatter.value_reduce
272275
- scatter_add
273276
- select.int
274-
- selu
275-
- selu_
276277
- sigmoid
277278
- sigmoid_backward
278-
- silu
279-
- silu_backward
280279
- slice.Tensor
281280
- slogdet
282281
- smooth_l1_loss

0 commit comments

Comments
 (0)