Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ full_codegen:
- amin
- any
- any.dim
- argmax
- argmin
- asin
- asinh
- atan
Expand Down Expand Up @@ -132,8 +134,6 @@ supported:
- addmm
- alias_copy
- arange.start_out
- argmax
- argmin
- as_strided_copy
- as_strided_scatter
- atan2
Expand Down
20 changes: 0 additions & 20 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,26 +689,6 @@ at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start,
return out;
}

at::Tensor XLANativeFunctions::argmax(const at::Tensor& self,
c10::optional<int64_t> dim,
bool keepdim) {
TORCH_LAZY_FN_COUNTER("xla::");
return dim ? bridge::AtenFromXlaTensor(tensor_methods::argmax(
bridge::GetXlaTensor(self), *dim, keepdim))
: bridge::AtenFromXlaTensor(
tensor_methods::argmax(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::argmin(const at::Tensor& self,
c10::optional<int64_t> dim,
bool keepdim) {
TORCH_LAZY_FN_COUNTER("xla::");
return dim ? bridge::AtenFromXlaTensor(tensor_methods::argmin(
bridge::GetXlaTensor(self), *dim, keepdim))
: bridge::AtenFromXlaTensor(
tensor_methods::argmin(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::as_strided_copy(
const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride,
c10::optional<int64_t> storage_offset) {
Expand Down
28 changes: 28 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <torch/csrc/lazy/core/helpers.h>

#include "torch_xla/csrc/LazyIr.h"
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/data_ops.h"
Expand Down Expand Up @@ -107,6 +109,32 @@ torch_xla::XlaOpVector AnyDim::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildAny(input, {dim}, keepdim), loctx);
}

torch_xla::XlaOpVector Argmax::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
if (dim.has_value()) {
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
dim.value(), input_shape.rank());
return ReturnOp(torch_xla::BuildArgMax(input, canonical_dim, keepdim),
loctx);
} else {
return ReturnOp(torch_xla::BuildArgMax(input, -1, false), loctx);
}
}

torch_xla::XlaOpVector Argmin::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
if (dim.has_value()) {
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
dim.value(), input_shape.rank());
return ReturnOp(torch_xla::BuildArgMin(input, canonical_dim, keepdim),
loctx);
} else {
return ReturnOp(torch_xla::BuildArgMin(input, -1, false), loctx);
}
}

torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Asin(xla_input), loctx);
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "torch_xla/csrc/ops/ops_xla_shape_fn.h"

#include <torch/csrc/lazy/core/helpers.h>

#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/elementwise.h"
#include "torch_xla/csrc/helpers.h"
Expand Down Expand Up @@ -181,6 +183,38 @@ xla::Shape AnyDimOutputShape(const torch::lazy::Value& input, int64_t dim,
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape ArgmaxOutputShape(const torch::lazy::Value& input,
c10::optional<int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
if (dim.has_value()) {
const xla::Shape& input_shape = GetXlaShape(input);
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
dim.value(), input_shape.rank());
return BuildArgMax(operands[0], {canonical_dim}, keepdim);
} else {
return BuildArgMax(operands[0], {-1}, false);
}
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape ArgminOutputShape(const torch::lazy::Value& input,
c10::optional<int64_t> dim, bool keepdim) {
auto lower_for_shape_fn =
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
if (dim.has_value()) {
const xla::Shape& input_shape = GetXlaShape(input);
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
dim.value(), input_shape.rank());
return BuildArgMin(operands[0], {canonical_dim}, keepdim);
} else {
return BuildArgMin(operands[0], {-1}, false);
}
};
return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn);
}

xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ xla::Shape AmaxOutputShape(const torch::lazy::Value& input,
xla::Shape AminOutputShape(const torch::lazy::Value& input,
absl::Span<const int64_t> dim, bool keepdim);

xla::Shape ArgmaxOutputShape(const torch::lazy::Value& input,
c10::optional<int64_t> dim, bool keepdim);

xla::Shape ArgminOutputShape(const torch::lazy::Value& input,
c10::optional<int64_t> dim, bool keepdim);

xla::Shape AnyOutputShape(const torch::lazy::Value& input);

xla::Shape AnyDimOutputShape(const torch::lazy::Value& input, int64_t dim,
Expand Down
28 changes: 0 additions & 28 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,34 +697,6 @@ void arange_out(XLATensorPtr& out, const at::Scalar& start,
out->SetScalarType(scalar_type);
}

XLATensorPtr argmax(const XLATensorPtr& input, int64_t dim, bool keepdim) {
int64_t canonical_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank());
return input->CreateFrom(torch::lazy::MakeNode<ArgMax>(
input->GetIrValue(), canonical_dim, keepdim),
at::ScalarType::Long);
}

XLATensorPtr argmax(const XLATensorPtr& input) {
return input->CreateFrom(
torch::lazy::MakeNode<ArgMax>(input->GetIrValue(), -1, false),
at::ScalarType::Long);
}

XLATensorPtr argmin(const XLATensorPtr& input, int64_t dim, bool keepdim) {
int64_t canonical_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank());
return input->CreateFrom(torch::lazy::MakeNode<ArgMin>(
input->GetIrValue(), canonical_dim, keepdim),
at::ScalarType::Long);
}

XLATensorPtr argmin(const XLATensorPtr& input) {
return input->CreateFrom(
torch::lazy::MakeNode<ArgMin>(input->GetIrValue(), -1, false),
at::ScalarType::Long);
}

XLATensorPtr as_strided(const XLATensorPtr& input, std::vector<int64_t> size,
std::vector<int64_t> stride,
c10::optional<int64_t> storage_offset) {
Expand Down
6 changes: 0 additions & 6 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ void arange_out(XLATensorPtr& out, const at::Scalar& start,
const at::Scalar& end, const at::Scalar& step,
at::ScalarType scalar_type);

XLATensorPtr argmax(const XLATensorPtr& input, int64_t dim, bool keepdim);
XLATensorPtr argmax(const XLATensorPtr& input);

XLATensorPtr argmin(const XLATensorPtr& input, int64_t dim, bool keepdim);
XLATensorPtr argmin(const XLATensorPtr& input);

// Takes a slice from the input as R1 at the specified offset and reshapes it
// into the provided size.
XLATensorPtr as_strided(const XLATensorPtr& input, std::vector<int64_t> size,
Expand Down