Skip to content

Commit 4efb10f

Browse files
committed
wip MoE refactor
Summary: now that the pytorch grouped_mm kernels don't require padding, refactoring the moe implementation to use that rather than what was there before. DONE -implement moe with grouped_mm [x] -add handling for generic module swap to AOQuantizable (MoEMapping) [x] -refactor MoEQuantConfig to swap generic modules [x] TODO -add dispatch from grouped_mm to linear decomposition of quantized kernel -compare linear decomposition vs new linear decomposition vs grouped_mm for eager, compile, autotuned compile linear decomposition -compare linear decomposition vs new linear decomposition for quantized kernels -add scaled_group_gemm and fbgemm kernel (probably in a new PR) ISSUE: the autotuned grouped_mm kernels don't give the correct output, but then work in eager and compile with reduce-overhead. why? see new_run.log output, first 2 runs are fine, line 144 is nonsense Test Plan: sh run.sh Reviewers: Subscribers: Tasks: Tags:
1 parent 3b4bc98 commit 4efb10f

File tree

13 files changed

+866
-707
lines changed

13 files changed

+866
-707
lines changed

test/quantization/test_moe_quant.py

Lines changed: 139 additions & 138 deletions
Large diffs are not rendered by default.

torchao/_models/mixtral-moe/generate.py

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import torch._inductor.config
1515

1616
from torchao.utils import get_model_size_in_bytes
17+
from torchao.prototype.moe_quant import MoEFeedForwardAOQuantizable
18+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
19+
from model import MoEFeedForward
1720

1821
torch.manual_seed(0)
1922

@@ -199,7 +202,9 @@ def main(
199202
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
200203
compile: bool = True,
201204
compile_prefill: bool = False,
205+
compile_mode: str = "reduce-overhead",
202206
moe_quant: Optional[str] = None,
207+
decompose_grouped_mm: bool = False,
203208
profile: Optional[Path] = None,
204209
memory_profile: Optional[Path] = None,
205210
device="cuda",
@@ -212,6 +217,13 @@ def main(
212217
precision = torch.bfloat16
213218
is_chat = "chat" in str(checkpoint_path)
214219

220+
if batch_size > 1 and moe_quant is None:
221+
print(
222+
"Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched,"+
223+
" if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable"+
224+
"module without quantization to run the quantizable module without quantization"
225+
)
226+
215227
if device == "cuda" and memory_profile is not None:
216228
torch.cuda.memory._record_memory_history(
217229
True, trace_alloc_max_entries=500000, trace_alloc_record_context=True
@@ -236,10 +248,11 @@ def main(
236248
]
237249
)
238250

239-
from torchao.prototype.moe_quant.utils import (
251+
from torchao.prototype.moe_quant import (
240252
MoEQuantConfig,
253+
MoEMapping,
241254
UseFakeExtraDimTensor,
242-
cond_ffn_filter,
255+
MoEFeedForwardAOQuantizable,
243256
)
244257
from torchao.quantization.quant_api import (
245258
Float8DynamicActivationFloat8WeightConfig,
@@ -255,71 +268,61 @@ def main(
255268

256269
if moe_quant:
257270
torch._dynamo.config.capture_dynamic_output_shape_ops = True
258-
config = None
271+
config = MoEQuantConfig(mapping=MoEMapping(target_module_type=MoEFeedForward, decompose_grouped_mm=decompose_grouped_mm))
259272
if "int8wo-base" in moe_quant:
260-
config = MoEQuantConfig(Int8WeightOnlyConfig())
273+
config.base_config = Int8WeightOnlyConfig()
261274

262275
elif "int8wo" in moe_quant:
263-
config = MoEQuantConfig(
264-
Int8WeightOnlyConfig(),
265-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
266-
)
276+
config.base_config = Int8WeightOnlyConfig()
277+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
267278

268279
elif "int8dq-base" in moe_quant:
269-
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
280+
config.base_config = Int8DynamicActivationInt8WeightConfig()
270281

271282
elif "int8dq" in moe_quant:
272-
config = MoEQuantConfig(
273-
Int8DynamicActivationInt8WeightConfig(),
274-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
275-
)
283+
config.base_config = Int8DynamicActivationInt8WeightConfig()
284+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
276285

277286
elif "int4wo-base" in moe_quant:
278-
config = MoEQuantConfig(Int4WeightOnlyConfig())
287+
config.base_config = Int4WeightOnlyConfig()
279288

280289
elif "int4wo" in moe_quant:
281-
config = MoEQuantConfig(
282-
Int4WeightOnlyConfig(),
283-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
284-
)
290+
config.base_config = Int4WeightOnlyConfig()
291+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
285292

286293
elif "fp8wo-base" in moe_quant:
287-
config = MoEQuantConfig(Float8WeightOnlyConfig())
294+
config.base_config = Float8WeightOnlyConfig()
288295

289296
elif "fp8wo" in moe_quant:
290-
config = MoEQuantConfig(
291-
Float8WeightOnlyConfig(),
292-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
293-
)
297+
config.base_config = Float8WeightOnlyConfig()
298+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
294299

295300
elif "fp8dq-base" in moe_quant:
296-
config = MoEQuantConfig(
297-
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
298-
)
301+
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
299302

300303
elif "fp8dq" in moe_quant:
301-
config = MoEQuantConfig(
302-
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
303-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
304-
)
304+
config.base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
305+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
305306

306307
elif "intxdq" in moe_quant:
307-
config = MoEQuantConfig(
308-
Int8DynamicActivationIntxWeightConfig(
308+
config.base_config = Int8DynamicActivationIntxWeightConfig(
309309
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
310310
),
311-
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
312-
)
311+
config.use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
312+
elif "noquant" in moe_quant:
313+
pass
313314
else:
314315
assert config is not None, (
315316
f"expected moe_quant to match one of the options but got {moe_quant}"
316317
)
317318

318-
if config is not None:
319-
quantize_(model, config, filter_fn=cond_ffn_filter, device=device)
320-
print(
321-
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
322-
)
319+
def filter_fn(mod, fqn):
320+
return isinstance(mod, MoEFeedForward)
321+
322+
quantize_(model, config, filter_fn=filter_fn, device=device)
323+
print(
324+
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
325+
)
323326

324327
model.to(device=device)
325328
device_sync(device=device)
@@ -335,12 +338,12 @@ def main(
335338

336339
global decode_one_token, prefill
337340

338-
if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
341+
if not decompose_grouped_mm or (batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant)):
339342
decode_one_token = torch.compile(
340-
decode_one_token, mode="reduce-overhead", fullgraph=True
343+
decode_one_token, mode=compile_mode, fullgraph=True
341344
)
342345
else:
343-
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
346+
decode_one_token = torch.compile(decode_one_token, mode=compile_mode)
344347

345348
if args.compile_prefill:
346349
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
@@ -474,11 +477,22 @@ def callback(x):
474477
action="store_true",
475478
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
476479
)
477-
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
480+
parser.add_argument(
481+
"--compile_mode",
482+
type=str,
483+
default="reduce-overhead",
484+
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.",
485+
)
478486
parser.add_argument(
479487
"--moe_quant",
480488
type=str,
481-
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq",
489+
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant",
490+
)
491+
parser.add_argument(
492+
"--decompose_grouped_mm",
493+
action="store_true",
494+
default=False,
495+
help="Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set",
482496
)
483497
parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
484498
parser.add_argument(
@@ -499,7 +513,9 @@ def callback(x):
499513
args.checkpoint_path,
500514
args.compile,
501515
args.compile_prefill,
516+
args.compile_mode,
502517
args.moe_quant,
518+
args.decompose_grouped_mm,
503519
args.profile,
504520
args.memory_profile,
505521
args.device,

0 commit comments

Comments
 (0)