Skip to content

Commit ce1bd4e

Browse files
authored
Revert "Support torch.mean with dim=None (#3752)" (#3757)
This reverts commit d9a03dd.
1 parent c88a86b commit ce1bd4e

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,16 +1975,13 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
19751975
/*keep_reduced_dimensions=*/false, dtype));
19761976
}
19771977

1978-
at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
1979-
at::OptionalIntArrayRef dim, bool keepdim,
1978+
at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::IntArrayRef dim,
1979+
bool keepdim,
19801980
c10::optional<at::ScalarType> dtype) {
19811981
XLA_FN_COUNTER("xla::");
1982-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
19831982
return bridge::AtenFromXlaTensor(XLATensor::mean(
1984-
self_tensor,
1985-
dim ? torch::lazy::ToVector<int64_t>(*dim)
1986-
: torch::lazy::Iota<int64_t>(self_tensor->shape().get().rank()),
1987-
keepdim, dtype));
1983+
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(dim),
1984+
/*keep_reduced_dimensions=*/keepdim, dtype));
19881985
}
19891986

19901987
at::Tensor XLANativeFunctions::min(const at::Tensor& self) {

0 commit comments

Comments
 (0)