File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -71,9 +71,13 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
7171
7272xla::XlaOp BuildHardshrink (xla::XlaOp input, xla::XlaOp lambda) {
7373 const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp (input);
74- xla::PrimitiveType element_type = shape.element_type ();
75- xla::XlaOp zero = xla::Zero (input.builder (), element_type );
74+ xla::PrimitiveType input_element_type = shape.element_type ();
75+ xla::XlaOp zero = xla::Zero (input.builder (), input_element_type );
7676
77+ // The conversion here is needed because when we do computation such as
78+ // broadcast or subtraction for input and lambda, XLA disallows mixed
79+ // precision for float point types.
80+ lambda = MaybeConvertTo (lambda, input_element_type);
7781 xla::XlaOp check_low = BuildComparisonOp (at::aten::ge, input, zero - lambda);
7882 xla::XlaOp check_high = BuildComparisonOp (at::aten::le, input, lambda);
7983 xla::XlaOp between = xla::And (check_low, check_high);
You can’t perform that action at this time.
0 commit comments