Skip to content

Commit 2459546

Browse files
committed
Add selu_ back
1 parent e242d86 commit 2459546

File tree

5 files changed

+14
-4
lines changed

5 files changed

+14
-4
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6382,14 +6382,12 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) {
63826382
torch::Tensor xla_input = CopyToDevice(input, device);
63836383
torch::Tensor output = torch::selu_(input);
63846384
torch::Tensor xla_output = torch::selu_(xla_input);
6385-
std::cerr << output << "\n";
6386-
std::cerr << xla_output << "\n";
63876385
AllClose(output, xla_output);
63886386
AllClose(input, xla_input);
63896387
});
63906388

6391-
// selu_ uses elu_ instead of selu
63926389
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
6390+
ExpectCounterChanged("xla::selu_", cpp_test::GetIgnoredCounters());
63936391
}
63946392

63956393
TEST_F(AtenXlaTensorTest, TestCelu) {

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2785,6 +2785,14 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim,
27852785
XLATensor::select(bridge::GetXlaTensor(self), dim, index));
27862786
}
27872787

2788+
// TODO(JackCaoG): Remove after elu being codegened
2789+
at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) {
2790+
XLA_FN_COUNTER("xla::");
2791+
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
2792+
XLATensor::selu_(self_tensor);
2793+
return self;
2794+
}
2795+
27882796
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
27892797
XLA_FN_COUNTER("xla::");
27902798
return bridge::AtenFromXlaTensor(

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,6 @@ class XLATensor : public c10::intrusive_ptr_target {
10371037
static XLATensorPtr select(const XLATensorPtr& input, int64_t dim,
10381038
int64_t index);
10391039

1040-
static XLATensorPtr selu(const XLATensorPtr& input);
10411040
static void selu_(XLATensorPtr& input);
10421041

10431042
static XLATensorPtr silu(const XLATensorPtr& input);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,10 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim,
24202420
return tensor_ops::Select(input, dim, index);
24212421
}
24222422

2423+
void XLATensor::selu_(XLATensorPtr& input) {
2424+
input->SetInPlaceIrValue(Selu(input->GetIrValue()));
2425+
}
2426+
24232427
XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) {
24242428
return input->CreateFrom(Sigmoid(input->GetIrValue()));
24252429
}

xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ supported:
274274
- scatter.value_reduce
275275
- scatter_add
276276
- select.int
277+
- selu_
277278
- sigmoid
278279
- sigmoid_backward
279280
- slice.Tensor

0 commit comments

Comments
 (0)