@@ -709,31 +709,30 @@ torch::lazy::NodePtr Identity(int64_t lines, int64_t cols,
709709torch::lazy::NodePtr Elu (const XlaValue& input, const at::Scalar& alpha,
710710 const at::Scalar& scale,
711711 const at::Scalar& input_scale) {
712- ScopePusher ir_scope (at::aten::elu.toQualString ());
713- const xla::Shape& shape = input.xla_shape ();
714- torch::lazy::NodePtr scaled_input = input * ScalarOp (input_scale, shape);
715- torch::lazy::NodePtr zero = ScalarOp (0 , shape);
716- torch::lazy::NodePtr one = ScalarOp (1 , shape);
717- torch::lazy::NodePtr alpha_scalar = ScalarOp (alpha, shape);
718- return Where (ComparisonOp (at::aten::le, input, zero),
719- alpha_scalar * (Exp (scaled_input) - one), input) *
720- ScalarOp (scale, shape);
712+ auto lower_fn = [=](const XlaNode& node,
713+ LoweringContext* loctx) -> XlaOpVector {
714+ xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
715+ return node.ReturnOp (BuildElu (xla_input, alpha, scale, input_scale), loctx);
716+ };
717+ return GenericOp (torch::lazy::OpKind (at::aten::elu),
718+ {input}, input.xla_shape (),
719+ std::move (lower_fn));
721720}
722721
723722torch::lazy::NodePtr EluBackward (const XlaValue& grad_output,
724723 const XlaValue& output,
725724 const at::Scalar& alpha,
726725 const at::Scalar& scale,
727726 const at::Scalar& input_scale) {
728- ScopePusher ir_scope (at::aten::elu_backward. toQualString ());
729- const xla::Shape& shape = grad_output. xla_shape ();
730- torch::lazy::NodePtr negative_output_branch =
731- ScalarOp (input_scale, shape) *
732- (output + ScalarOp ( alpha, shape) * ScalarOp ( scale, shape) );
733- torch::lazy::NodePtr positive_output_branch = ScalarOp (scale, shape) ;
734- return grad_output *
735- Where ( ComparisonOp (at::aten::gt , output, ScalarOp ( 0 , shape) ),
736- positive_output_branch, negative_output_branch );
727+ auto lower_fn = [=]( const XlaNode& node,
728+ LoweringContext* loctx) -> XlaOpVector {
729+ xla::XlaOp xla_grad_output = loctx-> GetOutputOp (node. operand ( 0 ));
730+ xla::XlaOp xla_output = loctx-> GetOutputOp (node. operand ( 1 ));
731+ return node. ReturnOp ( BuildEluBackward (xla_grad_output, xla_output, alpha, scale, input_scale), loctx );
732+ } ;
733+ return GenericOp ( torch::lazy::OpKind (at::aten::elu_backward),
734+ {grad_output , output}, output. xla_shape ( ),
735+ std::move (lower_fn) );
737736}
738737
739738torch::lazy::NodePtr Gelu (const XlaValue& input) {
0 commit comments