-
Notifications
You must be signed in to change notification settings - Fork 560
Open
Description
Seems python tests fail on non-symint expand calls. Here we are calling torch.testing.make_non_contiguous on idx which is indeed a dynamic object produced by torch.randint(). Under the hood, this op calls torch.expand().
It "appears" the upstream somehow confuses expand.SymInt vs. expand? Correct me if I am wrong.
Requirement:
torch.randintis not in scope of bounded dynamic shape design because it's an unbounded op. So, I propose we materialize the value ofrandintcall so regular expand won't fail. Wdyt @Krovatkin?
Below is a failure example.
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 390, in instantiated_test
raise rte
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 377, in instantiated_test
result = test(self, **param_kwargs)
File "/workspace/pytorch/xla/test/../../test/test_torch.py", line 2967, in test_index_reduce
idx = torch.testing.make_non_contiguous(idx)
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_deprecated.py", line 32, in inner_wrapper
return_value = fn(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/testing/_deprecated.py", line 138, in make_non_contiguous
input.copy_(tensor)
RuntimeError: /workspace/pytorch/xla/third_party/tensorflow/bazel-tensorflow/tensorflow/compiler/xla/xla_client/debug_macros.h:27 : Check failed: status.status() == ::tensorflow::Status::OK() (INVALID_ARGUMENT: Input dimension should be either 1 or equal to the output dimension it is broadcasting into; the 0th operand dimension is 3, the 0th output dimension is 10. vs. OK)
*** Begin stack trace ***
tensorflow::CurrentStackTrace[abi:cxx11]()
xla::Shape const* ConsumeValue<xla::Shape const*>(tensorflow::StatusOr<xla::Shape const*>&&)
torch_xla::XlaHelpers::ShapeOfXlaOp(xla::XlaOp)
torch_xla::InferOutputShape(absl::lts_20211102::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20211102::Span<xla::XlaOp const>)> const&)
torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
torch_xla::Expand::Expand(torch::lazy::Value const&, std::vector<long, std::allocator<long> >)
torch_xla::XLATensor::copy_(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&)
torch_xla::XLANativeFunctions::_copy_from(at::Tensor const&, at::Tensor const&, bool)
at::_ops::_copy_from::call(at::Tensor const&, at::Tensor const&, bool)
at::native::copy_(at::Tensor&, at::Tensor const&, bool)
at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool)
Metadata
Metadata
Assignees
Labels
dynamismDynamic Shape FeaturesDynamic Shape Features