Skip to content

Commit cac5261

Browse files
authored
Add exhaustive config option to intmm kernel (#1392)
* Add exhaustive config option to intmm kernel Summary: similar to pytorch/pytorch#126220 we added exhaustive option for int8mm and scaled_mm kernels in torchao Note that there seems to be native int8mm and scaled_mm support in pytorch: https://github.com/pytorch/pytorch/blob/0610b9730e27d066e26396a2d655ba0d98c2012d/torch/_inductor/kernel/mm.py#L305 for int8mm and https://github.com/pytorch/pytorch/blob/0610b9730e27d066e26396a2d655ba0d98c2012d/torch/_inductor/kernel/mm_scaled.py#L575 for scaled mm maybe we should use that at some point. Test Plan: ``` cd benchmarks TORCHAO_AUTOTUNER_ENABLE=1 python intmm.py --file_path intmm_shapes.csv TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE TORCHAO_AUTOTUNER_ENABLE=1 python intmm.py --file_path intmm_shapes.csv ``` Reviewers: Subscribers: Tasks: Tags: * remove unused * enable all autoquant qtensor * guard float8 qtensor subclass * guard exhaustive config torch version
1 parent f258d82 commit cac5261

File tree

7 files changed

+83
-29
lines changed

7 files changed

+83
-29
lines changed

torchao/_models/sam/eval_combo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def mlp_only(mod, name):
350350
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
351351
elif "autoquant_v2-float8" == compress:
352352
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST)
353+
elif "autoquant_v2-all" == compress:
354+
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.ALL_AUTOQUANT_CLASS_LIST)
353355
else:
354356
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)
355357

@@ -362,6 +364,8 @@ def mlp_only(mod, name):
362364
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
363365
elif "autoquant-float8" == compress:
364366
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
367+
elif "autoquant-all" == compress:
368+
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST)
365369
else:
366370
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True)
367371
predictor.model.image_encoder(example_input)

torchao/kernel/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
Set this to a nonzero value to enable the kernels generated by the autotuner. This is turned off by default, because it is still an experimental feature and also can take a long time to run.
88

9+
`TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE`
10+
Use this to enable exhaustive search for both int8mm and scaled_mm kernels.
11+
912
Searching a new config can take a long time and we'll save the updated data in `data.pkl`. If you'd like to contributed updated configs for your hardware or shapes, please open a pull request.
1013

1114
`TORCHAO_AUTOTUNER_DATA_PATH=torchao/kernel/configs/data_a100.pkl`

torchao/kernel/intmm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from torchao.kernel import intmm_triton
1111
else:
1212
intmm_triton = None
13-
except ImportError:
13+
except ImportError as e:
14+
print("import error:", e)
1415
# On cpu-only builds might not be available.
1516
intmm_triton = None
1617

@@ -56,7 +57,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
5657
and j_is_nonzero_multiple_of_8
5758
and k_is_nonzero_multiple_of_8
5859
)
59-
60+
6061
if device_cpu or bad_dimensions_for_cublas:
6162
# fallback path
6263
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
@@ -75,8 +76,8 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
7576
try:
7677
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
7778
except Exception:
78-
# fallback path, would run on H100 for float8 dtypes
79-
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
79+
# fallback path, would run on H100 for float8 dtypes
80+
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
8081
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
8182
else:
8283
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:

torchao/kernel/intmm_triton.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,50 @@
77
import triton.language as tl
88

99
from torchao.kernel.autotuner import get_best_config_fn
10+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1011

11-
int8_powers_of_two = [32, 64, 128, 256]
12-
int8_mm_kernel_configs = sum(
13-
[
14-
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
12+
# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option
13+
int8_mm_kernel_configs = (
14+
sum(
1515
[
16-
(i, j, k, 1, 1),
17-
(i, j, k, 1, 2),
18-
(i, j, k, 2, 2),
19-
(i, j, k, 1, 4),
20-
(i, j, k, 2, 4),
21-
(i, j, k, 3, 4),
22-
(i, j, k, 4, 4),
23-
(i, j, k, 1, 8),
24-
(i, j, k, 2, 8),
25-
(i, j, k, 3, 8),
26-
(i, j, k, 4, 8),
27-
(i, j, k, 5, 8),
28-
(i, j, k, 6, 8),
29-
(i, j, k, 7, 8),
30-
(i, j, k, 8, 8),
31-
]
32-
for (i, j, k) in itertools.product(
33-
int8_powers_of_two, int8_powers_of_two, int8_powers_of_two
34-
)
35-
],
36-
[],
16+
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
17+
[
18+
(i, j, k, 1, 1),
19+
(i, j, k, 1, 2),
20+
(i, j, k, 2, 2),
21+
(i, j, k, 1, 4),
22+
(i, j, k, 2, 4),
23+
(i, j, k, 3, 4),
24+
(i, j, k, 4, 4),
25+
(i, j, k, 1, 8),
26+
(i, j, k, 2, 8),
27+
(i, j, k, 3, 8),
28+
(i, j, k, 4, 8),
29+
(i, j, k, 5, 8),
30+
(i, j, k, 6, 8),
31+
(i, j, k, 7, 8),
32+
(i, j, k, 8, 8),
33+
]
34+
for (i, j, k) in itertools.product(
35+
[32, 64, 128, 256], repeat=3
36+
)
37+
],
38+
[]
39+
)
3740
)
3841

42+
if TORCH_VERSION_AFTER_2_5:
43+
if torch._inductor.config.max_autotune_gemm_search_space == "EXHAUSTIVE":
44+
int8_mm_kernel_configs = [
45+
(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
46+
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
47+
[16, 32, 64, 128, 256], repeat=3
48+
)
49+
for num_stages in [1, 2, 3, 4, 5, 6, 7, 8]
50+
for num_warps in [2, 4, 8]
51+
]
52+
53+
3954
# Baseline configs from pytorch/pytorch
4055
# https://github.com/pytorch/pytorch/blob/7718a1cd4f8e0b794c18a31ebd6353d6273c534e/torch/_inductor/kernel/mm_common.py#L132-L147
4156
# int8_mm_kernel_configs = [

torchao/prototype/quantization/autoquant_v2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
TORCH_VERSION_AT_LEAST_2_3,
3232
TORCH_VERSION_AT_LEAST_2_5,
3333
TorchAOBaseTensor,
34+
is_sm_at_least_89,
35+
is_sm_at_least_90,
3436
)
3537

3638
from torchao.quantization.granularity import (
@@ -63,6 +65,7 @@
6365
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
6466
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
6567
"OTHER_AUTOQUANT_CLASS_LIST",
68+
"ALL_AUTOQUANT_CLASS_LIST",
6669
"_is_linear",
6770
]
6871

@@ -1087,6 +1090,13 @@ def get_weight_block_size(x):
10871090
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
10881091
]
10891092

1093+
ALL_AUTOQUANT_CLASS_LIST = list(set(DEFAULT_AUTOQUANT_CLASS_LIST + DEFAULT_INT4_AUTOQUANT_CLASS_LIST + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST))
1094+
if is_sm_at_least_89():
1095+
ALL_AUTOQUANT_CLASS_LIST += [AQFloat8WeightOnlyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight]
1096+
1097+
if is_sm_at_least_90():
1098+
ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight]
1099+
10901100

10911101
def _replace_with_custom_fn_if_matches_filter(
10921102
model,

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from .autoquant import (
13+
ALL_AUTOQUANT_CLASS_LIST,
1314
DEFAULT_AUTOQUANT_CLASS_LIST,
1415
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1516
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
@@ -92,6 +93,7 @@
9293
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
9394
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
9495
"OTHER_AUTOQUANT_CLASS_LIST",
96+
"ALL_AUTOQUANT_CLASS_LIST",
9597
# top level API - manual
9698
"quantize_",
9799
"int8_dynamic_activation_int4_weight",

torchao/quantization/autoquant.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
TORCH_VERSION_AT_LEAST_2_3,
2727
TORCH_VERSION_AT_LEAST_2_5,
2828
TorchAOBaseTensor,
29+
is_sm_at_least_89,
30+
is_sm_at_least_90,
2931
)
3032

3133
from .granularity import (
@@ -45,6 +47,7 @@
4547
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
4648
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
4749
"OTHER_AUTOQUANT_CLASS_LIST",
50+
"ALL_AUTOQUANT_CLASS_LIST",
4851
]
4952

5053

@@ -951,6 +954,22 @@ def get_weight_block_size(x):
951954
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
952955
]
953956

957+
ALL_AUTOQUANT_CLASS_LIST = list(
958+
set(
959+
DEFAULT_AUTOQUANT_CLASS_LIST
960+
+ DEFAULT_INT4_AUTOQUANT_CLASS_LIST
961+
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
962+
)
963+
)
964+
if is_sm_at_least_89():
965+
ALL_AUTOQUANT_CLASS_LIST += [
966+
AQFloat8WeightOnlyQuantizedLinearWeight,
967+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
968+
]
969+
970+
if is_sm_at_least_90():
971+
ALL_AUTOQUANT_CLASS_LIST += [AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight]
972+
954973

955974
def _change_linears_to_autoquantizable(model, **kwargs):
956975
"""

0 commit comments

Comments
 (0)