@@ -370,22 +370,19 @@ xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
370370 return grad_output * (xla::Neg (max_deriv) - sign * (buffer - one) / buffer);
371371}
372372
373- xla::XlaOp BuildElu (xla::XlaOp input, const at::Scalar& alpha,
374- const at::Scalar& scale, const at::Scalar& input_scale) {
373+ xla::XlaOp BuildElu (xla::XlaOp input, xla::XlaOp alpha, xla::XlaOp scale ,
374+ xla::XlaOp input_scale) {
375375 const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp (input);
376- xla::XlaOp scaled_input =
377- input * XlaHelpers::ScalarValue (input_scale, shape.element_type (),
378- input.builder ());
376+ alpha = MaybeConvertTo (alpha, shape.element_type ());
377+ scale = MaybeConvertTo (scale, shape.element_type ());
378+ input_scale = MaybeConvertTo (input_scale, shape.element_type ());
379+ xla::XlaOp scaled_input = input * input_scale;
379380 xla::XlaOp zero = xla::Zero (input.builder (), shape.element_type ());
380381 xla::XlaOp one = XlaHelpers::ScalarValue<float >(1.0 , shape.element_type (),
381382 input.builder ());
382- xla::XlaOp alpha_scalar =
383- XlaHelpers::ScalarValue (alpha, shape.element_type (), input.builder ());
384- xla::XlaOp scale_scalar =
385- XlaHelpers::ScalarValue (scale, shape.element_type (), input.builder ());
386383 return xla::Select (xla::Le (input, zero),
387- alpha_scalar * (xla::Exp (scaled_input) - one), input) *
388- scale_scalar ;
384+ alpha * (xla::Exp (scaled_input) - one), input) *
385+ scale ;
389386}
390387
391388xla::XlaOp BuildEluBackward (xla::XlaOp grad_output, xla::XlaOp output,
0 commit comments