Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#81286
11 changes: 7 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1975,13 +1975,16 @@ at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
/*keep_reduced_dimensions=*/false, dtype));
}

at::Tensor XLANativeFunctions::mean(const at::Tensor& self, at::IntArrayRef dim,
bool keepdim,
at::Tensor XLANativeFunctions::mean(const at::Tensor& self,
at::OptionalIntArrayRef dim, bool keepdim,
c10::optional<at::ScalarType> dtype) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(XLATensor::mean(
bridge::GetXlaTensor(self), torch::lazy::ToVector<int64_t>(dim),
/*keep_reduced_dimensions=*/keepdim, dtype));
self_tensor,
dim ? torch::lazy::ToVector<int64_t>(*dim)
: torch::lazy::Iota<int64_t>(self_tensor->shape().get().rank()),
keepdim, dtype));
}

at::Tensor XLANativeFunctions::min(const at::Tensor& self) {
Expand Down