Skip to content

Commit 63a85d9

Browse files
authored
Codegen silu selu (#3780)
* Codegen selu and silu * Add torch pin * include error print * Add selu_ back * remove pin
1 parent 15f1239 commit 63a85d9

File tree

7 files changed

+46
-37
lines changed

7 files changed

+46
-37
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,32 +2785,14 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim,
27852785
XLATensor::select(bridge::GetXlaTensor(self), dim, index));
27862786
}
27872787

2788-
at::Tensor XLANativeFunctions::selu(const at::Tensor& self) {
2789-
XLA_FN_COUNTER("xla::");
2790-
return bridge::AtenFromXlaTensor(XLATensor::selu(bridge::GetXlaTensor(self)));
2791-
}
2792-
2788+
// TODO(JackCaoG): Remove after elu being codegened
27932789
at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) {
27942790
XLA_FN_COUNTER("xla::");
27952791
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
27962792
XLATensor::selu_(self_tensor);
27972793
return self;
27982794
}
27992795

2800-
at::Tensor XLANativeFunctions::silu(const at::Tensor& self) {
2801-
XLA_FN_COUNTER("xla::");
2802-
return bridge::AtenFromXlaTensor(XLATensor::silu(bridge::GetXlaTensor(self)));
2803-
}
2804-
2805-
at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
2806-
const at::Tensor& self) {
2807-
XLA_FN_COUNTER("xla::");
2808-
XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output);
2809-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
2810-
return bridge::AtenFromXlaTensor(
2811-
XLATensor::silu_backward(grad_output_tensor, self_tensor));
2812-
}
2813-
28142796
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
28152797
XLA_FN_COUNTER("xla::");
28162798
return bridge::AtenFromXlaTensor(

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ torch_xla::XlaOpVector Rsqrt::Lower(LoweringContext* loctx) const {
216216
return ReturnOp(xla::Rsqrt(xla_input), loctx);
217217
}
218218

219+
torch_xla::XlaOpVector Selu::Lower(LoweringContext* loctx) const {
220+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
221+
return ReturnOp(BuildSelu(xla_input), loctx);
222+
}
223+
219224
torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
220225
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
221226
return ReturnOp(BuildSgn(xla_input), loctx);
@@ -226,6 +231,17 @@ torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
226231
return ReturnOp(BuildSign(xla_input), loctx);
227232
}
228233

234+
torch_xla::XlaOpVector Silu::Lower(LoweringContext* loctx) const {
235+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
236+
return ReturnOp(xla_input * BuildSigmoid(xla_input), loctx);
237+
}
238+
239+
torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const {
240+
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
241+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
242+
return ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
243+
}
244+
229245
torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const {
230246
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
231247
return ReturnOp(xla::Sin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "tensorflow/compiler/xla/client/lib/logdet.h"
44
#include "tensorflow/compiler/xla/shape_util.h"
5+
#include "torch_xla/csrc/elementwise.h"
56
#include "torch_xla/csrc/helpers.h"
67
#include "torch_xla/csrc/pooling.h"
78

@@ -208,6 +209,10 @@ xla::Shape RsqrtOutputShape(const torch::lazy::Value& input) {
208209
return GetXlaShape(input);
209210
}
210211

212+
xla::Shape SeluOutputShape(const torch::lazy::Value& input) {
213+
return GetXlaShape(input);
214+
}
215+
211216
xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
212217
return GetXlaShape(input);
213218
}
@@ -216,6 +221,20 @@ xla::Shape SignOutputShape(const torch::lazy::Value& input) {
216221
return GetXlaShape(input);
217222
}
218223

224+
xla::Shape SiluOutputShape(const torch::lazy::Value& input) {
225+
return GetXlaShape(input);
226+
}
227+
228+
xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
229+
const torch::lazy::Value& input) {
230+
auto lower_for_shape_fn =
231+
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
232+
return BuildSiLUBackward(operands[0], operands[1]);
233+
};
234+
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
235+
lower_for_shape_fn);
236+
}
237+
219238
xla::Shape SinOutputShape(const torch::lazy::Value& input) {
220239
return GetXlaShape(input);
221240
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,17 @@ xla::Shape RoundOutputShape(const torch::lazy::Value& input);
8484

8585
xla::Shape RsqrtOutputShape(const torch::lazy::Value& input);
8686

87+
xla::Shape SeluOutputShape(const torch::lazy::Value& input);
88+
8789
xla::Shape SgnOutputShape(const torch::lazy::Value& input);
8890

8991
xla::Shape SignOutputShape(const torch::lazy::Value& input);
9092

93+
xla::Shape SiluOutputShape(const torch::lazy::Value& input);
94+
95+
xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
96+
const torch::lazy::Value& input);
97+
9198
xla::Shape SinOutputShape(const torch::lazy::Value& input);
9299

93100
xla::Shape SinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,6 @@ class XLATensor : public c10::intrusive_ptr_target {
10371037
static XLATensorPtr select(const XLATensorPtr& input, int64_t dim,
10381038
int64_t index);
10391039

1040-
static XLATensorPtr selu(const XLATensorPtr& input);
10411040
static void selu_(XLATensorPtr& input);
10421041

10431042
static XLATensorPtr silu(const XLATensorPtr& input);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,24 +2420,10 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim,
24202420
return tensor_ops::Select(input, dim, index);
24212421
}
24222422

2423-
XLATensorPtr XLATensor::selu(const XLATensorPtr& input) {
2424-
return input->CreateFrom(Selu(input->GetIrValue()));
2425-
}
2426-
24272423
void XLATensor::selu_(XLATensorPtr& input) {
24282424
input->SetInPlaceIrValue(Selu(input->GetIrValue()));
24292425
}
24302426

2431-
XLATensorPtr XLATensor::silu(const XLATensorPtr& input) {
2432-
return input->CreateFrom(SiLU(input->GetIrValue()));
2433-
}
2434-
2435-
XLATensorPtr XLATensor::silu_backward(XLATensorPtr& grad_output,
2436-
XLATensorPtr& input) {
2437-
return input->CreateFrom(
2438-
SiLUBackward(grad_output->GetIrValue(), input->GetIrValue()));
2439-
}
2440-
24412427
XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) {
24422428
return input->CreateFrom(Sigmoid(input->GetIrValue()));
24432429
}

xla_native_functions.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ full_codegen:
3434
- reciprocal
3535
- round
3636
- rsqrt
37+
- selu
3738
- sgn
3839
- sign
40+
- silu
41+
- silu_backward
3942
- sin
4043
- sinh
4144
- tan
@@ -271,12 +274,9 @@ supported:
271274
- scatter.value_reduce
272275
- scatter_add
273276
- select.int
274-
- selu
275277
- selu_
276278
- sigmoid
277279
- sigmoid_backward
278-
- silu
279-
- silu_backward
280280
- slice.Tensor
281281
- slogdet
282282
- smooth_l1_loss

0 commit comments

Comments
 (0)