@@ -116,43 +116,48 @@ def filtered_configs(
116
116
# List of dictionaries to store the kernel configs. Configs that evaluate to true
117
117
# will be utilised on the target platform. The configs are as follows:
118
118
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
119
- mm_kernel_configs = [
120
- # "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
121
- {"config" : (16 , 32 , 16 , 3 , 2 ), "cond" : True },
122
- {"config" : (16 , 32 , 32 , 4 , 2 ), "cond" : True },
123
- {"config" : (16 , 32 , 32 , 5 , 2 ), "cond" : True },
124
- {"config" : (32 , 32 , 16 , 1 , 2 ), "cond" : True },
125
- {"config" : (32 , 32 , 128 , 2 , 4 ), "cond" : torch .version .hip is None },
126
- {"config" : (32 , 64 , 32 , 5 , 8 ), "cond" : True },
127
- {"config" : (64 , 32 , 32 , 5 , 8 ), "cond" : True },
128
- {"config" : (64 , 32 , 128 , 5 , 4 ), "cond" : True },
129
- {"config" : (64 , 64 , 16 , 2 , 4 ), "cond" : True },
130
- {"config" : (64 , 64 , 32 , 2 , 4 ), "cond" : True },
131
- {"config" : (64 , 64 , 64 , 3 , 8 ), "cond" : True },
132
- {"config" : (64 , 64 , 128 , 3 , 4 ), "cond" : True },
133
- {"config" : (64 , 64 , 128 , 5 , 4 ), "cond" : True },
134
- {"config" : (64 , 128 , 32 , 3 , 4 ), "cond" : True },
135
- {"config" : (64 , 128 , 32 , 4 , 8 ), "cond" : True },
136
- {"config" : (64 , 128 , 64 , 4 , 4 ), "cond" : True },
137
- {"config" : (64 , 128 , 128 , 4 , 4 ), "cond" : True },
138
- {"config" : (128 , 64 , 32 , 2 , 2 ), "cond" : True },
139
- {"config" : (128 , 64 , 32 , 3 , 4 ), "cond" : True },
140
- {"config" : (128 , 64 , 32 , 4 , 8 ), "cond" : True },
141
- {"config" : (128 , 64 , 64 , 3 , 8 ), "cond" : True },
142
- {"config" : (128 , 64 , 128 , 4 , 8 ), "cond" : True },
143
- {"config" : (128 , 128 , 32 , 2 , 8 ), "cond" : True },
144
- {"config" : (128 , 128 , 32 , 3 , 4 ), "cond" : True },
145
- {"config" : (128 , 128 , 32 , 4 , 4 ), "cond" : True },
146
- {"config" : (128 , 128 , 64 , 3 , 4 ), "cond" : True },
147
- {"config" : (128 , 128 , 64 , 3 , 8 ), "cond" : True },
148
- {"config" : (128 , 128 , 64 , 5 , 4 ), "cond" : True },
149
- {"config" : (128 , 128 , 64 , 5 , 8 ), "cond" : True },
150
- ] if inductor_config .max_autotune_gemm_search_space != "EXHAUSTIVE" else [
151
- {"config" : (BLOCK_M , BLOCK_N , BLOCK_K , num_stages , num_warps ), "cond" : True }
152
- for BLOCK_M , BLOCK_N , BLOCK_K in itertools .product (
153
- [16 , 32 , 64 , 128 , 256 ], repeat = 3
154
- ) for num_stages in [1 , 2 , 3 , 4 , 5 ] for num_warps in [2 , 4 , 8 ]
155
- ]
119
+ mm_kernel_configs = (
120
+ [
121
+ {"config" : (16 , 32 , 16 , 3 , 2 ), "cond" : True },
122
+ {"config" : (16 , 32 , 32 , 4 , 2 ), "cond" : True },
123
+ {"config" : (16 , 32 , 32 , 5 , 2 ), "cond" : True },
124
+ {"config" : (32 , 32 , 16 , 1 , 2 ), "cond" : True },
125
+ {"config" : (32 , 32 , 128 , 2 , 4 ), "cond" : torch .version .hip is None },
126
+ {"config" : (32 , 64 , 32 , 5 , 8 ), "cond" : True },
127
+ {"config" : (64 , 32 , 32 , 5 , 8 ), "cond" : True },
128
+ {"config" : (64 , 32 , 128 , 5 , 4 ), "cond" : True },
129
+ {"config" : (64 , 64 , 16 , 2 , 4 ), "cond" : True },
130
+ {"config" : (64 , 64 , 32 , 2 , 4 ), "cond" : True },
131
+ {"config" : (64 , 64 , 64 , 3 , 8 ), "cond" : True },
132
+ {"config" : (64 , 64 , 128 , 3 , 4 ), "cond" : True },
133
+ {"config" : (64 , 64 , 128 , 5 , 4 ), "cond" : True },
134
+ {"config" : (64 , 128 , 32 , 3 , 4 ), "cond" : True },
135
+ {"config" : (64 , 128 , 32 , 4 , 8 ), "cond" : True },
136
+ {"config" : (64 , 128 , 64 , 4 , 4 ), "cond" : True },
137
+ {"config" : (64 , 128 , 128 , 4 , 4 ), "cond" : True },
138
+ {"config" : (128 , 64 , 32 , 2 , 2 ), "cond" : True },
139
+ {"config" : (128 , 64 , 32 , 3 , 4 ), "cond" : True },
140
+ {"config" : (128 , 64 , 32 , 4 , 8 ), "cond" : True },
141
+ {"config" : (128 , 64 , 64 , 3 , 8 ), "cond" : True },
142
+ {"config" : (128 , 64 , 128 , 4 , 8 ), "cond" : True },
143
+ {"config" : (128 , 128 , 32 , 2 , 8 ), "cond" : True },
144
+ {"config" : (128 , 128 , 32 , 3 , 4 ), "cond" : True },
145
+ {"config" : (128 , 128 , 32 , 4 , 4 ), "cond" : True },
146
+ {"config" : (128 , 128 , 64 , 3 , 4 ), "cond" : True },
147
+ {"config" : (128 , 128 , 64 , 3 , 8 ), "cond" : True },
148
+ {"config" : (128 , 128 , 64 , 5 , 4 ), "cond" : True },
149
+ {"config" : (128 , 128 , 64 , 5 , 8 ), "cond" : True },
150
+ ]
151
+ if inductor_config .max_autotune_gemm_search_space != "EXHAUSTIVE"
152
+ else [
153
+ {"config" : (BLOCK_M , BLOCK_N , BLOCK_K , num_stages , num_warps ), "cond" : True }
154
+ for BLOCK_M , BLOCK_N , BLOCK_K in itertools .product (
155
+ [16 , 32 , 64 , 128 , 256 ], repeat = 3
156
+ )
157
+ for num_stages in [1 , 2 , 3 , 4 , 5 ]
158
+ for num_warps in [2 , 4 , 8 ]
159
+ ]
160
+ )
156
161
157
162
int8_mm_kernel_configs = [
158
163
{"config" : (64 , 64 , 32 , 2 , 4 ), "cond" : True },
0 commit comments