14
14
import torch ._inductor .config
15
15
16
16
from torchao .utils import get_model_size_in_bytes
17
+ from torchao .prototype .moe_quant import MoEFeedForwardAOQuantizable
18
+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
19
+ from model import MoEFeedForward
17
20
18
21
torch .manual_seed (0 )
19
22
@@ -199,7 +202,9 @@ def main(
199
202
checkpoint_path : Path = Path ("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth" ),
200
203
compile : bool = True ,
201
204
compile_prefill : bool = False ,
205
+ compile_mode : str = "reduce-overhead" ,
202
206
moe_quant : Optional [str ] = None ,
207
+ decompose_grouped_mm : bool = False ,
203
208
profile : Optional [Path ] = None ,
204
209
memory_profile : Optional [Path ] = None ,
205
210
device = "cuda" ,
@@ -212,6 +217,13 @@ def main(
212
217
precision = torch .bfloat16
213
218
is_chat = "chat" in str (checkpoint_path )
214
219
220
+ if batch_size > 1 and moe_quant is None :
221
+ print (
222
+ "Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched," +
223
+ " if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable" +
224
+ "module without quantization to run the quantizable module without quantization"
225
+ )
226
+
215
227
if device == "cuda" and memory_profile is not None :
216
228
torch .cuda .memory ._record_memory_history (
217
229
True , trace_alloc_max_entries = 500000 , trace_alloc_record_context = True
@@ -236,10 +248,11 @@ def main(
236
248
]
237
249
)
238
250
239
- from torchao .prototype .moe_quant . utils import (
251
+ from torchao .prototype .moe_quant import (
240
252
MoEQuantConfig ,
253
+ MoEMapping ,
241
254
UseFakeExtraDimTensor ,
242
- cond_ffn_filter ,
255
+ MoEFeedForwardAOQuantizable ,
243
256
)
244
257
from torchao .quantization .quant_api import (
245
258
Float8DynamicActivationFloat8WeightConfig ,
@@ -255,71 +268,61 @@ def main(
255
268
256
269
if moe_quant :
257
270
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
258
- config = None
271
+ config = MoEQuantConfig ( mapping = MoEMapping ( target_module_type = MoEFeedForward , decompose_grouped_mm = decompose_grouped_mm ))
259
272
if "int8wo-base" in moe_quant :
260
- config = MoEQuantConfig ( Int8WeightOnlyConfig () )
273
+ config . base_config = Int8WeightOnlyConfig ()
261
274
262
275
elif "int8wo" in moe_quant :
263
- config = MoEQuantConfig (
264
- Int8WeightOnlyConfig (),
265
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
266
- )
276
+ config .base_config = Int8WeightOnlyConfig ()
277
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
267
278
268
279
elif "int8dq-base" in moe_quant :
269
- config = MoEQuantConfig ( Int8DynamicActivationInt8WeightConfig () )
280
+ config . base_config = Int8DynamicActivationInt8WeightConfig ()
270
281
271
282
elif "int8dq" in moe_quant :
272
- config = MoEQuantConfig (
273
- Int8DynamicActivationInt8WeightConfig (),
274
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
275
- )
283
+ config .base_config = Int8DynamicActivationInt8WeightConfig ()
284
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
276
285
277
286
elif "int4wo-base" in moe_quant :
278
- config = MoEQuantConfig ( Int4WeightOnlyConfig () )
287
+ config . base_config = Int4WeightOnlyConfig ()
279
288
280
289
elif "int4wo" in moe_quant :
281
- config = MoEQuantConfig (
282
- Int4WeightOnlyConfig (),
283
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
284
- )
290
+ config .base_config = Int4WeightOnlyConfig ()
291
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
285
292
286
293
elif "fp8wo-base" in moe_quant :
287
- config = MoEQuantConfig ( Float8WeightOnlyConfig () )
294
+ config . base_config = Float8WeightOnlyConfig ()
288
295
289
296
elif "fp8wo" in moe_quant :
290
- config = MoEQuantConfig (
291
- Float8WeightOnlyConfig (),
292
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
293
- )
297
+ config .base_config = Float8WeightOnlyConfig ()
298
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
294
299
295
300
elif "fp8dq-base" in moe_quant :
296
- config = MoEQuantConfig (
297
- Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
298
- )
301
+ config .base_config = Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
299
302
300
303
elif "fp8dq" in moe_quant :
301
- config = MoEQuantConfig (
302
- Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()),
303
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
304
- )
304
+ config .base_config = Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
305
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
305
306
306
307
elif "intxdq" in moe_quant :
307
- config = MoEQuantConfig (
308
- Int8DynamicActivationIntxWeightConfig (
308
+ config .base_config = Int8DynamicActivationIntxWeightConfig (
309
309
layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
310
310
),
311
- use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE ,
312
- )
311
+ config .use_fake_extra_dim_tensor = UseFakeExtraDimTensor .TRUE
312
+ elif "noquant" in moe_quant :
313
+ pass
313
314
else :
314
315
assert config is not None , (
315
316
f"expected moe_quant to match one of the options but got { moe_quant } "
316
317
)
317
318
318
- if config is not None :
319
- quantize_ (model , config , filter_fn = cond_ffn_filter , device = device )
320
- print (
321
- f"Time to apply quantization with config { config } to model: { time .time () - t0 :.02f} seconds"
322
- )
319
+ def filter_fn (mod , fqn ):
320
+ return isinstance (mod , MoEFeedForward )
321
+
322
+ quantize_ (model , config , filter_fn = filter_fn , device = device )
323
+ print (
324
+ f"Time to apply quantization with config { config } to model: { time .time () - t0 :.02f} seconds"
325
+ )
323
326
324
327
model .to (device = device )
325
328
device_sync (device = device )
@@ -335,12 +338,12 @@ def main(
335
338
336
339
global decode_one_token , prefill
337
340
338
- if batch_size == 1 and (isinstance (moe_quant , str ) and "base" in moe_quant ):
341
+ if not decompose_grouped_mm or ( batch_size == 1 and (isinstance (moe_quant , str ) and "base" in moe_quant ) ):
339
342
decode_one_token = torch .compile (
340
- decode_one_token , mode = "reduce-overhead" , fullgraph = True
343
+ decode_one_token , mode = compile_mode , fullgraph = True
341
344
)
342
345
else :
343
- decode_one_token = torch .compile (decode_one_token , mode = "reduce-overhead" )
346
+ decode_one_token = torch .compile (decode_one_token , mode = compile_mode )
344
347
345
348
if args .compile_prefill :
346
349
prefill = torch .compile (prefill , fullgraph = True , dynamic = True )
@@ -474,11 +477,22 @@ def callback(x):
474
477
action = "store_true" ,
475
478
help = "Whether to compile the prefill (improves prefill perf, but higher compile times)" ,
476
479
)
477
- # parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
480
+ parser .add_argument (
481
+ "--compile_mode" ,
482
+ type = str ,
483
+ default = "reduce-overhead" ,
484
+ help = "which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set." ,
485
+ )
478
486
parser .add_argument (
479
487
"--moe_quant" ,
480
488
type = str ,
481
- help = "Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq" ,
489
+ help = "Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant" ,
490
+ )
491
+ parser .add_argument (
492
+ "--decompose_grouped_mm" ,
493
+ action = "store_true" ,
494
+ default = False ,
495
+ help = "Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set" ,
482
496
)
483
497
parser .add_argument ("--profile" , type = Path , default = None , help = "Profile path." )
484
498
parser .add_argument (
@@ -499,7 +513,9 @@ def callback(x):
499
513
args .checkpoint_path ,
500
514
args .compile ,
501
515
args .compile_prefill ,
516
+ args .compile_mode ,
502
517
args .moe_quant ,
518
+ args .decompose_grouped_mm ,
503
519
args .profile ,
504
520
args .memory_profile ,
505
521
args .device ,
0 commit comments