Skip to content

Commit a1691bd

Browse files
authored
Codegen hardshrink (#3999)
* compiles * final cleanup * reformat cpp files. * fixed a typo * Do the necessary type conversion.
1 parent 97686a6 commit a1691bd

File tree

11 files changed

+40
-76
lines changed

11 files changed

+40
-76
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5854,6 +5854,17 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
58545854
});
58555855
}
58565856

5857+
TEST_F(AtenXlaTensorTest, TestHardshrinkWithNegativeLambda) {
5858+
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
5859+
torch::Scalar lambd = -0.5;
5860+
torch::Tensor output = torch::hardshrink(input, lambd);
5861+
ForEachDevice([&](const torch::Device& device) {
5862+
torch::Tensor xla_input = CopyToDevice(input, device);
5863+
torch::Tensor xla_output = torch::hardshrink(xla_input, lambd);
5864+
AllClose(output, xla_output);
5865+
});
5866+
}
5867+
58575868
TEST_F(AtenXlaTensorTest, TestHardSigmoid) {
58585869
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
58595870
torch::Tensor output = torch::hardsigmoid(input);

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,13 +1246,6 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
12461246
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
12471247
}
12481248

1249-
at::Tensor XLANativeFunctions::hardshrink(const at::Tensor& self,
1250-
const at::Scalar& lambda) {
1251-
XLA_FN_COUNTER("xla::");
1252-
return bridge::AtenFromXlaTensor(
1253-
XLATensor::hardshrink(bridge::GetXlaTensor(self), lambda));
1254-
}
1255-
12561249
at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
12571250
const at::Tensor& self,
12581251
const at::Scalar& lambda) {

torch_xla/csrc/elementwise.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,20 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
6969
0, input_shape.element_type(), input.builder()));
7070
}
7171

72-
xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda) {
72+
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
7373
const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
74-
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
75-
return xla::Select(Between(input, -lambda, lambda), zero, input);
74+
xla::PrimitiveType input_element_type = shape.element_type();
75+
xla::XlaOp zero = xla::Zero(input.builder(), input_element_type);
76+
77+
// The conversion here is needed because when we do computation such as
78+
// broadcast or subtraction for input and lambda, XLA disallows mixed
79+
// precision for float point types.
80+
lambda = MaybeConvertTo(lambda, input_element_type);
81+
xla::XlaOp check_low = BuildComparisonOp(at::aten::ge, input, zero - lambda);
82+
xla::XlaOp check_high = BuildComparisonOp(at::aten::le, input, lambda);
83+
xla::XlaOp between = xla::And(check_low, check_high);
84+
85+
return xla::Select(between, zero, input);
7686
}
7787

7888
xla::XlaOp BuildHardSigmoid(xla::XlaOp input) {

torch_xla/csrc/elementwise.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input,
2828
xla::XlaOp noise, const at::Scalar& lower,
2929
const at::Scalar& upper, bool training);
3030

31-
xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda);
31+
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda);
3232
xla::XlaOp BuildHardSigmoid(xla::XlaOp input);
3333
xla::XlaOp BuildHardSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input);
3434
xla::XlaOp BuildHardSwish(xla::XlaOp input);

torch_xla/csrc/ops/hardshrink.cpp

Lines changed: 0 additions & 32 deletions
This file was deleted.

torch_xla/csrc/ops/hardshrink.h

Lines changed: 0 additions & 25 deletions
This file was deleted.

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ torch_xla::XlaOpVector GtTensor::Lower(LoweringContext* loctx) const {
330330
return ReturnOp(BuildComparisonOp(at::aten::gt, xla_input, xla_other), loctx);
331331
}
332332

333+
torch_xla::XlaOpVector Hardshrink::Lower(LoweringContext* loctx) const {
334+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
335+
xla::XlaOp lambd = loctx->GetOutputOp(operand(1));
336+
return ReturnOp(BuildHardshrink(xla_input, lambd), loctx);
337+
}
338+
333339
torch_xla::XlaOpVector Hardsigmoid::Lower(LoweringContext* loctx) const {
334340
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
335341
return ReturnOp(BuildHardSigmoid(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,11 @@ xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
409409
return GtScalarOutputShape(self, other);
410410
}
411411

412+
xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
413+
const torch::lazy::Value& lambd) {
414+
return GetXlaShape(self);
415+
}
416+
412417
xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input) {
413418
return GetXlaShape(input);
414419
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
131131
xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
132132
const torch::lazy::Value& other);
133133

134+
xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
135+
const torch::lazy::Value& lambd);
136+
134137
xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input);
135138

136139
xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
#include "torch_xla/csrc/ops/gather.h"
5353
#include "torch_xla/csrc/ops/generic.h"
5454
#include "torch_xla/csrc/ops/get_dimensions_size.h"
55-
#include "torch_xla/csrc/ops/hardshrink.h"
5655
#include "torch_xla/csrc/ops/hardtanh_backward.h"
5756
#include "torch_xla/csrc/ops/index_ops.h"
5857
#include "torch_xla/csrc/ops/index_select.h"
@@ -1349,12 +1348,6 @@ XLATensorPtr XLATensor::le(const XLATensorPtr& input,
13491348
return DispatchComparisonOp(at::aten::le, input, other);
13501349
}
13511350

1352-
XLATensorPtr XLATensor::hardshrink(const XLATensorPtr& input,
1353-
const at::Scalar& lambda) {
1354-
return input->CreateFrom(
1355-
torch::lazy::MakeNode<Hardshrink>(input->GetIrValue(), lambda));
1356-
}
1357-
13581351
XLATensorPtr XLATensor::hardshrink_backward(const XLATensorPtr& grad_out,
13591352
const XLATensorPtr& input,
13601353
const at::Scalar& lambda) {

0 commit comments

Comments
 (0)