Skip to content

Commit 0478d2f

Browse files
committed
Add manual dtype conversion for aten_reciprocal
1 parent 3865443 commit 0478d2f

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,10 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {
680680

681681
torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
682682
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
683+
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
684+
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
685+
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
686+
}
683687
return ReturnOp(BuildReciprocal(xla_input), loctx);
684688
}
685689

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,11 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
758758
}
759759

760760
xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
761-
return GetXlaShape(input);
761+
xla::Shape result_shape = GetXlaShape(input);
762+
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
763+
result_shape.set_element_type(xla::PrimitiveType::F32);
764+
}
765+
return result_shape;
762766
}
763767

764768
xla::Shape ReluOutputShape(const torch::lazy::Value& input) {

0 commit comments

Comments
 (0)