@@ -1908,6 +1908,103 @@ def _test_expand_dynamic_shapeexpr(name, data, shape_data, shape, ref_data):
19081908 ref_data = np .tile (data , (64 , 1 , 1 ))
19091909 _test_expand_dynamic_shapeexpr ("expand_with_dynamic_dim" , data , shape_data , shape , ref_data )
19101910
1911+ def test_expand_incompatible_broadcasting ():
1912+ """
1913+ This test case reproduces the error where input tensor shape at dim 1 is 25
1914+ and target shape at dim 3 is 56, which violates ONNX broadcasting rules
1915+ """
1916+ def _test_expand_error_case (name , data_shape , target_shape ):
1917+ data = np .random .uniform (size = data_shape ).astype (np .float32 )
1918+
1919+ shape_array = np .array (target_shape_vals , dtype = np .int64 )
1920+ shape_node = onnx .helper .make_node (
1921+ "Constant" ,
1922+ inputs = [],
1923+ outputs = ["shape" ],
1924+ value = onnx .helper .make_tensor (
1925+ name = "const_tensor" ,
1926+ data_type = onnx .TensorProto .INT64 ,
1927+ dims = shape_array .shape ,
1928+ vals = shape_array .flatten (),
1929+ ),
1930+ )
1931+
1932+ expand_node = helper .make_node ("Expand" , ["in" , "shape" ], ["out" ])
1933+
1934+ graph = helper .make_graph (
1935+ [shape_node , expand_node ],
1936+ "expand_error_test" ,
1937+ inputs = [helper .make_tensor_value_info ("in" , TensorProto .FLOAT , list (data .shape ))],
1938+ outputs = [helper .make_tensor_value_info ("out" , TensorProto .FLOAT , target_shape_vals )],
1939+ )
1940+
1941+ model = helper .make_model (graph , producer_name = name )
1942+
1943+ with pytest .raises (tvm .error .TVMError , ValueError )) as exc_info :
1944+ from_onnx (model , keep_params_in_input = True )
1945+
1946+ error_msg = str (exc_info .value )
1947+ assert "broadcast" in error_msg .lower () or "incompatible" in error_msg .lower (), \
1948+ f"Expected broadcasting error, but got: { error_msg } "
1949+
1950+ # Test case 1: Reproduce the exact error from the issue-17769
1951+ # Input shape: (25,), target shape: (1, 1, 1, 56)
1952+ # This should faill because input dim 1 (25) != target dim 3 (56) and neither is 1
1953+ _test_expand_error_case (
1954+ "expand_incompatible_25_to_56" ,
1955+ data_shape = (25 ,),
1956+ target_shape = (1 , 1 , 1 , 56 ),
1957+ )
1958+
1959+ # Test case 2: Another incompatible case
1960+ # Input shape: (1, 25), target shape: (1, 1, 1, 56)
1961+ # After right-alignment, input (1, 1, 1, 25) vs. target (1, 1, 1, 56)
1962+ # This should fail because 25 != 56 and neither is 1
1963+ _test_expand_error_case (
1964+ "expand_incompatible_aligned_25_to_56" ,
1965+ data_shape = (1 , 25 ),
1966+ target_shape = (1 , 1 , 1 , 56 ),
1967+ )
1968+
1969+ # Test case 3: Valid case for comparison - should not raise error
1970+ def _test_expand_valid_case ():
1971+ """Test a valid expand case to ensure our fix doesn't break valid operations"""
1972+ data_shape = (1 , 25 )
1973+ target_shape = [2 , 25 ] # Valid: input (1, 25) can broadcast to (2, 25)
1974+
1975+ data = np .random .uniform (size = data_shape ).astype (np .float32 )
1976+ shape_array = np .array (target_shape_vals , dtype = np .int64 )
1977+
1978+ shape_node = onnx .helper .make_node (
1979+ "Constant" ,
1980+ inputs = [],
1981+ outputs = ["shape" ],
1982+ value = onnx .helper .make_tensor (
1983+ name = "const_tensor" ,
1984+ data_type = onnx .TensorProto .INT64 ,
1985+ dims = shape_array .shape ,
1986+ vals = shape_array .flatten (),
1987+ ),
1988+ )
1989+
1990+ expand_node = helper .make_node ("Expand" , ["in" , "shape" ], ["out" ])
1991+
1992+ graph = helper .make_graph (
1993+ [shape_node , expand_node ],
1994+ "expand_valid_test" ,
1995+ inputs = [helper .make_tensor_value_info ("in" , TensorProto .FLOAT , list (data .shape ))],
1996+ outputs = [helper .make_tensor_value_info ("out" , TensorProto .FLOAT , target_shape_vals )],
1997+ )
1998+
1999+ model = helper .make_model (graph , producer_name = "expand_valid_test_case" )
2000+
2001+ try :
2002+ tvm_model = from_onnx (model , keep_params_in_input = True )
2003+ except Exception as e :
2004+ pytest .fail (f"Valid expand case should not fail, but got error: { e } " )
2005+
2006+ _test_expand_valid_case ()
2007+
19112008
19122009# TODO(jwfromm) Current approach to dynamic expand is technically not well formed. Reenable once fixed.
19132010@pytest .mark .skip ("Produces ill-formed IR" )
0 commit comments