Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5839,6 +5839,17 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
});
}

TEST_F(AtenXlaTensorTest, TestHardshrinkWithNegativeLambda) {
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Scalar lambd = -0.5;
torch::Tensor output = torch::hardshrink(input, lambd);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::hardshrink(xla_input, lambd);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestHardSigmoid) {
torch::Tensor input = torch::randn({10}, torch::TensorOptions(torch::kFloat));
torch::Tensor output = torch::hardsigmoid(input);
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,13 +1244,6 @@ at::Tensor XLANativeFunctions::gelu_backward(const at::Tensor& grad,
bridge::GetXlaTensor(grad), bridge::GetXlaTensor(self), approximate));
}

at::Tensor XLANativeFunctions::hardshrink(const at::Tensor& self,
const at::Scalar& lambda) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::hardshrink(bridge::GetXlaTensor(self), lambda));
}

at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
const at::Tensor& self,
const at::Scalar& lambda) {
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,16 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
0, input_shape.element_type(), input.builder()));
}

xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda) {
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful about these at::Scalar -> xla::XlaOp because currently the default is f64. We can do a MaybeCast here, take a look at https://github.com/pytorch/xla/blame/f8b3dfd45d753a8844aca871cb39511022bb35ff/torch_xla/csrc/elementwise.cpp#L383 for example.

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't cast, then we may get error

Seen floating point types of different precisions in %subtract.4 = f64[] subtract(f32[] %constant.3, f64[] %constant.1), metadata={op_type="aten__hardshrink" op_name="aten__hardshrink" source_file="[email protected]" source_line=1026}, but mixed precision is disallowed.

? Is this the reason we have to do a MaybeCast here?

Also, shouldn't we cast the element type of xla::XlaOp to its original type of at::Scalar?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check

XlaHelpers::ScalarValue(min_val, element_type, builder));
for existing behavior. scalar often default to f64 when user don't specified it and will cause additional type promotion issue if we use its own type. It is easier to user the other operand type in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

By "promotion issue", you meant when we have an operation with 2 mixed type operands, xla will try to convert one type to another implicitly, just like any c++ operator?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea.. we usually promote to "more complex" type like f64 and s64 which are slower to compute.

const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp zero = xla::Zero(input.builder(), shape.element_type());
return xla::Select(Between(input, -lambda, lambda), zero, input);
xla::PrimitiveType element_type = shape.element_type();
xla::XlaOp zero = xla::Zero(input.builder(), element_type);

xla::XlaOp check_low = BuildComparisonOp(at::aten::ge, input, zero - lambda);
xla::XlaOp check_high = BuildComparisonOp(at::aten::le, input, lambda);
xla::XlaOp between = xla::And(check_low, check_high);

return xla::Select(between, zero, input);
}

xla::XlaOp BuildHardSigmoid(xla::XlaOp input) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input,
xla::XlaOp noise, const at::Scalar& lower,
const at::Scalar& upper, bool training);

xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda);
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda);
xla::XlaOp BuildHardSigmoid(xla::XlaOp input);
xla::XlaOp BuildHardSigmoidBackward(xla::XlaOp grad_output, xla::XlaOp input);
xla::XlaOp BuildHardSwish(xla::XlaOp input);
Expand Down
32 changes: 0 additions & 32 deletions torch_xla/csrc/ops/hardshrink.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions torch_xla/csrc/ops/hardshrink.h

This file was deleted.

6 changes: 6 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ torch_xla::XlaOpVector GtTensor::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildComparisonOp(at::aten::gt, xla_input, xla_other), loctx);
}

torch_xla::XlaOpVector Hardshrink::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp lambd = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildHardshrink(xla_input, lambd), loctx);
}

torch_xla::XlaOpVector Hardsigmoid::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildHardSigmoid(xla_input), loctx);
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,11 @@ xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
return GtScalarOutputShape(self, other);
}

xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& lambd) {
return GetXlaShape(self);
}

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ xla::Shape GtScalarOutputShape(const torch::lazy::Value& self,
xla::Shape GtTensorOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

xla::Shape HardshrinkOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& lambd);

xla::Shape HardsigmoidOutputShape(const torch::lazy::Value& input);

xla::Shape HardsigmoidBackwardOutputShape(const torch::lazy::Value& grad_output,
Expand Down
7 changes: 0 additions & 7 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
#include "torch_xla/csrc/ops/gather.h"
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/get_dimensions_size.h"
#include "torch_xla/csrc/ops/hardshrink.h"
#include "torch_xla/csrc/ops/hardtanh_backward.h"
#include "torch_xla/csrc/ops/index_ops.h"
#include "torch_xla/csrc/ops/index_select.h"
Expand Down Expand Up @@ -1349,12 +1348,6 @@ XLATensorPtr XLATensor::le(const XLATensorPtr& input,
return DispatchComparisonOp(at::aten::le, input, other);
}

XLATensorPtr XLATensor::hardshrink(const XLATensorPtr& input,
const at::Scalar& lambda) {
return input->CreateFrom(
torch::lazy::MakeNode<Hardshrink>(input->GetIrValue(), lambda));
}

XLATensorPtr XLATensor::hardshrink_backward(const XLATensorPtr& grad_out,
const XLATensorPtr& input,
const at::Scalar& lambda) {
Expand Down
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ full_codegen:
- ge.Tensor
- gt.Scalar
- gt.Tensor
- hardshrink
- hardsigmoid
- hardsigmoid_backward
- hardswish
Expand Down Expand Up @@ -172,7 +173,6 @@ supported:
- gather
- gelu
- gelu_backward
- hardshrink
- hardshrink_backward
- hardtanh
- hardtanh_backward
Expand Down