@@ -68,7 +68,6 @@ namespace torch_xla {
6868 std::move (lower_fn)); \
6969 }
7070
71- PTXLA_UNARY_OP (Tanh, at::aten::tanh, xla::Tanh);
7271PTXLA_UNARY_OP (Neg, at::aten::neg, xla::Neg);
7372PTXLA_UNARY_OP (Exp, at::aten::exp, xla::Exp);
7473PTXLA_UNARY_OP (Expm1, at::aten::expm1, xla::Expm1);
@@ -867,7 +866,7 @@ torch::lazy::NodePtr TanhGelu(const torch::lazy::Value& input) {
867866 torch::lazy::NodePtr one = ScalarOp (1 , shape);
868867 torch::lazy::NodePtr half = ScalarOp (0.5 , shape);
869868 torch::lazy::NodePtr inner = beta * (input + kappa * Pow (input, three));
870- return half * input * (one + Tanh (inner));
869+ return half * input * (one + torch::lazy::MakeNode< Tanh> (inner, std::vector<torch::lazy::Shape>()));
871870}
872871
873872torch::lazy::NodePtr TanhGeluBackward (const torch::lazy::Value& grad,
@@ -883,7 +882,7 @@ torch::lazy::NodePtr TanhGeluBackward(const torch::lazy::Value& grad,
883882 torch::lazy::NodePtr three = ScalarOp (3 , shape);
884883 torch::lazy::NodePtr half = ScalarOp (0.5 , shape);
885884 torch::lazy::NodePtr inner = beta * (input + kappa * Pow (input, three));
886- torch::lazy::NodePtr tanh_inner = Tanh (inner);
885+ torch::lazy::NodePtr tanh_inner = torch::lazy::MakeNode< Tanh> (inner, std::vector<torch::lazy::Shape>() );
887886
888887 torch::lazy::NodePtr left = half * input;
889888 torch::lazy::NodePtr right = one + tanh_inner;
0 commit comments