Skip to content

Commit 4ccfc24

Browse files
committed
Do the necessary type conversion.
1 parent f8b3dfd commit 4ccfc24

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch_xla/csrc/elementwise.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,13 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
7171

7272
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
7373
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
74-
xla::PrimitiveType element_type = shape.element_type();
75-
xla::XlaOp zero = xla::Zero(input.builder(), element_type);
74+
xla::PrimitiveType input_element_type = shape.element_type();
75+
xla::XlaOp zero = xla::Zero(input.builder(), input_element_type);
7676

77+
// The conversion here is needed because when we do computation such as
78+
// broadcast or subtraction for input and lambda, XLA disallows mixed
79+
// precision for float point types.
80+
lambda = MaybeConvertTo(lambda, input_element_type);
7781
xla::XlaOp check_low = BuildComparisonOp(at::aten::ge, input, zero - lambda);
7882
xla::XlaOp check_high = BuildComparisonOp(at::aten::le, input, lambda);
7983
xla::XlaOp between = xla::And(check_low, check_high);

0 commit comments

Comments
 (0)