Skip to content

Commit 68d8b0f

Browse files
committed
Codegen for Tanh
1 parent 9fa40f6 commit 68d8b0f

File tree

9 files changed

+15
-18
lines changed

9 files changed

+15
-18
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
15331533
XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val));
15341534
}
15351535

1536-
at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) {
1537-
XLA_FN_COUNTER("xla::");
1538-
return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self)));
1539-
}
1540-
15411536
at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,
15421537
const at::Tensor& self,
15431538
const at::Scalar& min_val,

torch_xla/csrc/ops/ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ namespace torch_xla {
6868
std::move(lower_fn)); \
6969
}
7070

71-
PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh);
7271
PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
7372
PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp);
7473
PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1);
@@ -867,7 +866,7 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
867866
torch::lazy::NodePtr one = ScalarOp(1, shape);
868867
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
869868
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
870-
return half * input * (one + Tanh(inner));
869+
return half * input * (one + torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>()));
871870
}
872871

873872
torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
@@ -883,7 +882,7 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
883882
torch::lazy::NodePtr three = ScalarOp(3, shape);
884883
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
885884
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
886-
torch::lazy::NodePtr tanh_inner = Tanh(inner);
885+
torch::lazy::NodePtr tanh_inner = torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>());
887886

888887
torch::lazy::NodePtr left = half * input;
889888
torch::lazy::NodePtr right = one + tanh_inner;

torch_xla/csrc/ops/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input,
6868

6969
torch::lazy::NodePtr Tan(const torch::lazy::Value& input);
7070

71-
torch::lazy::NodePtr Tanh(const torch::lazy::Value& input);
72-
7371
torch::lazy::NodePtr Neg(const torch::lazy::Value& input);
7472

7573
torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input);

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
139139
return ReturnOp(xla::Tan(xla_input), loctx);
140140
}
141141

142+
torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
143+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
144+
return ReturnOp(xla::Tanh(xla_input), loctx);
145+
}
146+
142147
} // namespace torch_xla

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) {
131131
return GetXlaShape(input);
132132
}
133133

134+
xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
135+
return GetXlaShape(input);
136+
}
137+
134138
} // namespace torch_xla

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input);
5656

5757
xla::Shape TanOutputShape(const torch::lazy::Value& input);
5858

59+
xla::Shape TanhOutputShape(const torch::lazy::Value& input);
60+
5961
} // namespace torch_xla

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,8 +1169,6 @@ class XLATensor : public c10::intrusive_ptr_target {
11691169
static XLATensorPtr take(const XLATensorPtr& input,
11701170
const XLATensorPtr& index);
11711171

1172-
static XLATensorPtr tanh(const XLATensorPtr& input);
1173-
11741172
static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
11751173
const XLATensorPtr& output);
11761174

torch_xla/csrc/tensor_methods.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,7 +1926,7 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
19261926
XLATensorPtr XLATensor::mish(const XLATensorPtr& input) {
19271927
return input->CreateFrom(
19281928
input->GetIrValue() *
1929-
Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
1929+
torch::lazy::MakeNode<Tanh>(tensor_ops::Softplus(input, 1, 20)->GetIrValue(), std::vector<torch::lazy::Shape>()));
19301930
}
19311931

19321932
XLATensorPtr XLATensor::mm(const XLATensorPtr& input,
@@ -2776,10 +2776,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input,
27762776
return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue()));
27772777
}
27782778

2779-
XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) {
2780-
return input->CreateFrom(Tanh(input->GetIrValue()));
2781-
}
2782-
27832779
XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output,
27842780
const XLATensorPtr& output) {
27852781
return XLATensor::mul(grad_output,

xla_native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ full_codegen:
2525
- sin
2626
- sinh
2727
- tan
28+
- tanh
2829
supported:
2930
- __ilshift__.Scalar
3031
- __ilshift__.Tensor
@@ -302,7 +303,6 @@ supported:
302303
- t
303304
- t_
304305
- take
305-
- tanh
306306
- tanh_backward
307307
- threshold
308308
- threshold_backward

0 commit comments

Comments
 (0)