@@ -5967,6 +5967,24 @@ TEST_F(AtenXlaTensorTest, TestHardshrink) {
59675967 });
59685968}
59695969
5970+ TEST_F (AtenXlaTensorTest, TestHardshrinkWithMixedDataType) {
5971+ torch::Tensor lambdaTensor =
5972+ torch::scalar_tensor (0 ., torch::TensorOptions (torch::kFloat32 ));
5973+ // It seems the below .item() will convert a kFloat64 to a kFloat32 if I
5974+ // make the scalar tensor a kFloat32 type.
5975+ torch::Scalar lambda = lambdaTensor.item ();
5976+ torch::Tensor input =
5977+ torch::randn ({10 }, torch::TensorOptions (torch::kFloat64 ));
5978+
5979+ torch::Tensor output = torch::hardshrink (input, lambda);
5980+ ForEachDevice ([&](const torch::Device& device) {
5981+ torch::Tensor xla_input = CopyToDevice (input, device);
5982+ torch::Tensor xla_output = torch::hardshrink (xla_input, lambda);
5983+ AllClose (output, xla_output);
5984+ });
5985+ }
5986+
5987+ // Unlike Softshrink, a negative lambda is a valid input for Hardshrink.
59705988TEST_F (AtenXlaTensorTest, TestHardshrinkWithNegativeLambda) {
59715989 torch::Tensor input = torch::randn ({10 }, torch::TensorOptions (torch::kFloat ));
59725990 torch::Scalar lambd = -0.5 ;
@@ -10433,6 +10451,22 @@ TEST_F(AtenXlaTensorTest, TestHardshrinkBackward) {
1043310451 });
1043410452}
1043510453
10454+ TEST_F (AtenXlaTensorTest, TestHardshrinkBackwardWithMixedDataType) {
10455+ torch::Tensor lambdaTensor =
10456+ torch::scalar_tensor (0 ., torch::TensorOptions (torch::kFloat32 ));
10457+ torch::Scalar lambda = lambdaTensor.item ();
10458+
10459+ auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10460+ return torch::hardshrink (inputs[0 ], lambda);
10461+ };
10462+ ForEachDevice ([&](const torch::Device& device) {
10463+ TestBackward (
10464+ {torch::randn (
10465+ {100 }, torch::TensorOptions (torch::kFloat64 ).requires_grad (true ))},
10466+ device, testfn);
10467+ });
10468+ }
10469+
1043610470TEST_F (AtenXlaTensorTest, TestSoftshrinkBackward) {
1043710471 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
1043810472 return torch::softshrink (inputs[0 ]);
@@ -10445,6 +10479,22 @@ TEST_F(AtenXlaTensorTest, TestSoftshrinkBackward) {
1044510479 });
1044610480}
1044710481
10482+ TEST_F (AtenXlaTensorTest, TestSoftshrinkBackwardWithMixedDataType) {
10483+ torch::Tensor lambdaTensor =
10484+ torch::scalar_tensor (0 ., torch::TensorOptions (torch::kFloat32 ));
10485+ torch::Scalar lambda = lambdaTensor.item ();
10486+
10487+ auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
10488+ return torch::softshrink (inputs[0 ], lambda);
10489+ };
10490+ ForEachDevice ([&](const torch::Device& device) {
10491+ TestBackward (
10492+ {torch::randn (
10493+ {100 }, torch::TensorOptions (torch::kFloat64 ).requires_grad (true ))},
10494+ device, testfn);
10495+ });
10496+ }
10497+
1044810498TEST_F (AtenXlaTensorTest, TestHardtanhBackward) {
1044910499 auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
1045010500 return torch::hardtanh (inputs[0 ]);
0 commit comments