Skip to content

Commit 1262dd4

Browse files
authored
Codegen for Tanh (#3724)
1 parent 75ac08b commit 1262dd4

File tree

9 files changed

+21
-20
lines changed

9 files changed

+21
-20
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,11 +1533,6 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
15331533
XLATensor::clamp(bridge::GetXlaTensor(self), min_val, max_val));
15341534
}
15351535

1536-
at::Tensor XLANativeFunctions::tanh(const at::Tensor& self) {
1537-
XLA_FN_COUNTER("xla::");
1538-
return bridge::AtenFromXlaTensor(XLATensor::tanh(bridge::GetXlaTensor(self)));
1539-
}
1540-
15411536
at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,
15421537
const at::Tensor& self,
15431538
const at::Scalar& min_val,

torch_xla/csrc/ops/ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ namespace torch_xla {
6868
std::move(lower_fn)); \
6969
}
7070

71-
PTXLA_UNARY_OP(Tanh, at::aten::tanh, xla::Tanh);
7271
PTXLA_UNARY_OP(Neg, at::aten::neg, xla::Neg);
7372
PTXLA_UNARY_OP(Exp, at::aten::exp, xla::Exp);
7473
PTXLA_UNARY_OP(Expm1, at::aten::expm1, xla::Expm1);
@@ -866,7 +865,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
866865
torch::lazy::NodePtr one = ScalarOp(1, shape);
867866
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
868867
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
869-
return half * input * (one + Tanh(inner));
868+
return half * input *
869+
(one + torch::lazy::MakeNode<Tanh>(inner,
870+
std::vector<torch::lazy::Shape>()));
870871
}
871872

872873
torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
@@ -882,7 +883,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
882883
torch::lazy::NodePtr three = ScalarOp(3, shape);
883884
torch::lazy::NodePtr half = ScalarOp(0.5, shape);
884885
torch::lazy::NodePtr inner = beta * (input + kappa * Pow(input, three));
885-
torch::lazy::NodePtr tanh_inner = Tanh(inner);
886+
torch::lazy::NodePtr tanh_inner =
887+
torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>());
886888

887889
torch::lazy::NodePtr left = half * input;
888890
torch::lazy::NodePtr right = one + tanh_inner;

torch_xla/csrc/ops/ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ torch::lazy::NodePtr Atan2(const torch::lazy::Value& input,
6868

6969
torch::lazy::NodePtr Tan(const torch::lazy::Value& input);
7070

71-
torch::lazy::NodePtr Tanh(const torch::lazy::Value& input);
72-
7371
torch::lazy::NodePtr Neg(const torch::lazy::Value& input);
7472

7573
torch::lazy::NodePtr SgnOp(const torch::lazy::Value& input);

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,9 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
144144
return ReturnOp(xla::Tan(xla_input), loctx);
145145
}
146146

147+
torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
148+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
149+
return ReturnOp(xla::Tanh(xla_input), loctx);
150+
}
151+
147152
} // namespace torch_xla

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,8 @@ xla::Shape TanOutputShape(const torch::lazy::Value& input) {
135135
return GetXlaShape(input);
136136
}
137137

138+
xla::Shape TanhOutputShape(const torch::lazy::Value& input) {
139+
return GetXlaShape(input);
140+
}
141+
138142
} // namespace torch_xla

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,6 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input);
5858

5959
xla::Shape TanOutputShape(const torch::lazy::Value& input);
6060

61+
xla::Shape TanhOutputShape(const torch::lazy::Value& input);
62+
6163
} // namespace torch_xla

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,6 @@ class XLATensor : public c10::intrusive_ptr_target {
11671167
static XLATensorPtr take(const XLATensorPtr& input,
11681168
const XLATensorPtr& index);
11691169

1170-
static XLATensorPtr tanh(const XLATensorPtr& input);
1171-
11721170
static XLATensorPtr tanh_backward(const XLATensorPtr& grad_output,
11731171
const XLATensorPtr& output);
11741172

torch_xla/csrc/tensor_methods.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,9 +1924,10 @@ void XLATensor::min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
19241924
}
19251925

19261926
XLATensorPtr XLATensor::mish(const XLATensorPtr& input) {
1927-
return input->CreateFrom(
1928-
input->GetIrValue() *
1929-
Tanh(tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
1927+
return input->CreateFrom(input->GetIrValue() *
1928+
torch::lazy::MakeNode<Tanh>(
1929+
tensor_ops::Softplus(input, 1, 20)->GetIrValue(),
1930+
std::vector<torch::lazy::Shape>()));
19301931
}
19311932

19321933
XLATensorPtr XLATensor::mm(const XLATensorPtr& input,
@@ -2772,10 +2773,6 @@ XLATensorPtr XLATensor::take(const XLATensorPtr& input,
27722773
return input->CreateFrom(Take(input->GetIrValue(), index->GetIrValue()));
27732774
}
27742775

2775-
XLATensorPtr XLATensor::tanh(const XLATensorPtr& input) {
2776-
return input->CreateFrom(Tanh(input->GetIrValue()));
2777-
}
2778-
27792776
XLATensorPtr XLATensor::tanh_backward(const XLATensorPtr& grad_output,
27802777
const XLATensorPtr& output) {
27812778
return XLATensor::mul(grad_output,

xla_native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ full_codegen:
2626
- sin
2727
- sinh
2828
- tan
29+
- tanh
2930
supported:
3031
- __ilshift__.Scalar
3132
- __ilshift__.Tensor
@@ -302,7 +303,6 @@ supported:
302303
- t
303304
- t_
304305
- take
305-
- tanh
306306
- tanh_backward
307307
- threshold
308308
- threshold_backward

0 commit comments

Comments
 (0)