Skip to content

Commit e9119cb

Browse files
committed
VLLM Workaround
stack-info: PR: #2165, branch: drisspg/stack/52
1 parent fb39817 commit e9119cb

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

torchao/prototype/mx_formats/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ def to_blocked(input_matrix) -> Tensor:
3535
padded_cols = n_col_blocks * 4
3636

3737
padded = input_matrix
38-
# if (rows, cols) != (padded_rows, padded_cols):
39-
padded = torch.zeros(
40-
(padded_rows, padded_cols),
41-
device=input_matrix.device,
42-
dtype=input_matrix.dtype,
43-
)
44-
padded[:rows, :cols] = input_matrix
38+
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
39+
if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
40+
padded = torch.zeros(
41+
(padded_rows, padded_cols),
42+
device=input_matrix.device,
43+
dtype=input_matrix.dtype,
44+
)
45+
padded[:rows, :cols] = input_matrix
4546

4647
# Rearrange the blocks
4748
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)

0 commit comments

Comments
 (0)