Skip to content

Commit f55c0cc

Browse files
nmacchioniZelboK
authored andcommitted
variable search spaces for gemm autotuning (pytorch#126220)
add a switch to change the gemm autotuning search space between the default (the current set of hardcoded configs) and an exhaustive search space that enumerates all block sizes in [16, 32, 64, 128, 256], stages in [1, 2, 3, 4, 5], and warps in [2, 4, 6] Pull Request resolved: pytorch#126220 Approved by: https://github.com/eellison
1 parent 747bdea commit f55c0cc

File tree

2 files changed

+53
-33
lines changed

2 files changed

+53
-33
lines changed

torch/_inductor/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def is_fbcode():
232232
force_same_precision = (
233233
True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1"
234234
)
235+
235236
# Specify candidate backends for gemm autotune.
236237
# Possible choices are combinations of: ATen, Triton, CUTLASS.
237238
# ATen: default Pytorch ATen kernels.
@@ -241,6 +242,13 @@ def is_fbcode():
241242
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON"
242243
).upper()
243244

245+
# Specify the size of the search space for GEMM autotuning.
246+
# DEFAULT - balance between compile time overhead and performance
247+
# EXHAUSTIVE - maximize performance
248+
max_autotune_gemm_search_space = os.environ.get(
249+
"TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
250+
).upper()
251+
244252
# the value used as a fallback for the unbacked SymInts
245253
# that can appear in the input shapes (e.g., in autotuning)
246254
unbacked_symint_fallback = 8192

torch/_inductor/kernel/mm_common.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import itertools
23
import logging
34
from typing import cast, List, Tuple
45

@@ -113,39 +114,50 @@ def filtered_configs(
113114

114115

115116
# List of dictionaries to store the kernel configs. Configs that evaluate to true
116-
# will be utilised on the target platform
117-
mm_kernel_configs = [
118-
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
119-
{"config": (16, 32, 16, 3, 2), "cond": True},
120-
{"config": (16, 32, 32, 4, 2), "cond": True},
121-
{"config": (16, 32, 32, 5, 2), "cond": True},
122-
{"config": (32, 32, 16, 1, 2), "cond": True},
123-
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
124-
{"config": (32, 64, 32, 5, 8), "cond": True},
125-
{"config": (64, 32, 32, 5, 8), "cond": True},
126-
{"config": (64, 32, 128, 5, 4), "cond": True},
127-
{"config": (64, 64, 16, 2, 4), "cond": True},
128-
{"config": (64, 64, 32, 2, 4), "cond": True},
129-
{"config": (64, 64, 64, 3, 8), "cond": True},
130-
{"config": (64, 64, 128, 3, 4), "cond": True},
131-
{"config": (64, 64, 128, 5, 4), "cond": True},
132-
{"config": (64, 128, 32, 3, 4), "cond": True},
133-
{"config": (64, 128, 32, 4, 8), "cond": True},
134-
{"config": (64, 128, 64, 4, 4), "cond": True},
135-
{"config": (64, 128, 128, 4, 4), "cond": True},
136-
{"config": (128, 64, 32, 2, 2), "cond": True},
137-
{"config": (128, 64, 32, 3, 4), "cond": True},
138-
{"config": (128, 64, 32, 4, 8), "cond": True},
139-
{"config": (128, 64, 64, 3, 8), "cond": True},
140-
{"config": (128, 64, 128, 4, 8), "cond": True},
141-
{"config": (128, 128, 32, 2, 8), "cond": True},
142-
{"config": (128, 128, 32, 3, 4), "cond": True},
143-
{"config": (128, 128, 32, 4, 4), "cond": True},
144-
{"config": (128, 128, 64, 3, 4), "cond": True},
145-
{"config": (128, 128, 64, 3, 8), "cond": True},
146-
{"config": (128, 128, 64, 5, 4), "cond": True},
147-
{"config": (128, 128, 64, 5, 8), "cond": True},
148-
]
117+
# will be utilised on the target platform. The configs are as follows:
118+
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
119+
mm_kernel_configs = (
120+
[
121+
{"config": (16, 32, 16, 3, 2), "cond": True},
122+
{"config": (16, 32, 32, 4, 2), "cond": True},
123+
{"config": (16, 32, 32, 5, 2), "cond": True},
124+
{"config": (32, 32, 16, 1, 2), "cond": True},
125+
{"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None},
126+
{"config": (32, 64, 32, 5, 8), "cond": True},
127+
{"config": (64, 32, 32, 5, 8), "cond": True},
128+
{"config": (64, 32, 128, 5, 4), "cond": True},
129+
{"config": (64, 64, 16, 2, 4), "cond": True},
130+
{"config": (64, 64, 32, 2, 4), "cond": True},
131+
{"config": (64, 64, 64, 3, 8), "cond": True},
132+
{"config": (64, 64, 128, 3, 4), "cond": True},
133+
{"config": (64, 64, 128, 5, 4), "cond": True},
134+
{"config": (64, 128, 32, 3, 4), "cond": True},
135+
{"config": (64, 128, 32, 4, 8), "cond": True},
136+
{"config": (64, 128, 64, 4, 4), "cond": True},
137+
{"config": (64, 128, 128, 4, 4), "cond": True},
138+
{"config": (128, 64, 32, 2, 2), "cond": True},
139+
{"config": (128, 64, 32, 3, 4), "cond": True},
140+
{"config": (128, 64, 32, 4, 8), "cond": True},
141+
{"config": (128, 64, 64, 3, 8), "cond": True},
142+
{"config": (128, 64, 128, 4, 8), "cond": True},
143+
{"config": (128, 128, 32, 2, 8), "cond": True},
144+
{"config": (128, 128, 32, 3, 4), "cond": True},
145+
{"config": (128, 128, 32, 4, 4), "cond": True},
146+
{"config": (128, 128, 64, 3, 4), "cond": True},
147+
{"config": (128, 128, 64, 3, 8), "cond": True},
148+
{"config": (128, 128, 64, 5, 4), "cond": True},
149+
{"config": (128, 128, 64, 5, 8), "cond": True},
150+
]
151+
if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE"
152+
else [
153+
{"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True}
154+
for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product(
155+
[16, 32, 64, 128, 256], repeat=3
156+
)
157+
for num_stages in [1, 2, 3, 4, 5]
158+
for num_warps in [2, 4, 8]
159+
]
160+
)
149161

150162
int8_mm_kernel_configs = [
151163
{"config": (64, 64, 32, 2, 4), "cond": True},

0 commit comments

Comments
 (0)