Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,64 @@ xla::XlaOp BuildSelu(xla::XlaOp input) {
xla::Min(zero, alpha * (xla::Exp(input) - one)));
}

std::vector<xla::XlaOp> BuildLogSigmoid(xla::XlaOp input) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp neg_input = xla::Neg(input);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
xla::XlaOp max_elem = xla::Max(zero, neg_input);
xla::XlaOp buffer =
xla::Exp(xla::Neg(max_elem)) + xla::Exp(neg_input - max_elem);
xla::XlaOp output = xla::Neg(max_elem + xla::Log(buffer));
return {output, buffer};
}

xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
xla::XlaOp buffer) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
xla::XlaOp one = XlaHelpers::ScalarValue<float>(1.0, shape.element_type(),
input.builder());
xla::XlaOp minus_one = XlaHelpers::ScalarValue<float>(
-1.0, shape.element_type(), input.builder());

xla::XlaOp max_deriv = xla::Select(xla::Lt(input, zero), minus_one, zero);
xla::XlaOp sign = xla::Select(xla::Lt(input, zero), one, minus_one);
return grad_output * (xla::Neg(max_deriv) - sign * (buffer - one) / buffer);
}

xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
const at::Scalar& scale, const at::Scalar& input_scale) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp scaled_input =
input * XlaHelpers::ScalarValue(input_scale, shape.element_type(),
input.builder());
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
xla::XlaOp one = XlaHelpers::ScalarValue<float>(1.0, shape.element_type(),
input.builder());
xla::XlaOp alpha_scalar =
XlaHelpers::ScalarValue(alpha, shape.element_type(), input.builder());
xla::XlaOp scale_scalar =
XlaHelpers::ScalarValue(scale, shape.element_type(), input.builder());
return xla::Select(xla::Le(input, zero),
alpha_scalar * (xla::Exp(scaled_input) - one), input) *
scale_scalar;
}

xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output,
const at::Scalar& alpha, const at::Scalar& scale,
const at::Scalar& input_scale) {
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(output);
xla::XlaOp zero = xla::Zero(output.builder(), shape.element_type());
xla::XlaOp alpha_scalar =
XlaHelpers::ScalarValue(alpha, shape.element_type(), output.builder());
xla::XlaOp scale_scalar =
XlaHelpers::ScalarValue(scale, shape.element_type(), output.builder());
xla::XlaOp input_scale_scalar = XlaHelpers::ScalarValue(
input_scale, shape.element_type(), output.builder());
xla::XlaOp negative_output_branch =
input_scale_scalar * (output + alpha_scalar * scale_scalar);
return grad_output * xla::Select(xla::Gt(output, zero), scale_scalar,
negative_output_branch);
}

} // namespace torch_xla
16 changes: 16 additions & 0 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,20 @@ xla::XlaOp BuildCelu(xla::XlaOp input, const at::Scalar& alpha);
// SELU(x)=scale*(max(0,x)+min(0,a*(exp(x)−1)))
xla::XlaOp BuildSelu(xla::XlaOp input);

// Computes the LogSigmoid function of input.
std::vector<xla::XlaOp> BuildLogSigmoid(xla::XlaOp input);

// Computes the backward of LogSigmoid.
xla::XlaOp BuildLogSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input,
xla::XlaOp buffer);

// Computes the Elu function of input.
xla::XlaOp BuildElu(xla::XlaOp input, const at::Scalar& alpha,
const at::Scalar& scale, const at::Scalar& input_scale);

// Computes the backward of Elu.
xla::XlaOp BuildEluBackward(xla::XlaOp grad_output, xla::XlaOp output,
const at::Scalar& alpha, const at::Scalar& scale,
const at::Scalar& input_scale);

} // namespace torch_xla
74 changes: 37 additions & 37 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,30 +244,30 @@ torch::lazy::NodePtr HardSwishBackward(const XlaValue& grad_output,
std::move(lower_fn));
}

std::tuple<torch::lazy::NodePtr, torch::lazy::NodePtr> LogSigmoid(
const XlaValue& input) {
ScopePusher ir_scope(at::aten::log_sigmoid.toQualString());
// Use log-sum-exp trick to avoid overflow.
torch::lazy::NodePtr neg_input = Neg(input);
torch::lazy::NodePtr max_elem =
Max(ScalarOp(0, input.xla_shape()), neg_input);
torch::lazy::NodePtr buffer = Exp(Neg(max_elem)) + Exp(neg_input - max_elem);
torch::lazy::NodePtr output = Neg(max_elem + Log(buffer));
return std::make_tuple(output, buffer);
torch::lazy::NodePtr LogSigmoid(const XlaValue& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOps(BuildLogSigmoid(xla_input), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::log_sigmoid), {input},
input.xla_shape(), std::move(lower_fn), /*num_outputs=*/2);
}

torch::lazy::NodePtr LogSigmoidBackward(const XlaValue& grad_output,
const XlaValue& input,
const XlaValue& buffer) {
ScopePusher ir_scope(at::aten::log_sigmoid_backward.toQualString());
torch::lazy::NodePtr zero = ScalarOp(0, input.xla_shape());
torch::lazy::NodePtr one = ScalarOp(1, input.xla_shape());
torch::lazy::NodePtr minus_one = ScalarOp(-1, input.xla_shape());
torch::lazy::NodePtr max_deriv =
Where(ComparisonOp(at::aten::lt, input, zero), minus_one, zero);
torch::lazy::NodePtr sign =
Where(ComparisonOp(at::aten::lt, input, zero), one, minus_one);
return grad_output * (Neg(max_deriv) - sign * (buffer - one) / buffer);
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_buffer = loctx->GetOutputOp(node.operand(2));
return node.ReturnOp(
BuildLogSigmoidBackward(xla_grad_output, xla_input, xla_buffer), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::log_sigmoid_backward),
{grad_output, input, buffer}, input.xla_shape(),
std::move(lower_fn));
}

torch::lazy::NodePtr SiLU(const XlaValue& input) {
Expand Down Expand Up @@ -718,31 +718,31 @@ torch::lazy::NodePtr Identity(int64_t lines, int64_t cols,
torch::lazy::NodePtr Elu(const XlaValue& input, const at::Scalar& alpha,
const at::Scalar& scale,
const at::Scalar& input_scale) {
ScopePusher ir_scope(at::aten::elu.toQualString());
const xla::Shape& shape = input.xla_shape();
torch::lazy::NodePtr scaled_input = input * ScalarOp(input_scale, shape);
torch::lazy::NodePtr zero = ScalarOp(0, shape);
torch::lazy::NodePtr one = ScalarOp(1, shape);
torch::lazy::NodePtr alpha_scalar = ScalarOp(alpha, shape);
return Where(ComparisonOp(at::aten::le, input, zero),
alpha_scalar * (Exp(scaled_input) - one), input) *
ScalarOp(scale, shape);
auto lower_fn = [=](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::elu), {input},
input.xla_shape(), std::move(lower_fn));
}

torch::lazy::NodePtr EluBackward(const XlaValue& grad_output,
const XlaValue& output,
const at::Scalar& alpha,
const at::Scalar& scale,
const at::Scalar& input_scale) {
ScopePusher ir_scope(at::aten::elu_backward.toQualString());
const xla::Shape& shape = grad_output.xla_shape();
torch::lazy::NodePtr negative_output_branch =
ScalarOp(input_scale, shape) *
(output + ScalarOp(alpha, shape) * ScalarOp(scale, shape));
torch::lazy::NodePtr positive_output_branch = ScalarOp(scale, shape);
return grad_output *
Where(ComparisonOp(at::aten::gt, output, ScalarOp(0, shape)),
positive_output_branch, negative_output_branch);
auto lower_fn = [=](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_output = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(BuildEluBackward(xla_grad_output, xla_output, alpha,
scale, input_scale),
loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::elu_backward),
{grad_output, output}, output.xla_shape(),
std::move(lower_fn));
}

torch::lazy::NodePtr Gelu(const XlaValue& input) {
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ torch::lazy::NodePtr HardSwish(const XlaValue& input);
torch::lazy::NodePtr HardSwishBackward(const XlaValue& grad_output,
const XlaValue& input);

std::tuple<torch::lazy::NodePtr, torch::lazy::NodePtr> LogSigmoid(
const XlaValue& input);
torch::lazy::NodePtr LogSigmoid(const XlaValue& input);

torch::lazy::NodePtr LogSigmoidBackward(const XlaValue& grad_output,
const XlaValue& input,
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,14 +1698,15 @@ XLATensor XLATensor::log_base(const XLATensor& input, torch::lazy::OpKind op,
}

XLATensor XLATensor::log_sigmoid(const XLATensor& input) {
return input.CreateFrom(std::get<0>(LogSigmoid(input.GetIrValue())));
torch::lazy::NodePtr node = LogSigmoid(input.GetIrValue());
return input.CreateFrom(XlaValue(node, 0));
}

std::tuple<XLATensor, XLATensor> XLATensor::log_sigmoid_forward(
const XLATensor& input) {
auto output_and_buffer = LogSigmoid(input.GetIrValue());
return std::make_tuple(input.CreateFrom(std::get<0>(output_and_buffer)),
input.CreateFrom(std::get<1>(output_and_buffer)));
torch::lazy::NodePtr node = LogSigmoid(input.GetIrValue());
return std::make_tuple(input.CreateFrom(XlaValue(node, 0)),
input.CreateFrom(XlaValue(node, 1)));
}

XLATensor XLATensor::log_sigmoid_backward(const XLATensor& grad_output,
Expand Down