diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index c6993f2aff3d..e2bffd1e68db 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -345,4 +345,64 @@ xla::XlaOp BuildSelu(xla::XlaOp input) { xla::Min(zero, alpha * (xla::Exp(input) - one))); } +std::vector BuildLogSigmoid(xla::XlaOp input) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp neg_input = xla::Neg(input); + xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type()); + xla::XlaOp max_elem = xla::Max(zero, neg_input); + xla::XlaOp buffer = + xla::Exp(xla::Neg(max_elem)) + xla::Exp(neg_input - max_elem); + xla::XlaOp output = xla::Neg(max_elem + xla::Log(buffer)); + return {output, buffer}; +} + +xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp buffer) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type()); + xla::XlaOp one = XlaHelpers::ScalarValue(1.0, shape.element_type(), + input.builder()); + xla::XlaOp minus_one = XlaHelpers::ScalarValue( + -1.0, shape.element_type(), input.builder()); + + xla::XlaOp max_deriv = xla::Select(xla::Lt(input, zero), minus_one, zero); + xla::XlaOp sign = xla::Select(xla::Lt(input, zero), one, minus_one); + return grad_output * (xla::Neg(max_deriv) - sign * (buffer - one) / buffer); +} + +xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha, + const at::Scalar& scale, const at::Scalar& input_scale) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); + xla::XlaOp scaled_input = + input * XlaHelpers::ScalarValue(input_scale, shape.element_type(), + input.builder()); + xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type()); + xla::XlaOp one = XlaHelpers::ScalarValue(1.0, shape.element_type(), + input.builder()); + xla::XlaOp alpha_scalar = + XlaHelpers::ScalarValue(alpha, shape.element_type(), input.builder()); + xla::XlaOp scale_scalar = + XlaHelpers::ScalarValue(scale, shape.element_type(), input.builder()); + return xla::Select(xla::Le(input, zero), + alpha_scalar * (xla::Exp(scaled_input) - one), input) * + scale_scalar; +} + +xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output, + const at::Scalar& alpha, const at::Scalar& scale, + const at::Scalar& input_scale) { + const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(output); + xla::XlaOp zero = xla::Zero(output.builder(), shape.element_type()); + xla::XlaOp alpha_scalar = + XlaHelpers::ScalarValue(alpha, shape.element_type(), output.builder()); + xla::XlaOp scale_scalar = + XlaHelpers::ScalarValue(scale, shape.element_type(), output.builder()); + xla::XlaOp input_scale_scalar = XlaHelpers::ScalarValue( + input_scale, shape.element_type(), output.builder()); + xla::XlaOp negative_output_branch = + input_scale_scalar * (output + alpha_scalar * scale_scalar); + return grad_output * xla::Select(xla::Gt(output, zero), scale_scalar, + negative_output_branch); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index 907ef0d3f999..3ffcf899cfce 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -89,4 +89,20 @@ xla::XlaOp BuildCelu(xla::XlaOp input, const at::Scalar& alpha); // SELU(x)=scale*(max(0,x)+min(0,a*(exp(x)−1))) xla::XlaOp BuildSelu(xla::XlaOp input); +// Computes the LogSigmoid function of input. +std::vector BuildLogSigmoid(xla::XlaOp input); + +// Computes the backward of LogSigmoid. +xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp buffer); + +// Computes the Elu function of input. +xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha, + const at::Scalar& scale, const at::Scalar& input_scale); + +// Computes the backward of Elu. +xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output, + const at::Scalar& alpha, const at::Scalar& scale, + const at::Scalar& input_scale); + } // namespace torch_xla diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index e251e6da08c6..acf4520e4172 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -244,30 +244,30 @@ torch::lazy::NodePtr HardSwishBackward(const XlaValue& grad_output, std::move(lower_fn)); } -std::tuple LogSigmoid( - const XlaValue& input) { - ScopePusher ir_scope(at::aten::log_sigmoid.toQualString()); - // Use log-sum-exp trick to avoid overflow. - torch::lazy::NodePtr neg_input = Neg(input); - torch::lazy::NodePtr max_elem = - Max(ScalarOp(0, input.xla_shape()), neg_input); - torch::lazy::NodePtr buffer = Exp(Neg(max_elem)) + Exp(neg_input - max_elem); - torch::lazy::NodePtr output = Neg(max_elem + Log(buffer)); - return std::make_tuple(output, buffer); +torch::lazy::NodePtr LogSigmoid(const XlaValue& input) { + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + return node.ReturnOps(BuildLogSigmoid(xla_input), loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::log_sigmoid), {input}, + input.xla_shape(), std::move(lower_fn), /*num_outputs=*/2); } torch::lazy::NodePtr LogSigmoidBackward(const XlaValue& grad_output, const XlaValue& input, const XlaValue& buffer) { - ScopePusher ir_scope(at::aten::log_sigmoid_backward.toQualString()); - torch::lazy::NodePtr zero = ScalarOp(0, input.xla_shape()); - torch::lazy::NodePtr one = ScalarOp(1, input.xla_shape()); - torch::lazy::NodePtr minus_one = ScalarOp(-1, input.xla_shape()); - torch::lazy::NodePtr max_deriv = - Where(ComparisonOp(at::aten::lt, input, zero), minus_one, zero); - torch::lazy::NodePtr sign = - Where(ComparisonOp(at::aten::lt, input, zero), one, minus_one); - return grad_output * (Neg(max_deriv) - sign * (buffer - one) / buffer); + auto lower_fn = [](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1)); + xla::XlaOp xla_buffer = loctx->GetOutputOp(node.operand(2)); + return node.ReturnOp( + BuildLogSigmoidBackward(xla_grad_output, xla_input, xla_buffer), loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::log_sigmoid_backward), + {grad_output, input, buffer}, input.xla_shape(), + std::move(lower_fn)); } torch::lazy::NodePtr SiLU(const XlaValue& input) { @@ -718,15 +718,13 @@ torch::lazy::NodePtr Identity(int64_t lines, int64_t cols, torch::lazy::NodePtr Elu(const XlaValue& input, const at::Scalar& alpha, const at::Scalar& scale, const at::Scalar& input_scale) { - ScopePusher ir_scope(at::aten::elu.toQualString()); - const xla::Shape& shape = input.xla_shape(); - torch::lazy::NodePtr scaled_input = input * ScalarOp(input_scale, shape); - torch::lazy::NodePtr zero = ScalarOp(0, shape); - torch::lazy::NodePtr one = ScalarOp(1, shape); - torch::lazy::NodePtr alpha_scalar = ScalarOp(alpha, shape); - return Where(ComparisonOp(at::aten::le, input, zero), - alpha_scalar * (Exp(scaled_input) - one), input) * - ScalarOp(scale, shape); + auto lower_fn = [=](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); + return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::elu), {input}, + input.xla_shape(), std::move(lower_fn)); } torch::lazy::NodePtr EluBackward(const XlaValue& grad_output, @@ -734,15 +732,17 @@ torch::lazy::NodePtr EluBackward(const XlaValue& grad_output, const at::Scalar& alpha, const at::Scalar& scale, const at::Scalar& input_scale) { - ScopePusher ir_scope(at::aten::elu_backward.toQualString()); - const xla::Shape& shape = grad_output.xla_shape(); - torch::lazy::NodePtr negative_output_branch = - ScalarOp(input_scale, shape) * - (output + ScalarOp(alpha, shape) * ScalarOp(scale, shape)); - torch::lazy::NodePtr positive_output_branch = ScalarOp(scale, shape); - return grad_output * - Where(ComparisonOp(at::aten::gt, output, ScalarOp(0, shape)), - positive_output_branch, negative_output_branch); + auto lower_fn = [=](const XlaNode& node, + LoweringContext* loctx) -> XlaOpVector { + xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0)); + xla::XlaOp xla_output = loctx->GetOutputOp(node.operand(1)); + return node.ReturnOp(BuildEluBackward(xla_grad_output, xla_output, alpha, + scale, input_scale), + loctx); + }; + return GenericOp(torch::lazy::OpKind(at::aten::elu_backward), + {grad_output, output}, output.xla_shape(), + std::move(lower_fn)); } torch::lazy::NodePtr Gelu(const XlaValue& input) { diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 419d28f649de..4d8b54950f27 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -136,8 +136,7 @@ torch::lazy::NodePtr HardSwish(const XlaValue& input); torch::lazy::NodePtr HardSwishBackward(const XlaValue& grad_output, const XlaValue& input); -std::tuple LogSigmoid( - const XlaValue& input); +torch::lazy::NodePtr LogSigmoid(const XlaValue& input); torch::lazy::NodePtr LogSigmoidBackward(const XlaValue& grad_output, const XlaValue& input, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 49cf25391d1f..20d7ed9cf359 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1698,14 +1698,15 @@ XLATensor XLATensor::log_base(const XLATensor& input, torch::lazy::OpKind op, } XLATensor XLATensor::log_sigmoid(const XLATensor& input) { - return input.CreateFrom(std::get<0>(LogSigmoid(input.GetIrValue()))); + torch::lazy::NodePtr node = LogSigmoid(input.GetIrValue()); + return input.CreateFrom(XlaValue(node, 0)); } std::tuple XLATensor::log_sigmoid_forward( const XLATensor& input) { - auto output_and_buffer = LogSigmoid(input.GetIrValue()); - return std::make_tuple(input.CreateFrom(std::get<0>(output_and_buffer)), - input.CreateFrom(std::get<1>(output_and_buffer))); + torch::lazy::NodePtr node = LogSigmoid(input.GetIrValue()); + return std::make_tuple(input.CreateFrom(XlaValue(node, 0)), + input.CreateFrom(XlaValue(node, 1))); } XLATensor XLATensor::log_sigmoid_backward(const XLATensor& grad_output,