@@ -69,10 +69,20 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
6969 0 , input_shape.element_type (), input.builder ()));
7070}
7171
72- xla::XlaOp BuildHardshrink (xla::XlaOp input, const at::Scalar& lambda) {
72+ xla::XlaOp BuildHardshrink (xla::XlaOp input, xla::XlaOp lambda) {
7373 const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp (input);
74- xla::XlaOp zero = xla::Zero (input.builder (), shape.element_type ());
75- return xla::Select (Between (input, -lambda, lambda), zero, input);
74+ xla::PrimitiveType input_element_type = shape.element_type ();
75+ xla::XlaOp zero = xla::Zero (input.builder (), input_element_type);
76+
77+ // The conversion here is needed because when we do computation such as
78+ // broadcast or subtraction for input and lambda, XLA disallows mixed
79+ // precision for float point types.
80+ lambda = MaybeConvertTo (lambda, input_element_type);
81+ xla::XlaOp check_low = BuildComparisonOp (at::aten::ge, input, zero - lambda);
82+ xla::XlaOp check_high = BuildComparisonOp (at::aten::le, input, lambda);
83+ xla::XlaOp between = xla::And (check_low, check_high);
84+
85+ return xla::Select (between, zero, input);
7686}
7787
7888xla::XlaOp BuildHardSigmoid (xla::XlaOp input) {
0 commit comments