Skip to content

Commit 18b56f3

Browse files
committed
Lower Elu and EluBackward
1 parent 8a5b336 commit 18b56f3

File tree

3 files changed

+62
-18
lines changed

3 files changed

+62
-18
lines changed

torch_xla/csrc/elementwise.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,4 +342,36 @@ xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output,
342342
return grad_output * (xla::Neg(max_deriv) - sign * (buffer - one) / buffer);
343343
}
344344

345+
xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
346+
const at::Scalar& scale,
347+
const at::Scalar& input_scale) {
348+
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
349+
xla::XlaOp scaled_input = input * XlaHelpers::ScalarValue(input_scale, shape.element_type(), input.builder());
350+
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
351+
xla::XlaOp one = XlaHelpers::ScalarValue<float>(1.0, shape.element_type(),
352+
input.builder());
353+
xla::XlaOp alpha_scalar = XlaHelpers::ScalarValue(alpha, shape.element_type(), input.builder());
354+
xla::XlaOp scale_scalar = XlaHelpers::ScalarValue(scale, shape.element_type(), input.builder());
355+
return xla::Select(xla::Le(input, zero), alpha_scalar * (xla::Exp(scaled_input) - one), input) * scale_scalar;
356+
357+
358+
// XlaHelpers::ScalarValue(lower, shape.element_type(), input.builder());
359+
360+
}
361+
362+
xla::XlaOp BuildEluBackward(xla::XlaOp grad_output,
363+
xla::XlaOp output,
364+
const at::Scalar& alpha,
365+
const at::Scalar& scale,
366+
const at::Scalar& input_scale) {
367+
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(output);
368+
xla::XlaOp zero = xla::Zero(output.builder(), shape.element_type());
369+
xla::XlaOp alpha_scalar = XlaHelpers::ScalarValue(alpha, shape.element_type(), output.builder());
370+
xla::XlaOp scale_scalar = XlaHelpers::ScalarValue(scale, shape.element_type(), output.builder());
371+
xla::XlaOp input_scale_scalar = XlaHelpers::ScalarValue(input_scale, shape.element_type(), output.builder());
372+
xla::XlaOp negative_output_branch = input_scale_scalar * (output + alpha_scalar * scale_scalar);
373+
return grad_output * xla::Select(xla::Gt(output, zero), scale_scalar, negative_output_branch);
374+
375+
}
376+
345377
} // namespace torch_xla

torch_xla/csrc/elementwise.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,21 @@ xla::XlaOp BuildGeluBackward(xla::XlaOp grad_output, xla::XlaOp input);
8484
// Computes the LogSigmoid function of input.
8585
std::vector<xla::XlaOp> BuildLogSigmoid(xla::XlaOp input);
8686

87+
// Computes the backward of LogSigmoid.
8788
xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output,
8889
xla::XlaOp input,
8990
xla::XlaOp buffer);
9091

92+
// Computes the Elu function of input.
93+
xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
94+
const at::Scalar& scale,
95+
const at::Scalar& input_scale);
96+
97+
// Computes the backward of Elu.
98+
xla::XlaOp BuildEluBackward(xla::XlaOp grad_output,
99+
xla::XlaOp output,
100+
const at::Scalar& alpha,
101+
const at::Scalar& scale,
102+
const at::Scalar& input_scale);
103+
91104
} // namespace torch_xla

torch_xla/csrc/ops/ops.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -709,31 +709,30 @@ torch::lazy::NodePtr Identity(int64_t lines, int64_t cols,
709709
torch::lazy::NodePtr Elu(const XlaValue& input, const at::Scalar& alpha,
710710
const at::Scalar& scale,
711711
const at::Scalar& input_scale) {
712-
ScopePusher ir_scope(at::aten::elu.toQualString());
713-
const xla::Shape& shape = input.xla_shape();
714-
torch::lazy::NodePtr scaled_input = input * ScalarOp(input_scale, shape);
715-
torch::lazy::NodePtr zero = ScalarOp(0, shape);
716-
torch::lazy::NodePtr one = ScalarOp(1, shape);
717-
torch::lazy::NodePtr alpha_scalar = ScalarOp(alpha, shape);
718-
return Where(ComparisonOp(at::aten::le, input, zero),
719-
alpha_scalar * (Exp(scaled_input) - one), input) *
720-
ScalarOp(scale, shape);
712+
auto lower_fn = [=](const XlaNode& node,
713+
LoweringContext* loctx) -> XlaOpVector {
714+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
715+
return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx);
716+
};
717+
return GenericOp(torch::lazy::OpKind(at::aten::elu),
718+
{input}, input.xla_shape(),
719+
std::move(lower_fn));
721720
}
722721

723722
torch::lazy::NodePtr EluBackward(const XlaValue& grad_output,
724723
const XlaValue& output,
725724
const at::Scalar& alpha,
726725
const at::Scalar& scale,
727726
const at::Scalar& input_scale) {
728-
ScopePusher ir_scope(at::aten::elu_backward.toQualString());
729-
const xla::Shape& shape = grad_output.xla_shape();
730-
torch::lazy::NodePtr negative_output_branch =
731-
ScalarOp(input_scale, shape) *
732-
(output + ScalarOp(alpha, shape) * ScalarOp(scale, shape));
733-
torch::lazy::NodePtr positive_output_branch = ScalarOp(scale, shape);
734-
return grad_output *
735-
Where(ComparisonOp(at::aten::gt, output, ScalarOp(0, shape)),
736-
positive_output_branch, negative_output_branch);
727+
auto lower_fn = [=](const XlaNode& node,
728+
LoweringContext* loctx) -> XlaOpVector {
729+
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
730+
xla::XlaOp xla_output = loctx->GetOutputOp(node.operand(1));
731+
return node.ReturnOp(BuildEluBackward(xla_grad_output, xla_output, alpha, scale, input_scale), loctx);
732+
};
733+
return GenericOp(torch::lazy::OpKind(at::aten::elu_backward),
734+
{grad_output, output}, output.xla_shape(),
735+
std::move(lower_fn));
737736
}
738737

739738
torch::lazy::NodePtr Gelu(const XlaValue& input) {

0 commit comments

Comments
 (0)