|
1 | 1 | import functools
|
| 2 | +import itertools |
2 | 3 | import logging
|
3 | 4 | from typing import cast, List, Tuple
|
4 | 5 |
|
@@ -113,39 +114,50 @@ def filtered_configs(
|
113 | 114 |
|
114 | 115 |
|
115 | 116 | # List of dictionaries to store the kernel configs. Configs that evaluate to true
|
116 |
| -# will be utilised on the target platform |
117 |
| -mm_kernel_configs = [ |
118 |
| - # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps" |
119 |
| - {"config": (16, 32, 16, 3, 2), "cond": True}, |
120 |
| - {"config": (16, 32, 32, 4, 2), "cond": True}, |
121 |
| - {"config": (16, 32, 32, 5, 2), "cond": True}, |
122 |
| - {"config": (32, 32, 16, 1, 2), "cond": True}, |
123 |
| - {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, |
124 |
| - {"config": (32, 64, 32, 5, 8), "cond": True}, |
125 |
| - {"config": (64, 32, 32, 5, 8), "cond": True}, |
126 |
| - {"config": (64, 32, 128, 5, 4), "cond": True}, |
127 |
| - {"config": (64, 64, 16, 2, 4), "cond": True}, |
128 |
| - {"config": (64, 64, 32, 2, 4), "cond": True}, |
129 |
| - {"config": (64, 64, 64, 3, 8), "cond": True}, |
130 |
| - {"config": (64, 64, 128, 3, 4), "cond": True}, |
131 |
| - {"config": (64, 64, 128, 5, 4), "cond": True}, |
132 |
| - {"config": (64, 128, 32, 3, 4), "cond": True}, |
133 |
| - {"config": (64, 128, 32, 4, 8), "cond": True}, |
134 |
| - {"config": (64, 128, 64, 4, 4), "cond": True}, |
135 |
| - {"config": (64, 128, 128, 4, 4), "cond": True}, |
136 |
| - {"config": (128, 64, 32, 2, 2), "cond": True}, |
137 |
| - {"config": (128, 64, 32, 3, 4), "cond": True}, |
138 |
| - {"config": (128, 64, 32, 4, 8), "cond": True}, |
139 |
| - {"config": (128, 64, 64, 3, 8), "cond": True}, |
140 |
| - {"config": (128, 64, 128, 4, 8), "cond": True}, |
141 |
| - {"config": (128, 128, 32, 2, 8), "cond": True}, |
142 |
| - {"config": (128, 128, 32, 3, 4), "cond": True}, |
143 |
| - {"config": (128, 128, 32, 4, 4), "cond": True}, |
144 |
| - {"config": (128, 128, 64, 3, 4), "cond": True}, |
145 |
| - {"config": (128, 128, 64, 3, 8), "cond": True}, |
146 |
| - {"config": (128, 128, 64, 5, 4), "cond": True}, |
147 |
| - {"config": (128, 128, 64, 5, 8), "cond": True}, |
148 |
| -] |
| 117 | +# will be utilised on the target platform. The configs are as follows: |
| 118 | +# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) |
| 119 | +mm_kernel_configs = ( |
| 120 | + [ |
| 121 | + {"config": (16, 32, 16, 3, 2), "cond": True}, |
| 122 | + {"config": (16, 32, 32, 4, 2), "cond": True}, |
| 123 | + {"config": (16, 32, 32, 5, 2), "cond": True}, |
| 124 | + {"config": (32, 32, 16, 1, 2), "cond": True}, |
| 125 | + {"config": (32, 32, 128, 2, 4), "cond": torch.version.hip is None}, |
| 126 | + {"config": (32, 64, 32, 5, 8), "cond": True}, |
| 127 | + {"config": (64, 32, 32, 5, 8), "cond": True}, |
| 128 | + {"config": (64, 32, 128, 5, 4), "cond": True}, |
| 129 | + {"config": (64, 64, 16, 2, 4), "cond": True}, |
| 130 | + {"config": (64, 64, 32, 2, 4), "cond": True}, |
| 131 | + {"config": (64, 64, 64, 3, 8), "cond": True}, |
| 132 | + {"config": (64, 64, 128, 3, 4), "cond": True}, |
| 133 | + {"config": (64, 64, 128, 5, 4), "cond": True}, |
| 134 | + {"config": (64, 128, 32, 3, 4), "cond": True}, |
| 135 | + {"config": (64, 128, 32, 4, 8), "cond": True}, |
| 136 | + {"config": (64, 128, 64, 4, 4), "cond": True}, |
| 137 | + {"config": (64, 128, 128, 4, 4), "cond": True}, |
| 138 | + {"config": (128, 64, 32, 2, 2), "cond": True}, |
| 139 | + {"config": (128, 64, 32, 3, 4), "cond": True}, |
| 140 | + {"config": (128, 64, 32, 4, 8), "cond": True}, |
| 141 | + {"config": (128, 64, 64, 3, 8), "cond": True}, |
| 142 | + {"config": (128, 64, 128, 4, 8), "cond": True}, |
| 143 | + {"config": (128, 128, 32, 2, 8), "cond": True}, |
| 144 | + {"config": (128, 128, 32, 3, 4), "cond": True}, |
| 145 | + {"config": (128, 128, 32, 4, 4), "cond": True}, |
| 146 | + {"config": (128, 128, 64, 3, 4), "cond": True}, |
| 147 | + {"config": (128, 128, 64, 3, 8), "cond": True}, |
| 148 | + {"config": (128, 128, 64, 5, 4), "cond": True}, |
| 149 | + {"config": (128, 128, 64, 5, 8), "cond": True}, |
| 150 | + ] |
| 151 | + if inductor_config.max_autotune_gemm_search_space != "EXHAUSTIVE" |
| 152 | + else [ |
| 153 | + {"config": (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps), "cond": True} |
| 154 | + for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( |
| 155 | + [16, 32, 64, 128, 256], repeat=3 |
| 156 | + ) |
| 157 | + for num_stages in [1, 2, 3, 4, 5] |
| 158 | + for num_warps in [2, 4, 8] |
| 159 | + ] |
| 160 | +) |
149 | 161 |
|
150 | 162 | int8_mm_kernel_configs = [
|
151 | 163 | {"config": (64, 64, 32, 2, 4), "cond": True},
|
|
0 commit comments