From e54604cb83376810c5cc5fdbfb85a20b201e37dc Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 27 Jul 2022 05:25:20 +0000 Subject: [PATCH 1/5] Codegen selu and silu --- test/cpp/test_aten_xla_tensor.cpp | 2 +- torch_xla/csrc/aten_xla_type.cpp | 26 ------------------------- torch_xla/csrc/ops/ops_lower_fn.cpp | 16 +++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.cpp | 19 ++++++++++++++++++ torch_xla/csrc/ops/ops_xla_shape_fn.h | 7 +++++++ torch_xla/csrc/tensor_methods.cpp | 18 ----------------- xla_native_functions.yaml | 7 +++---- 7 files changed, 46 insertions(+), 49 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 196ac8b6cda0..ec259d1ca7eb 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6386,8 +6386,8 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) { AllClose(input, xla_input); }); + // selu_ uses elu_ instead of selu ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::selu_", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestCelu) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 97c07c7741aa..4055e9beffc3 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2785,32 +2785,6 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim, XLATensor::select(bridge::GetXlaTensor(self), dim, index)); } -at::Tensor XLANativeFunctions::selu(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::selu(bridge::GetXlaTensor(self))); -} - -at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - XLATensor::selu_(self_tensor); - return self; -} - -at::Tensor XLANativeFunctions::silu(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::silu(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output, - const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output); - XLATensorPtr self_tensor = bridge::GetXlaTensor(self); - return bridge::AtenFromXlaTensor( - XLATensor::silu_backward(grad_output_tensor, self_tensor)); -} - at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index cccf1f8c542f..e4fbc6ce3564 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -216,6 +216,11 @@ torch_xla::XlaOpVector Rsqrt::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Rsqrt(xla_input), loctx); } +torch_xla::XlaOpVector Selu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(BuildSelu(xla_input), loctx); +} + torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(BuildSgn(xla_input), loctx); @@ -226,6 +231,17 @@ torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const { return ReturnOp(BuildSign(xla_input), loctx); } +torch_xla::XlaOpVector Silu::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla_input * BuildSigmoid(xla_input), loctx); +} + +torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(operand(1)); + return ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx); +} + torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::Sin(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 8fe686348ea0..fc04fde976e7 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -2,6 +2,7 @@ #include "tensorflow/compiler/xla/client/lib/logdet.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/pooling.h" @@ -208,6 +209,10 @@ xla::Shape RsqrtOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape SeluOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + xla::Shape SgnOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } @@ -216,6 +221,20 @@ xla::Shape SignOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape SiluOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output, + const torch::lazy::Value& input) { + auto lower_for_shape_fn = + [](absl::Span operands) -> xla::XlaOp { + return BuildSiLUBackward(operands[0], operands[1]); + }; + return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)}, + lower_for_shape_fn); +} + xla::Shape SinOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 368850adb9e4..e636559b17e6 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -84,10 +84,17 @@ xla::Shape RoundOutputShape(const torch::lazy::Value& input); xla::Shape RsqrtOutputShape(const torch::lazy::Value& input); +xla::Shape SeluOutputShape(const torch::lazy::Value& input); + xla::Shape SgnOutputShape(const torch::lazy::Value& input); xla::Shape SignOutputShape(const torch::lazy::Value& input); +xla::Shape SiluOutputShape(const torch::lazy::Value& input); + +xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output, + const torch::lazy::Value& input); + xla::Shape SinOutputShape(const torch::lazy::Value& input); xla::Shape SinhOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2f5845c6a4cf..b6e6ee6d4f19 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2420,24 +2420,6 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim, return tensor_ops::Select(input, dim, index); } -XLATensorPtr XLATensor::selu(const XLATensorPtr& input) { - return input->CreateFrom(Selu(input->GetIrValue())); -} - -void XLATensor::selu_(XLATensorPtr& input) { - input->SetInPlaceIrValue(Selu(input->GetIrValue())); -} - -XLATensorPtr XLATensor::silu(const XLATensorPtr& input) { - return input->CreateFrom(SiLU(input->GetIrValue())); -} - -XLATensorPtr XLATensor::silu_backward(XLATensorPtr& grad_output, - XLATensorPtr& input) { - return input->CreateFrom( - SiLUBackward(grad_output->GetIrValue(), input->GetIrValue())); -} - XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) { return input->CreateFrom(Sigmoid(input->GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 0042d955ac63..884911b6d093 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -34,8 +34,11 @@ full_codegen: - reciprocal - round - rsqrt + - selu - sgn - sign + - silu + - silu_backward - sin - sinh - tan @@ -271,12 +274,8 @@ supported: - scatter.value_reduce - scatter_add - select.int - - selu - - selu_ - sigmoid - sigmoid_backward - - silu - - silu_backward - slice.Tensor - slogdet - smooth_l1_loss From c259c0ff93c9f7e60bfa0193fc0afbbd337e7d4f Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 27 Jul 2022 05:35:31 +0000 Subject: [PATCH 2/5] Add torch pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 000000000000..c6489037910f --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#82297 From e242d8691bb0343fde252aa169893dc8a90367aa Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 27 Jul 2022 23:35:25 +0000 Subject: [PATCH 3/5] include error print --- test/cpp/test_aten_xla_tensor.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index ec259d1ca7eb..1de174fabb03 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6382,6 +6382,8 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) { torch::Tensor xla_input = CopyToDevice(input, device); torch::Tensor output = torch::selu_(input); torch::Tensor xla_output = torch::selu_(xla_input); + std::cerr << output << "\n"; + std::cerr << xla_output << "\n"; AllClose(output, xla_output); AllClose(input, xla_input); }); From 24595462f3914ed0edee10d522cc1f94a782cbdc Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 29 Jul 2022 04:47:25 +0000 Subject: [PATCH 4/5] Add selu_ back --- test/cpp/test_aten_xla_tensor.cpp | 4 +--- torch_xla/csrc/aten_xla_type.cpp | 8 ++++++++ torch_xla/csrc/tensor.h | 1 - torch_xla/csrc/tensor_methods.cpp | 4 ++++ xla_native_functions.yaml | 1 + 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 1de174fabb03..196ac8b6cda0 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6382,14 +6382,12 @@ TEST_F(AtenXlaTensorTest, TestSeluInPlace) { torch::Tensor xla_input = CopyToDevice(input, device); torch::Tensor output = torch::selu_(input); torch::Tensor xla_output = torch::selu_(xla_input); - std::cerr << output << "\n"; - std::cerr << xla_output << "\n"; AllClose(output, xla_output); AllClose(input, xla_input); }); - // selu_ uses elu_ instead of selu ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::selu_", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestCelu) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4055e9beffc3..e480c12df159 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2785,6 +2785,14 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim, XLATensor::select(bridge::GetXlaTensor(self), dim, index)); } +// TODO(JackCaoG): Remove after elu being codegened +at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) { + XLA_FN_COUNTER("xla::"); + XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + XLATensor::selu_(self_tensor); + return self; +} + at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor( diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index b301315b1ec7..2ccee7c82cef 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1037,7 +1037,6 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensorPtr select(const XLATensorPtr& input, int64_t dim, int64_t index); - static XLATensorPtr selu(const XLATensorPtr& input); static void selu_(XLATensorPtr& input); static XLATensorPtr silu(const XLATensorPtr& input); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index b6e6ee6d4f19..b04847d56ccb 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2420,6 +2420,10 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim, return tensor_ops::Select(input, dim, index); } +void XLATensor::selu_(XLATensorPtr& input) { + input->SetInPlaceIrValue(Selu(input->GetIrValue())); +} + XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) { return input->CreateFrom(Sigmoid(input->GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 884911b6d093..35c745e23c89 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -274,6 +274,7 @@ supported: - scatter.value_reduce - scatter_add - select.int + - selu_ - sigmoid - sigmoid_backward - slice.Tensor From a39d3ca1d961c38eb86ab7b974e29e4b66ed4ccc Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 29 Jul 2022 04:47:40 +0000 Subject: [PATCH 5/5] remove pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index c6489037910f..000000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#82297