Skip to content

Commit b549612

Browse files
authored
[WIP] Partially Codegen Atan2 (#4184)
* partially codegen atan2 * clang-format-7 * clang-format-7 again
1 parent b1f167e commit b549612

File tree

9 files changed

+31
-20
lines changed

9 files changed

+31
-20
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,15 @@ at::Tensor XLANativeFunctions::atan2(const at::Tensor& self,
688688
return at::native::call_fallback_fn<&xla_cpu_fallback,
689689
ATEN_OP(atan2)>::call(self, other);
690690
}
691-
return DoBinaryOp(self, other,
692-
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
693-
at::ScalarType dtype) {
694-
return XLATensor::atan2(xself, xother, dtype);
695-
});
691+
692+
auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
693+
XLA_CHECK(common_device);
694+
torch::lazy::NodePtr node =
695+
torch::lazy::MakeNode<Atan2>(bridge::GetXlaTensor(self)->GetIrValue(),
696+
bridge::GetXlaTensor(other)->GetIrValue());
697+
698+
return torch_xla::bridge::AtenFromXlaTensor(
699+
torch_xla::XLATensor::Create(std::move(node), *common_device));
696700
}
697701

698702
at::Tensor XLANativeFunctions::avg_pool2d(

torch_xla/csrc/ops/ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ PTXLA_UNARY_OP(Sqrt, at::aten::sqrt, xla::Sqrt);
7777
PTXLA_BINARY_OP(Min, at::aten::min, xla::Min);
7878
PTXLA_BINARY_OP(Pow, at::aten::pow, xla::Pow);
7979
PTXLA_BINARY_OP(Fmod, at::aten::fmod, xla::Rem);
80-
PTXLA_BINARY_OP(Atan2, at::aten::atan2, xla::Atan2);
8180

8281
torch::lazy::NodePtr LogBase(const torch::lazy::Value& input,
8382
torch::lazy::OpKind op, double base) {

torch_xla/csrc/ops/ops.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ torch::lazy::NodePtr Sin(const torch::lazy::Value& input);
7575

7676
torch::lazy::NodePtr Sinh(const torch::lazy::Value& input);
7777

78-
torch::lazy::NodePtr Atan2(const torch::lazy::Value& input,
79-
const torch::lazy::Value& other);
80-
8178
torch::lazy::NodePtr Tan(const torch::lazy::Value& input);
8279

8380
torch::lazy::NodePtr Neg(const torch::lazy::Value& input);

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ torch_xla::XlaOpVector Atan::Lower(LoweringContext* loctx) const {
120120
return ReturnOp(xla::Atan(xla_input), loctx);
121121
}
122122

123+
torch_xla::XlaOpVector Atan2::Lower(LoweringContext* loctx) const {
124+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
125+
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
126+
auto promoted = XlaHelpers::Promote(xla_input, xla_other);
127+
return ReturnOp(xla::Atan2(promoted.first, promoted.second), loctx);
128+
}
129+
123130
torch_xla::XlaOpVector Atanh::Lower(LoweringContext* loctx) const {
124131
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
125132
return ReturnOp(xla::Atanh(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,17 @@ xla::Shape AtanOutputShape(const torch::lazy::Value& input) {
193193
return GetXlaShape(input);
194194
}
195195

196+
xla::Shape Atan2OutputShape(const torch::lazy::Value& input,
197+
const torch::lazy::Value& other) {
198+
auto lower_for_shape_fn =
199+
[&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
200+
auto promoted = XlaHelpers::Promote(operands[0], operands[1]);
201+
return xla::Atan2(promoted.first, promoted.second);
202+
};
203+
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)},
204+
lower_for_shape_fn);
205+
}
206+
196207
xla::Shape AtanhOutputShape(const torch::lazy::Value& input) {
197208
return GetXlaShape(input);
198209
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ xla::Shape AsinhOutputShape(const torch::lazy::Value& input);
5353

5454
xla::Shape AtanOutputShape(const torch::lazy::Value& input);
5555

56+
xla::Shape Atan2OutputShape(const torch::lazy::Value& input,
57+
const torch::lazy::Value& other);
58+
5659
xla::Shape AtanhOutputShape(const torch::lazy::Value& input);
5760

5861
xla::Shape BinaryCrossEntropyOutputShape(

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,6 @@ class XLATensor : public c10::intrusive_ptr_target {
428428
std::vector<int64_t> stride,
429429
c10::optional<int64_t> storage_offset);
430430

431-
static XLATensorPtr atan2(
432-
const XLATensorPtr& input, const XLATensorPtr& other,
433-
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
434-
435431
static XLATensorPtr avg_pool_nd(const XLATensorPtr& input,
436432
int64_t spatial_dim_count,
437433
std::vector<int64_t> kernel_size,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -711,13 +711,6 @@ void XLATensor::as_strided_(XLATensorPtr& input, std::vector<int64_t> size,
711711
}
712712
}
713713

714-
XLATensorPtr XLATensor::atan2(
715-
const XLATensorPtr& input, const XLATensorPtr& other,
716-
c10::optional<at::ScalarType> logical_element_type) {
717-
return input->CreateFrom(Atan2(input->GetIrValue(), other->GetIrValue()),
718-
logical_element_type);
719-
}
720-
721714
XLATensorPtr XLATensor::avg_pool_nd(const XLATensorPtr& input,
722715
int64_t spatial_dim_count,
723716
std::vector<int64_t> kernel_size,

xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ ir_gen:
8787
- _adaptive_avg_pool2d_backward
8888
- _adaptive_avg_pool3d
8989
- _adaptive_avg_pool3d_backward
90+
- atan2
9091
- bitwise_and.Tensor
9192
- bitwise_or.Tensor
9293
- bitwise_xor.Tensor

0 commit comments

Comments
 (0)