Skip to content

Commit 76df130

Browse files
authored
Codegen elu and elu_ (#3893)
* Codegen elu * Add MaybeCastTo in type casting
1 parent 3f47aba commit 76df130

File tree

11 files changed

+32
-65
lines changed

11 files changed

+32
-65
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,24 +1157,6 @@ at::Tensor XLANativeFunctions::dot(const at::Tensor& self,
11571157
bridge::GetXlaTensor(self), bridge::GetXlaTensor(tensor)));
11581158
}
11591159

1160-
at::Tensor XLANativeFunctions::elu(const at::Tensor& self,
1161-
const at::Scalar& alpha,
1162-
const at::Scalar& scale,
1163-
const at::Scalar& input_scale) {
1164-
XLA_FN_COUNTER("xla::");
1165-
return bridge::AtenFromXlaTensor(
1166-
XLATensor::elu(bridge::GetXlaTensor(self), alpha, scale, input_scale));
1167-
}
1168-
1169-
at::Tensor& XLANativeFunctions::elu_(at::Tensor& self, const at::Scalar& alpha,
1170-
const at::Scalar& scale,
1171-
const at::Scalar& input_scale) {
1172-
XLA_FN_COUNTER("xla::");
1173-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
1174-
XLATensor::elu_(self_tensor, alpha, scale, input_scale);
1175-
return self;
1176-
}
1177-
11781160
at::Tensor XLANativeFunctions::elu_backward(const at::Tensor& grad_output,
11791161
const at::Scalar& alpha,
11801162
const at::Scalar& scale,

torch_xla/csrc/elementwise.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -370,22 +370,19 @@ xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
370370
return grad_output * (xla::Neg(max_deriv) - sign * (buffer - one) / buffer);
371371
}
372372

373-
xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
374-
const at::Scalar& scale, const at::Scalar& input_scale) {
373+
xla::XlaOp BuildElu(xla::XlaOp input, xla::XlaOp alpha, xla::XlaOp scale,
374+
xla::XlaOp input_scale) {
375375
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
376-
xla::XlaOp scaled_input =
377-
input * XlaHelpers::ScalarValue(input_scale, shape.element_type(),
378-
input.builder());
376+
alpha = MaybeConvertTo(alpha, shape.element_type());
377+
scale = MaybeConvertTo(scale, shape.element_type());
378+
input_scale = MaybeConvertTo(input_scale, shape.element_type());
379+
xla::XlaOp scaled_input = input * input_scale;
379380
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
380381
xla::XlaOp one = XlaHelpers::ScalarValue<float>(1.0, shape.element_type(),
381382
input.builder());
382-
xla::XlaOp alpha_scalar =
383-
XlaHelpers::ScalarValue(alpha, shape.element_type(), input.builder());
384-
xla::XlaOp scale_scalar =
385-
XlaHelpers::ScalarValue(scale, shape.element_type(), input.builder());
386383
return xla::Select(xla::Le(input, zero),
387-
alpha_scalar * (xla::Exp(scaled_input) - one), input) *
388-
scale_scalar;
384+
alpha * (xla::Exp(scaled_input) - one), input) *
385+
scale;
389386
}
390387

391388
xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output,

torch_xla/csrc/elementwise.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
9797
xla::XlaOp buffer);
9898

9999
// Computes the Elu function of input.
100-
xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
101-
const at::Scalar& scale, const at::Scalar& input_scale);
100+
xla::XlaOp BuildElu(xla::XlaOp input, xla::XlaOp alpha, xla::XlaOp scale,
101+
xla::XlaOp input_scale);
102102

103103
// Computes the backward of Elu.
104104
xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output,

torch_xla/csrc/ops/ops.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -532,18 +532,6 @@ torch::lazy::NodePtr Identity(int64_t lines, int64_t cols,
532532
torch::lazy::MHash(lines, cols));
533533
}
534534

535-
torch::lazy::NodePtr Elu(const torch::lazy::Value& input,
536-
const at::Scalar& alpha, const at::Scalar& scale,
537-
const at::Scalar& input_scale) {
538-
auto lower_fn = [=](const XlaNode& node,
539-
LoweringContext* loctx) -> XlaOpVector {
540-
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
541-
return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx);
542-
};
543-
return GenericOp(torch::lazy::OpKind(at::aten::elu), {input},
544-
GetXlaShape(input), std::move(lower_fn));
545-
}
546-
547535
torch::lazy::NodePtr EluBackward(const torch::lazy::Value& grad_output,
548536
const torch::lazy::Value& output,
549537
const at::Scalar& alpha,

torch_xla/csrc/ops/ops.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,6 @@ torch::lazy::NodePtr Norm(const torch::lazy::Value& input,
175175
torch::lazy::NodePtr Identity(int64_t lines, int64_t cols,
176176
xla::PrimitiveType element_type);
177177

178-
torch::lazy::NodePtr Elu(const torch::lazy::Value& input,
179-
const at::Scalar& alpha, const at::Scalar& scale,
180-
const at::Scalar& input_scale);
181-
182178
torch::lazy::NodePtr EluBackward(const torch::lazy::Value& grad_output,
183179
const torch::lazy::Value& output,
184180
const at::Scalar& alpha,

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ torch_xla::XlaOpVector Cosh::Lower(LoweringContext* loctx) const {
189189
return ReturnOp(xla::Cosh(xla_input), loctx);
190190
}
191191

192+
torch_xla::XlaOpVector Elu::Lower(LoweringContext* loctx) const {
193+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
194+
xla::XlaOp xla_alpha = loctx->GetOutputOp(operand(1));
195+
xla::XlaOp xla_scale = loctx->GetOutputOp(operand(2));
196+
xla::XlaOp xla_input_scale = loctx->GetOutputOp(operand(3));
197+
return ReturnOp(BuildElu(xla_input, xla_alpha, xla_scale, xla_input_scale),
198+
loctx);
199+
}
200+
192201
torch_xla::XlaOpVector Erf::Lower(LoweringContext* loctx) const {
193202
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
194203
return ReturnOp(xla::Erf(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) {
256256
return GetXlaShape(input);
257257
}
258258

259+
xla::Shape EluOutputShape(const torch::lazy::Value& input,
260+
const torch::lazy::Value& alpha,
261+
const torch::lazy::Value& scale,
262+
const torch::lazy::Value& input_scale) {
263+
return GetXlaShape(input);
264+
}
265+
259266
xla::Shape ErfOutputShape(const torch::lazy::Value& input) {
260267
return GetXlaShape(input);
261268
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ xla::Shape CosOutputShape(const torch::lazy::Value& input);
7272

7373
xla::Shape CoshOutputShape(const torch::lazy::Value& input);
7474

75+
xla::Shape EluOutputShape(const torch::lazy::Value& input,
76+
const torch::lazy::Value& alpha,
77+
const torch::lazy::Value& scale,
78+
const torch::lazy::Value& input_scale);
79+
7580
xla::Shape ErfOutputShape(const torch::lazy::Value& input);
7681

7782
xla::Shape ErfcOutputShape(const torch::lazy::Value& input);

torch_xla/csrc/tensor.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,6 @@ class XLATensor : public c10::intrusive_ptr_target {
550550
static XLATensorPtr einsum(const std::string& equation,
551551
absl::Span<const XLATensorPtr> tensors);
552552

553-
static XLATensorPtr elu(const XLATensorPtr& input, const at::Scalar& alpha,
554-
const at::Scalar& scale,
555-
const at::Scalar& input_scale);
556-
static void elu_(XLATensorPtr& input, const at::Scalar& alpha,
557-
const at::Scalar& scale, const at::Scalar& input_scale);
558553
static XLATensorPtr elu_backward(const XLATensorPtr& grad_output,
559554
const at::Scalar& alpha,
560555
const at::Scalar& scale,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,17 +1151,6 @@ XLATensorPtr XLATensor::eq(const XLATensorPtr& input,
11511151
return DispatchComparisonOp(at::aten::eq, input, other);
11521152
}
11531153

1154-
XLATensorPtr XLATensor::elu(const XLATensorPtr& input, const at::Scalar& alpha,
1155-
const at::Scalar& scale,
1156-
const at::Scalar& input_scale) {
1157-
return input->CreateFrom(Elu(input->GetIrValue(), alpha, scale, input_scale));
1158-
}
1159-
1160-
void XLATensor::elu_(XLATensorPtr& input, const at::Scalar& alpha,
1161-
const at::Scalar& scale, const at::Scalar& input_scale) {
1162-
input->SetInPlaceIrValue(Elu(input->GetIrValue(), alpha, scale, input_scale));
1163-
}
1164-
11651154
XLATensorPtr XLATensor::elu_backward(const XLATensorPtr& grad_output,
11661155
const at::Scalar& alpha,
11671156
const at::Scalar& scale,

0 commit comments

Comments
 (0)