@@ -68,7 +68,6 @@ namespace torch_xla {
6868 std::move (lower_fn)); \
6969 }
7070
71- PTXLA_UNARY_OP (Tanh, at::aten::tanh, xla::Tanh);
7271PTXLA_UNARY_OP (Neg, at::aten::neg, xla::Neg);
7372PTXLA_UNARY_OP (Exp, at::aten::exp, xla::Exp);
7473PTXLA_UNARY_OP (Expm1, at::aten::expm1, xla::Expm1);
@@ -866,7 +865,9 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
866865 torch::lazy::NodePtr one = ScalarOp (1 , shape);
867866 torch::lazy::NodePtr half = ScalarOp (0.5 , shape);
868867 torch::lazy::NodePtr inner = beta * (input + kappa * Pow (input, three));
869- return half * input * (one + Tanh (inner));
868+ return half * input *
869+ (one + torch::lazy::MakeNode<Tanh>(inner,
870+ std::vector<torch::lazy::Shape>()));
870871}
871872
872873torch::lazy::NodePtr TanhGeluBackward (const torch::lazy::Value& grad,
@@ -882,7 +883,8 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
882883 torch::lazy::NodePtr three = ScalarOp (3 , shape);
883884 torch::lazy::NodePtr half = ScalarOp (0.5 , shape);
884885 torch::lazy::NodePtr inner = beta * (input + kappa * Pow (input, three));
885- torch::lazy::NodePtr tanh_inner = Tanh (inner);
886+ torch::lazy::NodePtr tanh_inner =
887+ torch::lazy::MakeNode<Tanh>(inner, std::vector<torch::lazy::Shape>());
886888
887889 torch::lazy::NodePtr left = half * input;
888890 torch::lazy::NodePtr right = one + tanh_inner;
0 commit comments