Skip to content

Commit 98ed451

Browse files
committed
[Relax][Frontend][ONNX] Error converting operator Expand: TVMError: broadcast_to expects the input tensor shape is broadcastable to the target shape
1 parent 874de94 commit 98ed451

File tree

2 files changed

+156
-7
lines changed

2 files changed

+156
-7
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,15 +1910,35 @@ def _impl_v13(cls, bb, inputs, attr, params):
19101910
if isinstance(shape, relax.ShapeExpr):
19111911
data_shape = list(data.struct_info.shape)
19121912
target_shape = list(shape.values)
1913+
original_data_shape = [dim.value if hasattr(dim, 'value') else str(dim) for dim in data_shape]
1914+
original_target_shape = [dim.value if hasattr(dim, 'value') else str(dim) for dim in target_shape]
19131915
data_shape = [1] * (len(target_shape) - len(data_shape)) + data_shape
19141916
assert len(data_shape) == len(target_shape)
1915-
# Fix small target shapes or target shapes assigned to -1
1917+
# Apply ONNX v13 Expand broadcasting rules
19161918
for i, s in enumerate(target_shape):
1917-
if isinstance(s, tvm.tir.IntImm) and (
1918-
(isinstance(data_shape[i], tvm.tir.IntImm) and s < data_shape[i])
1919-
or s.value == -1
1920-
):
1921-
target_shape[i] = data_shape[i]
1919+
if isinstance(s, tvm.tir.IntImm):
1920+
if s.value == -1:
1921+
# -1 means preserve the input dimension
1922+
target_shape[i] = data_shape[i]
1923+
elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1:
1924+
# Input dimension is 1, can broadcast to any target dimension >= 1
1925+
if s.value < 1:
1926+
raise ValueError(f"ONNX Expand: Invalid target dimension {s.value} at possition {i}. Target dimensions must be >= 1.")
1927+
elif isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value:
1928+
# Dimensions match, no change needed
1929+
pass
1930+
elif s.value == 1:
1931+
# Target dimension is 1 but input dimension is not 1
1932+
# This would "squeeze" the dimension - preserve input for safety
1933+
target_shape[i] = data_shape[i]
1934+
else:
1935+
if isinstance(data_shape[i], tvm.tir.IntImm):
1936+
raise ValueError(
1937+
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} to target shape {original_target_shape}. "
1938+
f"At dimension {i}: input size {data_shape[i].value} is incompatible with target size {s.value}. "
1939+
f"ONNX broadcasting requires corresponding dimensions to have the same value or one of them to be 1."
1940+
)
1941+
# For dynamic shapes, let broadcast_to handle it
19221942
if target_shape == data_shape:
19231943
return data
19241944
return relax.op.broadcast_to(data, relax.ShapeExpr(target_shape))
@@ -1929,6 +1949,8 @@ def _impl_v13(cls, bb, inputs, attr, params):
19291949
# ONNX Expand operator requires preserving target rank and broadcasting
19301950
# according to standard rules. Dimensions are right-aligned.
19311951
data_shape = [dim.value for dim in data.struct_info.shape]
1952+
original_data_shape = data_shape.copy()
1953+
original_new_shape = new_shape.copy()
19321954

19331955
# Right-align the shapes
19341956
if len(new_shape) > len(data_shape):
@@ -1938,8 +1960,26 @@ def _impl_v13(cls, bb, inputs, attr, params):
19381960
# Fix small target shapes - if target dim is smaller than input dim
19391961
# use the input dim (ONNX-specific behavior).
19401962
for i in range(len(new_shape)):
1941-
if new_shape[i] < data_shape[i]:
1963+
if new_shape[i] == -1:
1964+
# -1 means preserve the input dimension
1965+
new_shape[i] = data_shape[i]
1966+
elif data_shape[i] == 1:
1967+
# Input dimension is 1, can broadcast to any target dimension >= 1
1968+
if s.value < 1:
1969+
raise ValueError(f"ONNX Expand: Invalid target dimension {new_shape[i]} at possition {i}. Target dimensions must be >= 1.")
1970+
elif new_shape[i] == data_shape[i]:
1971+
# Dimensions match, no change needed
1972+
pass
1973+
elif new_shape[i] == 1:
1974+
# Target dimension is 1 but input dimension is not 1
1975+
# This would "squeeze" the dimension - preserve input for safety
19421976
new_shape[i] = data_shape[i]
1977+
else:
1978+
raise ValueError(
1979+
f"ONNX Expand: Cannot broadcast input shape {original_data_shape} to target shape {original_new_shape}. "
1980+
f"At dimension {i}: input size {data_shape[i]} is incompatible with target size {new_shape[i]}. "
1981+
f"ONNX broadcasting requires corresponding dimensions to have the same value or one of them to be 1."
1982+
)
19431983
return relax.op.broadcast_to(data, relax.ShapeExpr(new_shape))
19441984

19451985
# Otherwise handle dynamic shapes.
@@ -1957,6 +1997,18 @@ def _impl_v13(cls, bb, inputs, attr, params):
19571997
shape_vars.append(tvm.tir.Var("x_%d" % i, "int64"))
19581998
bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
19591999
return bb.normalize(relax.op.broadcast_to(data, relax.ShapeExpr(shape_vars)))
2000+
2001+
# Applying broadcasting rules for dynamic shapes
2002+
data_shape = list(data.struct_info.shape)
2003+
data_ndim = len(data_shape)
2004+
target_ndim = shape_ndim
2005+
padded_data = data
2006+
2007+
if target_ndim > data_ndim:
2008+
padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape
2009+
padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape)))
2010+
2011+
return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars)))
19602012

19612013

19622014
class Attention(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,6 +1908,103 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data):
19081908
ref_data = np.tile(data, (64, 1, 1))
19091909
_test_expand_dynamic_shapeexpr("expand_with_dynamic_dim", data, shape_data, shape, ref_data)
19101910

1911+
def test_expand_incompatible_broadcasting():
1912+
"""
1913+
This test case reproduces the error where input tensor shape at dim 1 is 25
1914+
and target shape at dim 3 is 56, which violates ONNX broadcasting rules
1915+
"""
1916+
def _test_expand_error_case(name, data_shape, target_shape):
1917+
data = np.random.uniform(size=data_shape).astype(np.float32)
1918+
1919+
shape_array = np.array(target_shape_vals, dtype=np.int64)
1920+
shape_node = onnx.helper.make_node(
1921+
"Constant",
1922+
inputs=[],
1923+
outputs=["shape"],
1924+
value=onnx.helper.make_tensor(
1925+
name="const_tensor",
1926+
data_type=onnx.TensorProto.INT64,
1927+
dims=shape_array.shape,
1928+
vals=shape_array.flatten(),
1929+
),
1930+
)
1931+
1932+
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
1933+
1934+
graph = helper.make_graph(
1935+
[shape_node, expand_node],
1936+
"expand_error_test",
1937+
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
1938+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
1939+
)
1940+
1941+
model = helper.make_model(graph, producer_name=name)
1942+
1943+
with pytest.raises(tvm.error.TVMError, ValueError)) as exc_info:
1944+
from_onnx(model, keep_params_in_input=True)
1945+
1946+
error_msg = str(exc_info.value)
1947+
assert "broadcast" in error_msg.lower() or "incompatible" in error_msg.lower(), \
1948+
f"Expected broadcasting error, but got: {error_msg}"
1949+
1950+
# Test case 1: Reproduce the exact error from the issue-17769
1951+
# Input shape: (25,), target shape: (1, 1, 1, 56)
1952+
# This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1
1953+
_test_expand_error_case(
1954+
"expand_incompatible_25_to_56",
1955+
data_shape=(25,),
1956+
target_shape=(1, 1, 1, 56),
1957+
)
1958+
1959+
# Test case 2: Another incompatible case
1960+
# Input shape: (1, 25), target shape: (1, 1, 1, 56)
1961+
# After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56)
1962+
# This should fail because 25 != 56 and neither is 1
1963+
_test_expand_error_case(
1964+
"expand_incompatible_aligned_25_to_56",
1965+
data_shape=(1, 25),
1966+
target_shape=(1, 1, 1, 56),
1967+
)
1968+
1969+
# Test case 3: Valid case for comparison - should not raise error
1970+
def _test_expand_valid_case():
1971+
"""Test a valid expand case to ensure our fix doesn't break valid operations"""
1972+
data_shape = (1, 25)
1973+
target_shape = [2, 25] # Valid: input (1, 25) can broadcast to (2, 25)
1974+
1975+
data = np.random.uniform(size=data_shape).astype(np.float32)
1976+
shape_array = np.array(target_shape_vals, dtype=np.int64)
1977+
1978+
shape_node = onnx.helper.make_node(
1979+
"Constant",
1980+
inputs=[],
1981+
outputs=["shape"],
1982+
value=onnx.helper.make_tensor(
1983+
name="const_tensor",
1984+
data_type=onnx.TensorProto.INT64,
1985+
dims=shape_array.shape,
1986+
vals=shape_array.flatten(),
1987+
),
1988+
)
1989+
1990+
expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
1991+
1992+
graph = helper.make_graph(
1993+
[shape_node, expand_node],
1994+
"expand_valid_test",
1995+
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(data.shape))],
1996+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, target_shape_vals)],
1997+
)
1998+
1999+
model = helper.make_model(graph, producer_name="expand_valid_test_case")
2000+
2001+
try:
2002+
tvm_model = from_onnx(model, keep_params_in_input=True)
2003+
except Exception as e:
2004+
pytest.fail(f"Valid expand case should not fail, but got error: {e}")
2005+
2006+
_test_expand_valid_case()
2007+
19112008

19122009
# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed.
19132010
@pytest.mark.skip("Produces ill-formed IR")

0 commit comments

Comments
 (0)