Skip to content

Commit ab15402

Browse files
authored
Codegen HardshrinkBackward and SoftshrinkBackward (#4002)
* Codegen HardshrinkBackward and SoftshrinkBackward. * reformated * cleaned up a bit. * reformated * fixed a typo in yaml file. * Added failing tests for mixed data type cases. All tests pass. * reformated.
1 parent 0d316f3 commit ab15402

File tree

11 files changed

+99
-101
lines changed

11 files changed

+99
-101
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5967,6 +5967,24 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
59675967
});
59685968
}
59695969

5970+
TEST_F(AtenXlaTensorTest, TestHardshrinkWithMixedDataType) {
5971+
torch::Tensor lambdaTensor =
5972+
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
5973+
// It seems the below .item() will convert a kFloat64 to a kFloat32 if I
5974+
// make the scalar tensor a kFloat32 type.
5975+
torch::Scalar lambda = lambdaTensor.item();
5976+
torch::Tensor input =
5977+
torch::randn({10}, torch::TensorOptions(torch::kFloat64));
5978+
5979+
torch::Tensor output = torch::hardshrink(input, lambda);
5980+
ForEachDevice([&](const torch::Device& device) {
5981+
torch::Tensor xla_input = CopyToDevice(input, device);
5982+
torch::Tensor xla_output = torch::hardshrink(xla_input, lambda);
5983+
AllClose(output, xla_output);
5984+
});
5985+
}
5986+
5987+
// Unlike Softshrink, a negative lambda is a valid input for Hardshrink.
59705988
TEST_F(AtenXlaTensorTest, TestHardshrinkWithNegativeLambda) {
59715989
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
59725990
torch::Scalar lambd = -0.5;
@@ -10433,6 +10451,22 @@ TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) {
1043310451
});
1043410452
}
1043510453

10454+
TEST_F(AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) {
10455+
torch::Tensor lambdaTensor =
10456+
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
10457+
torch::Scalar lambda = lambdaTensor.item();
10458+
10459+
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10460+
return torch::hardshrink(inputs[0], lambda);
10461+
};
10462+
ForEachDevice([&](const torch::Device& device) {
10463+
TestBackward(
10464+
{torch::randn(
10465+
{100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))},
10466+
device, testfn);
10467+
});
10468+
}
10469+
1043610470
TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) {
1043710471
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
1043810472
return torch::softshrink(inputs[0]);
@@ -10445,6 +10479,22 @@ TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) {
1044510479
});
1044610480
}
1044710481

10482+
TEST_F(AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) {
10483+
torch::Tensor lambdaTensor =
10484+
torch::scalar_tensor(0., torch::TensorOptions(torch::kFloat32));
10485+
torch::Scalar lambda = lambdaTensor.item();
10486+
10487+
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10488+
return torch::softshrink(inputs[0], lambda);
10489+
};
10490+
ForEachDevice([&](const torch::Device& device) {
10491+
TestBackward(
10492+
{torch::randn(
10493+
{100}, torch::TensorOptions(torch::kFloat64).requires_grad(true))},
10494+
device, testfn);
10495+
});
10496+
}
10497+
1044810498
TEST_F(AtenXlaTensorTest, TestHardtanhBackward) {
1044910499
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
1045010500
return torch::hardtanh(inputs[0]);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,14 +1243,6 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
12431243
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
12441244
}
12451245

1246-
at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
1247-
const at::Tensor& self,
1248-
const at::Scalar& lambda) {
1249-
XLA_FN_COUNTER("xla::");
1250-
return bridge::AtenFromXlaTensor(XLATensor::hardshrink_backward(
1251-
bridge::GetXlaTensor(grad_out), bridge::GetXlaTensor(self), lambda));
1252-
}
1253-
12541246
at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
12551247
const at::Scalar& min_val,
12561248
const at::Scalar& max_val) {
@@ -2554,14 +2546,6 @@ at::Tensor XLANativeFunctions::softshrink(const at::Tensor& self,
25542546
XLATensor::softshrink(bridge::GetXlaTensor(self), lambda));
25552547
}
25562548

2557-
at::Tensor XLANativeFunctions::softshrink_backward(const at::Tensor& grad_out,
2558-
const at::Tensor& self,
2559-
const at::Scalar& lambda) {
2560-
XLA_FN_COUNTER("xla::");
2561-
return bridge::AtenFromXlaTensor(XLATensor::softshrink_backward(
2562-
bridge::GetXlaTensor(grad_out), bridge::GetXlaTensor(self), lambda));
2563-
}
2564-
25652549
std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::sort(
25662550
const at::Tensor& self, int64_t dim, bool descending) {
25672551
XLA_FN_COUNTER("xla::");

torch_xla/csrc/elementwise.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,20 @@ xla::XlaOp BuildSoftshrink(xla::XlaOp input, const at::Scalar& lambda) {
141141
}
142142

143143
xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
144-
const at::Scalar& lambda) {
144+
xla::XlaOp lambda) {
145145
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
146-
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
147-
return xla::Select(Between(input, -lambda, lambda), zero, grad_output);
146+
xla::PrimitiveType input_element_type = shape.element_type();
147+
xla::XlaOp zero = xla::Zero(input.builder(), input_element_type);
148+
149+
// The conversion here is needed because when we do computation such as
150+
// broadcast or subtraction for input and lambda, XLA disallows mixed
151+
// precision for float point types.
152+
lambda = MaybeConvertTo(lambda, input_element_type);
153+
xla::XlaOp check_low = BuildComparisonOp(at::aten::ge, input, zero - lambda);
154+
xla::XlaOp check_high = BuildComparisonOp(at::aten::le, input, lambda);
155+
xla::XlaOp between = xla::And(check_low, check_high);
156+
157+
return xla::Select(between, zero, grad_output);
148158
}
149159

150160
xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,

torch_xla/csrc/elementwise.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ xla::XlaOp BuildHardSwish(xla::XlaOp input);
3535
xla::XlaOp BuildHardSwishBackward(xla::XlaOp grad_output, xla::XlaOp input);
3636
xla::XlaOp BuildSoftshrink(xla::XlaOp input, const at::Scalar& lambda);
3737
xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
38-
const at::Scalar& lambda);
38+
xla::XlaOp lambda);
3939

4040
xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,
4141
const at::Scalar& min_val,

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,13 @@ torch_xla::XlaOpVector Hardshrink::Lower(LoweringContext* loctx) const {
336336
return ReturnOp(BuildHardshrink(xla_input, lambd), loctx);
337337
}
338338

339+
torch_xla::XlaOpVector HardshrinkBackward::Lower(LoweringContext* loctx) const {
340+
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
341+
xla::XlaOp input = loctx->GetOutputOp(operand(1));
342+
xla::XlaOp lambda = loctx->GetOutputOp(operand(2));
343+
return ReturnOp(BuildShrinkBackward(grad_output, input, lambda), loctx);
344+
}
345+
339346
torch_xla::XlaOpVector Hardsigmoid::Lower(LoweringContext* loctx) const {
340347
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
341348
return ReturnOp(BuildHardSigmoid(xla_input), loctx);
@@ -532,6 +539,13 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const {
532539
return ReturnOp(xla::Sinh(xla_input), loctx);
533540
}
534541

542+
torch_xla::XlaOpVector SoftshrinkBackward::Lower(LoweringContext* loctx) const {
543+
xla::XlaOp grad_output = loctx->GetOutputOp(operand(0));
544+
xla::XlaOp input = loctx->GetOutputOp(operand(1));
545+
xla::XlaOp lambda = loctx->GetOutputOp(operand(2));
546+
return ReturnOp(BuildShrinkBackward(grad_output, input, lambda), loctx);
547+
}
548+
535549
/* Blocked on https://github.com/pytorch/xla/issues/3596 */
536550
// torch_xla::XlaOpVector Slogdet::Lower(LoweringContext* loctx) const {
537551
// xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
414414
return GetXlaShape(self);
415415
}
416416

417+
xla::Shape HardshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
418+
const torch::lazy::Value& input,
419+
const torch::lazy::Value& lambd) {
420+
return GetXlaShape(input);
421+
}
422+
417423
xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input) {
418424
return GetXlaShape(input);
419425
}
@@ -623,6 +629,12 @@ xla::Shape SinhOutputShape(const torch::lazy::Value& input) {
623629
return GetXlaShape(input);
624630
}
625631

632+
xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
633+
const torch::lazy::Value& input,
634+
const torch::lazy::Value& lambd) {
635+
return GetXlaShape(input);
636+
}
637+
626638
/* Blocked on https://github.com/pytorch/xla/issues/3596 */
627639
// xla::Shape SlogdetOutputShape(const torch::lazy::Value& input) {
628640
// auto lower_for_shape_fn =

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
134134
xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
135135
const torch::lazy::Value& lambd);
136136

137+
xla::Shape HardshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
138+
const torch::lazy::Value& input,
139+
const torch::lazy::Value& lambd);
140+
137141
xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input);
138142

139143
xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
@@ -214,6 +218,9 @@ xla::Shape SinOutputShape(const torch::lazy::Value& input);
214218

215219
xla::Shape SinhOutputShape(const torch::lazy::Value& input);
216220

221+
xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out,
222+
const torch::lazy::Value& input,
223+
const torch::lazy::Value& lambd);
217224
/* Blocked on https://github.com/pytorch/xla/issues/3596 */
218225
// xla::Shape SlogdetOutputShape(const torch::lazy::Value& input);
219226

torch_xla/csrc/ops/shrink_backward.cpp

Lines changed: 0 additions & 35 deletions
This file was deleted.

torch_xla/csrc/ops/shrink_backward.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108
#include "torch_xla/csrc/ops/scatter_add.h"
109109
#include "torch_xla/csrc/ops/send.h"
110110
#include "torch_xla/csrc/ops/sgd_optimizer_step.h"
111-
#include "torch_xla/csrc/ops/shrink_backward.h"
112111
#include "torch_xla/csrc/ops/softmax.h"
113112
#include "torch_xla/csrc/ops/softshrink.h"
114113
#include "torch_xla/csrc/ops/split.h"
@@ -1374,14 +1373,6 @@ XLATensorPtr XLATensor::le(const XLATensorPtr& input,
13741373
return DispatchComparisonOp(at::aten::le, input, other);
13751374
}
13761375

1377-
XLATensorPtr XLATensor::hardshrink_backward(const XLATensorPtr& grad_out,
1378-
const XLATensorPtr& input,
1379-
const at::Scalar& lambda) {
1380-
return input->CreateFrom(torch::lazy::MakeNode<ShrinkBackward>(
1381-
torch::lazy::OpKind(at::aten::hardshrink_backward),
1382-
grad_out->GetIrValue(), input->GetIrValue(), lambda));
1383-
}
1384-
13851376
XLATensorPtr XLATensor::hardtanh_backward(const XLATensorPtr& grad_output,
13861377
const XLATensorPtr& input,
13871378
const at::Scalar& min_val,
@@ -2313,14 +2304,6 @@ XLATensorPtr XLATensor::softshrink(const XLATensorPtr& input,
23132304
torch::lazy::MakeNode<Softshrink>(input->GetIrValue(), lambda));
23142305
}
23152306

2316-
XLATensorPtr XLATensor::softshrink_backward(const XLATensorPtr& grad_out,
2317-
const XLATensorPtr& input,
2318-
const at::Scalar& lambda) {
2319-
return input->CreateFrom(torch::lazy::MakeNode<ShrinkBackward>(
2320-
torch::lazy::OpKind(at::aten::softshrink_backward),
2321-
grad_out->GetIrValue(), input->GetIrValue(), lambda));
2322-
}
2323-
23242307
std::vector<XLATensorPtr> XLATensor::split(const XLATensorPtr& input,
23252308
int64_t split_size, int64_t dim) {
23262309
auto input_shape = input->shape();

0 commit comments

Comments
 (0)