Skip to content

Commit 85ce2b0

Browse files
committed
Codegen addcdiv and addcmul
1 parent ce1bd4e commit 85ce2b0

File tree

6 files changed

+61
-63
lines changed

6 files changed

+61
-63
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -597,37 +597,6 @@ at::Tensor XLANativeFunctions::add(const at::Tensor& self,
597597
});
598598
}
599599

600-
at::Tensor XLANativeFunctions::addcdiv(const at::Tensor& self,
601-
const at::Tensor& tensor1,
602-
const at::Tensor& tensor2,
603-
const at::Scalar& value) {
604-
XLA_FN_COUNTER("xla::");
605-
return bridge::AtenFromXlaTensor(XLATensor::addcdiv(
606-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
607-
bridge::GetXlaTensor(tensor2)));
608-
}
609-
610-
at::Tensor& XLANativeFunctions::addcdiv_(at::Tensor& self,
611-
const at::Tensor& tensor1,
612-
const at::Tensor& tensor2,
613-
const at::Scalar& value) {
614-
XLA_FN_COUNTER("xla::");
615-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
616-
XLATensor::addcdiv_(self_tensor, value, bridge::GetXlaTensor(tensor1),
617-
bridge::GetXlaTensor(tensor2));
618-
return self;
619-
}
620-
621-
at::Tensor XLANativeFunctions::addcmul(const at::Tensor& self,
622-
const at::Tensor& tensor1,
623-
const at::Tensor& tensor2,
624-
const at::Scalar& value) {
625-
XLA_FN_COUNTER("xla::");
626-
return bridge::AtenFromXlaTensor(XLATensor::addcmul(
627-
bridge::GetXlaTensor(self), value, bridge::GetXlaTensor(tensor1),
628-
bridge::GetXlaTensor(tensor2)));
629-
}
630-
631600
at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
632601
const at::Tensor& mat1,
633602
const at::Tensor& mat2,

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,31 @@ torch_xla::XlaOpVector Acosh::Lower(LoweringContext* loctx) const {
2323
return ReturnOp(xla::Acosh(xla_input), loctx);
2424
}
2525

26+
torch_xla::XlaOpVector Addcdiv::Lower(LoweringContext* loctx) const {
27+
// xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
28+
// torch::lazy::Value constant = GetIrValueForScalar(
29+
// value, tensor1->shape().get().element_type(), input->GetDevice());
30+
// torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
31+
// return input->CreateFrom(input->GetIrValue() + div * constant);
32+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
33+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
34+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
35+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
36+
return ReturnOp(xla_input + (xla_t1 / xla_t2) * xla_val, loctx);
37+
}
38+
39+
torch_xla::XlaOpVector Addcmul::Lower(LoweringContext* loctx) const {
40+
// torch::lazy::Value constant = GetIrValueForScalar(
41+
// value, tensor1->shape().get().element_type(), input->GetDevice());
42+
// torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
43+
// return input->CreateFrom(input->GetIrValue() + mul * constant);
44+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
45+
xla::XlaOp xla_t1 = loctx->GetOutputOp(operand(1));
46+
xla::XlaOp xla_t2 = loctx->GetOutputOp(operand(2));
47+
xla::XlaOp xla_val = loctx->GetOutputOp(operand(3));
48+
return ReturnOp(xla_input + (xla_t1 * xla_t2) * xla_val, loctx);
49+
}
50+
2651
torch_xla::XlaOpVector Asin::Lower(LoweringContext* loctx) const {
2752
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
2853
return ReturnOp(xla::Asin(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,30 @@ xla::Shape AcoshOutputShape(const torch::lazy::Value& input) {
1818
return GetXlaShape(input);
1919
}
2020

21+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
22+
const torch::lazy::Value& t1,
23+
const torch::lazy::Value& t2,
24+
const torch::lazy::Value& value) {
25+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
26+
return operands[0] + (operands[1] / operands[2]) * operands[3];
27+
};
28+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
29+
GetXlaShape(value)},
30+
shape_fn);
31+
}
32+
33+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
34+
const torch::lazy::Value& t1,
35+
const torch::lazy::Value& t2,
36+
const torch::lazy::Value& value) {
37+
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
38+
return operands[0] + (operands[1] * operands[2]) * operands[3];
39+
};
40+
return InferOutputShape({GetXlaShape(input), GetXlaShape(t1), GetXlaShape(t2),
41+
GetXlaShape(value)},
42+
shape_fn);
43+
}
44+
2145
xla::Shape AsinOutputShape(const torch::lazy::Value& input) {
2246
return GetXlaShape(input);
2347
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ xla::Shape AcosOutputShape(const torch::lazy::Value& input);
99

1010
xla::Shape AcoshOutputShape(const torch::lazy::Value& input);
1111

12+
xla::Shape AddcdivOutputShape(const torch::lazy::Value& input,
13+
const torch::lazy::Value& t1,
14+
const torch::lazy::Value& t2,
15+
const torch::lazy::Value& value);
16+
17+
xla::Shape AddcmulOutputShape(const torch::lazy::Value& input,
18+
const torch::lazy::Value& t1,
19+
const torch::lazy::Value& t2,
20+
const torch::lazy::Value& value);
21+
1222
xla::Shape AsinOutputShape(const torch::lazy::Value& input);
1323

1424
xla::Shape AsinhOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -690,35 +690,6 @@ XLATensorPtr XLATensor::add(
690690
logical_element_type);
691691
}
692692

693-
XLATensorPtr XLATensor::addcdiv(const XLATensorPtr& input,
694-
const at::Scalar& value,
695-
const XLATensorPtr& tensor1,
696-
const XLATensorPtr& tensor2) {
697-
torch::lazy::Value constant = GetIrValueForScalar(
698-
value, tensor1->shape().get().element_type(), input->GetDevice());
699-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
700-
return input->CreateFrom(input->GetIrValue() + div * constant);
701-
}
702-
703-
void XLATensor::addcdiv_(XLATensorPtr& input, const at::Scalar& value,
704-
const XLATensorPtr& tensor1,
705-
const XLATensorPtr& tensor2) {
706-
torch::lazy::Value constant = GetIrValueForScalar(
707-
value, tensor1->shape().get().element_type(), input->GetDevice());
708-
torch::lazy::Value div = tensor1->GetIrValue() / tensor2->GetIrValue();
709-
input->SetInPlaceIrValue(input->GetIrValue() + div * constant);
710-
}
711-
712-
XLATensorPtr XLATensor::addcmul(const XLATensorPtr& input,
713-
const at::Scalar& value,
714-
const XLATensorPtr& tensor1,
715-
const XLATensorPtr& tensor2) {
716-
torch::lazy::Value constant = GetIrValueForScalar(
717-
value, tensor1->shape().get().element_type(), input->GetDevice());
718-
torch::lazy::Value mul = tensor1->GetIrValue() * tensor2->GetIrValue();
719-
return input->CreateFrom(input->GetIrValue() + mul * constant);
720-
}
721-
722693
XLATensorPtr XLATensor::addmm(const XLATensorPtr& input,
723694
const XLATensorPtr& weight,
724695
const XLATensorPtr& bias) {

xla_native_functions.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ full_codegen:
44
- acos
55
- acosh
66
- abs
7+
- addcdiv
8+
- addcmul
79
- asin
810
- asinh
911
- atan
@@ -73,9 +75,6 @@ supported:
7375
- adaptive_max_pool2d_backward
7476
- add.Scalar
7577
- add.Tensor
76-
- addcdiv
77-
- addcdiv_
78-
- addcmul
7978
- addmm
8079
- alias
8180
- all

0 commit comments

Comments
 (0)