diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e5fe1adfc2fc..0cd0ee2928bb 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self, XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val)); } -at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output, const at::Tensor& self, const at::Scalar& min_val, diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 18fee13ab886..93219ac426c0 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -68,7 +68,6 @@ namespace torch_xla { std::move(lower_fn)); \ } -PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh); PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg); PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp); PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1); @@ -867,7 +866,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) { torch::lazy::NodePtr one = ScalarOp(1, shape); torch::lazy::NodePtr half = ScalarOp(0.5, shape); torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three)); - return half * input * (one + Tanh(inner)); + return half * input * + (one + torch::lazy::MakeNode(inner, + std::vector())); } torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad, @@ -883,7 +884,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad, torch::lazy::NodePtr three = ScalarOp(3, shape); torch::lazy::NodePtr half = ScalarOp(0.5, shape); torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three)); - torch::lazy::NodePtr tanh_inner = Tanh(inner); + torch::lazy::NodePtr tanh_inner = + torch::lazy::MakeNode(inner, std::vector()); torch::lazy::NodePtr left = half * input; torch::lazy::NodePtr right = one + tanh_inner; diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 9e67b824897b..307640e35e17 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input, torch::lazy::NodePtr Tan(const torch::lazy::Value& input); -torch::lazy::NodePtr Tanh(const torch::lazy::Value& input); - torch::lazy::NodePtr Neg(const torch::lazy::Value& input); torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index ffeed40049fa..acd16e93739c 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -139,4 +139,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Tan(xla_input), loctx); } +torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Tanh(xla_input), loctx); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index bdd74a9112cd..67e3c0970034 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -131,4 +131,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape TanhOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 679875404d41..12f3e662fb9d 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -56,4 +56,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input); xla::Shape TanOutputShape(const torch::lazy::Value& input); +xla::Shape TanhOutputShape(const torch::lazy::Value& input); + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 6758a9224338..51196e38cd83 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1169,8 +1169,6 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensorPtr take(const XLATensorPtr& input, const XLATensorPtr& index); - static XLATensorPtr tanh(const XLATensorPtr& input); - static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 3e2c485a4370..527b5164df38 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1924,9 +1924,10 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices, } XLATensorPtr XLATensor::mish(const XLATensorPtr& input) { - return input->CreateFrom( - input->GetIrValue() * - Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue())); + return input->CreateFrom(input->GetIrValue() * + torch::lazy::MakeNode( + tensor_ops::Softplus(input, 1, 20)->GetIrValue(), + std::vector())); } XLATensorPtr XLATensor::mm(const XLATensorPtr& input, @@ -2776,10 +2777,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input, return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue())); } -XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) { - return input->CreateFrom(Tanh(input->GetIrValue())); -} - XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output, const XLATensorPtr& output) { return XLATensor::mul(grad_output, diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index 6992707e63c1..0dec51728c3d 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -25,6 +25,7 @@ full_codegen: - sin - sinh - tan + - tanh supported: - __ilshift__.Scalar - __ilshift__.Tensor @@ -302,7 +303,6 @@ supported: - t - t_ - take - - tanh - tanh_backward - threshold - threshold_backward