Skip to content

Commit 4ab7a24

Browse files
authored
Revert "Fix some more core aten ops (pytorch#6342)" (pytorch#6377)
1 parent 1521316 commit 4ab7a24

File tree

6 files changed

+14
-44
lines changed

6 files changed

+14
-44
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ full_codegen:
7777
- rsqrt
7878
- selu
7979
- sgn
80-
- sigmoid
8180
- sign
8281
- silu
8382
- silu_backward
@@ -303,6 +302,7 @@ supported:
303302
- select_scatter
304303
- selu_
305304
- set_.source_Tensor
305+
- sigmoid
306306
- sigmoid_backward
307307
- slice_copy.Tensor
308308
- slice_scatter

test/test_core_aten_ops.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,17 +1904,11 @@ def test_aten_gelu_0(self):
19041904
kwargs = dict()
19051905
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)
19061906

1907+
@unittest.skip
19071908
def test_aten_gelu_1(self):
19081909
args = (torch.randn((10, 10)).to(torch.float16),)
19091910
kwargs = dict()
1910-
run_export_and_compare(
1911-
self,
1912-
torch.ops.aten.gelu,
1913-
args,
1914-
kwargs,
1915-
rtol=0.001,
1916-
atol=0.01,
1917-
)
1911+
run_export_and_compare(self, torch.ops.aten.gelu, args, kwargs)
19181912

19191913
def test_aten_glu_0(self):
19201914
args = (
@@ -3091,6 +3085,7 @@ def test_aten_native_group_norm_0(self):
30913085
kwargs = dict()
30923086
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)
30933087

3088+
@unittest.skip
30943089
def test_aten_native_group_norm_1(self):
30953090
args = (
30963091
torch.randn((1, 3, 2, 10)).to(torch.float16),
@@ -3103,14 +3098,7 @@ def test_aten_native_group_norm_1(self):
31033098
0.0,
31043099
)
31053100
kwargs = dict()
3106-
run_export_and_compare(
3107-
self,
3108-
torch.ops.aten.native_group_norm,
3109-
args,
3110-
kwargs,
3111-
rtol=0.001,
3112-
atol=0.01,
3113-
)
3101+
run_export_and_compare(self, torch.ops.aten.native_group_norm, args, kwargs)
31143102

31153103
def test_aten_native_layer_norm_0(self):
31163104
args = (
@@ -3417,6 +3405,7 @@ def test_aten_reciprocal_1(self):
34173405
kwargs = dict()
34183406
run_export_and_compare(self, torch.ops.aten.reciprocal, args, kwargs)
34193407

3408+
@unittest.skip
34203409
def test_aten_reciprocal_2(self):
34213410
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
34223411
kwargs = dict()
@@ -4014,6 +4003,7 @@ def test_aten_sigmoid_1(self):
40144003
kwargs = dict()
40154004
run_export_and_compare(self, torch.ops.aten.sigmoid, args, kwargs)
40164005

4006+
@unittest.skip
40174007
def test_aten_sigmoid_2(self):
40184008
args = (torch.randint(0, 10, (10, 10)).to(torch.int32),)
40194009
kwargs = dict()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2758,6 +2758,12 @@ at::Tensor& XLANativeFunctions::set_(at::Tensor& self,
27582758
return self;
27592759
}
27602760

2761+
at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
2762+
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
2763+
return bridge::AtenFromXlaTensor(
2764+
tensor_methods::sigmoid(bridge::GetXlaTensor(self)));
2765+
}
2766+
27612767
at::Tensor XLANativeFunctions::sigmoid_backward(const at::Tensor& grad_output,
27622768
const at::Tensor& output) {
27632769
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,6 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {
684684

685685
torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
686686
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
687-
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
688-
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
689-
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
690-
}
691687
return ReturnOp(BuildReciprocal(xla_input), loctx);
692688
}
693689

@@ -730,14 +726,6 @@ torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
730726
return ReturnOp(BuildSgn(xla_input), loctx);
731727
}
732728

733-
torch_xla::XlaOpVector Sigmoid::Lower(LoweringContext* loctx) const {
734-
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
735-
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
736-
xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32);
737-
}
738-
return ReturnOp(xla::Logistic(xla_input), loctx);
739-
}
740-
741729
torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
742730
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
743731
return ReturnOp(BuildSign(xla_input), loctx);

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -762,11 +762,7 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
762762
}
763763

764764
xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
765-
xla::Shape result_shape = GetXlaShape(input);
766-
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
767-
result_shape.set_element_type(xla::PrimitiveType::F32);
768-
}
769-
return result_shape;
765+
return GetXlaShape(input);
770766
}
771767

772768
xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
@@ -808,14 +804,6 @@ xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
808804
return GetXlaShape(input);
809805
}
810806

811-
xla::Shape SigmoidOutputShape(const torch::lazy::Value& input) {
812-
xla::Shape result_shape = GetXlaShape(input);
813-
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
814-
result_shape.set_element_type(xla::PrimitiveType::F32);
815-
}
816-
return result_shape;
817-
}
818-
819807
xla::Shape SignOutputShape(const torch::lazy::Value& input) {
820808
return GetXlaShape(input);
821809
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,6 @@ xla::Shape SeluOutputShape(const torch::lazy::Value& input);
248248

249249
xla::Shape SgnOutputShape(const torch::lazy::Value& input);
250250

251-
xla::Shape SigmoidOutputShape(const torch::lazy::Value& input);
252-
253251
xla::Shape SignOutputShape(const torch::lazy::Value& input);
254252

255253
xla::Shape SiluOutputShape(const torch::lazy::Value& input);

0 commit comments

Comments
 (0)