Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5933,6 +5933,24 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
});
}

TEST_F(AtenXlaTensorTest, TestHardshrinkWithMixedDataType) {
torch::Tensor lambdaTensor =
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
// It seems the below .item() will convert a kFloat64 to a kFloat32 if I
// make the scalar tensor a kFloat32 type.
torch::Scalar lambda = lambdaTensor.item();
torch::Tensor input =
torch::randn({10}, torch::TensorOptions(torch::kFloat64));

torch::Tensor output = torch::hardshrink(input, lambda);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::hardshrink(xla_input, lambda);
AllClose(output, xla_output);
});
}

// Unlike Softshrink, a negative lambda is a valid input for Hardshrink.
TEST_F(AtenXlaTensorTest, TestHardshrinkWithNegativeLambda) {
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Scalar lambd = -0.5;
Expand Down Expand Up @@ -10399,6 +10417,22 @@ TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) {
});
}

TEST_F(AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) {
torch::Tensor lambdaTensor =
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
torch::Scalar lambda = lambdaTensor.item();

auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::hardshrink(inputs[0], lambda);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::randn(
{100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))},
device, testfn);
});
}

TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::softshrink(inputs[0]);
Expand All @@ -10411,6 +10445,22 @@ TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) {
});
}

TEST_F(AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) {
torch::Tensor lambdaTensor =
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
torch::Scalar lambda = lambdaTensor.item();

auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::softshrink(inputs[0], lambda);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward(
{torch::randn(
{100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))},
device, testfn);
});
}

TEST_F(AtenXlaTensorTest, TestHardtanhBackward) {
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::hardtanh(inputs[0]);
Expand Down
16 changes: 0 additions & 16 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,14 +1259,6 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
const at::Tensor& self,
const at::Scalar& lambda) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::hardshrink_backward(
bridge::GetXlaTensor(grad_out), bridge::GetXlaTensor(self), lambda));
}

at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
const at::Scalar& min_val,
const at::Scalar& max_val) {
Expand Down Expand Up @@ -2570,14 +2562,6 @@ at::Tensor XLANativeFunctions::softshrink(const at::Tensor& self,
XLATensor::softshrink(bridge::GetXlaTensor(self), lambda));
}

at::Tensor XLANativeFunctions::softshrink_backward(const at::Tensor& grad_out,
const at::Tensor& self,
const at::Scalar& lambda) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::softshrink_backward(
bridge::GetXlaTensor(grad_out), bridge::GetXlaTensor(self), lambda));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::sort(
const at::Tensor& self, int64_t dim, bool descending) {
XLA_FN_COUNTER("xla::");
Expand Down
16 changes: 13 additions & 3 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,20 @@ xla::XlaOp BuildSoftshrink(xla::XlaOp input, const at::Scalar& lambda) {
}

xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
const at::Scalar& lambda) {
xla::XlaOp lambda) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we may want to do MaybeCast similar to 4ccfc24

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
return xla::Select(Between(input, -lambda, lambda), zero, grad_output);
xla::PrimitiveType input_element_type = shape.element_type();
xla::XlaOp zero = xla::Zero(input.builder(), input_element_type);

// The conversion here is needed because when we do computation such as
// broadcast or subtraction for input and lambda, XLA disallows mixed
// precision for float point types.
lambda = MaybeConvertTo(lambda, input_element_type);
xla::XlaOp check_low = BuildComparisonOp(at::aten::ge, input, zero - lambda);
xla::XlaOp check_high = BuildComparisonOp(at::aten::le, input, lambda);
xla::XlaOp between = xla::And(check_low, check_high);

return xla::Select(between, zero, grad_output);
}

xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ xla::XlaOp BuildHardSwish(xla::XlaOp input);
xla::XlaOp BuildHardSwishBackward(xla::XlaOp grad_output, xla::XlaOp input);
xla::XlaOp BuildSoftshrink(xla::XlaOp input, const at::Scalar& lambda);
xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
const at::Scalar& lambda);
xla::XlaOp lambda);

xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,
const at::Scalar& min_val,
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,13 @@ torch_xla::XlaOpVector Hardshrink::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildHardshrink(xla_input, lambd), loctx);
}

torch_xla::XlaOpVector HardshrinkBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp lambda = loctx->GetOutputOp(operand(2));
return ReturnOp(BuildShrinkBackward(grad_output, input, lambda), loctx);
}

torch_xla::XlaOpVector Hardsigmoid::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildHardSigmoid(xla_input), loctx);
Expand Down Expand Up @@ -532,6 +539,13 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Sinh(xla_input), loctx);
}

torch_xla::XlaOpVector SoftshrinkBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp input = loctx->GetOutputOp(operand(1));
xla::XlaOp lambda = loctx->GetOutputOp(operand(2));
return ReturnOp(BuildShrinkBackward(grad_output, input, lambda), loctx);
}

/* Blocked on https://github.com/pytorch/xla/issues/3596 */
// torch_xla::XlaOpVector Slogdet::Lower(LoweringContext* loctx) const {
// xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
return GetXlaShape(self);
}

xla::Shape HardshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& lambd) {
return GetXlaShape(input);
}

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down Expand Up @@ -623,6 +629,12 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& lambd) {
return GetXlaShape(input);
}

/* Blocked on https://github.com/pytorch/xla/issues/3596 */
// xla::Shape SlogdetOutputShape(const torch::lazy::Value& input) {
// auto lower_for_shape_fn =
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& lambd);

xla::Shape HardshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& lambd);

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input);

xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
Expand Down Expand Up @@ -214,6 +218,9 @@ xla::Shape SinOutputShape(const torch::lazy::Value& input);

xla::Shape SinhOutputShape(const torch::lazy::Value& input);

xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
const torch::lazy::Value& input,
const torch::lazy::Value& lambd);
/* Blocked on https://github.com/pytorch/xla/issues/3596 */
// xla::Shape SlogdetOutputShape(const torch::lazy::Value& input);

Expand Down
35 changes: 0 additions & 35 deletions torch_xla/csrc/ops/shrink_backward.cpp

This file was deleted.

27 changes: 0 additions & 27 deletions torch_xla/csrc/ops/shrink_backward.h

This file was deleted.

17 changes: 0 additions & 17 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
#include "torch_xla/csrc/ops/scatter_add.h"
#include "torch_xla/csrc/ops/send.h"
#include "torch_xla/csrc/ops/sgd_optimizer_step.h"
#include "torch_xla/csrc/ops/shrink_backward.h"
#include "torch_xla/csrc/ops/softmax.h"
#include "torch_xla/csrc/ops/softshrink.h"
#include "torch_xla/csrc/ops/split.h"
Expand Down Expand Up @@ -1383,14 +1382,6 @@ XLATensorPtr XLATensor::le(const XLATensorPtr& input,
return DispatchComparisonOp(at::aten::le, input, other);
}

XLATensorPtr XLATensor::hardshrink_backward(const XLATensorPtr& grad_out,
const XLATensorPtr& input,
const at::Scalar& lambda) {
return input->CreateFrom(torch::lazy::MakeNode<ShrinkBackward>(
torch::lazy::OpKind(at::aten::hardshrink_backward),
grad_out->GetIrValue(), input->GetIrValue(), lambda));
}

XLATensorPtr XLATensor::hardtanh_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
const at::Scalar& min_val,
Expand Down Expand Up @@ -2322,14 +2313,6 @@ XLATensorPtr XLATensor::softshrink(const XLATensorPtr& input,
torch::lazy::MakeNode<Softshrink>(input->GetIrValue(), lambda));
}

XLATensorPtr XLATensor::softshrink_backward(const XLATensorPtr& grad_out,
const XLATensorPtr& input,
const at::Scalar& lambda) {
return input->CreateFrom(torch::lazy::MakeNode<ShrinkBackward>(
torch::lazy::OpKind(at::aten::softshrink_backward),
grad_out->GetIrValue(), input->GetIrValue(), lambda));
}

std::vector<XLATensorPtr> XLATensor::split(const XLATensorPtr& input,
int64_t split_size, int64_t dim) {
auto input_shape = input->shape();
Expand Down
4 changes: 2 additions & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ full_codegen:
- gt.Scalar
- gt.Tensor
- hardshrink
- hardshrink_backward
- hardsigmoid
- hardsigmoid_backward
- hardswish
Expand Down Expand Up @@ -76,6 +77,7 @@ full_codegen:
- silu_backward
- sin
- sinh
- softshrink_backward
- take
- tan
- tanh
Expand Down Expand Up @@ -173,7 +175,6 @@ supported:
- gather
- gelu
- gelu_backward
- hardshrink_backward
- hardtanh
- hardtanh_backward
- index.Tensor
Expand Down Expand Up @@ -281,7 +282,6 @@ supported:
- softplus
- softplus_backward
- softshrink
- softshrink_backward
- sort
- sort.stable
- split.Tensor
Expand Down