Skip to content

MoE refactor to use grouped_mm and scaled_grouped_mm #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 139 additions & 138 deletions test/quantization/test_moe_quant.py

Large diffs are not rendered by default.

106 changes: 66 additions & 40 deletions torchao/_models/mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.config
import torch._inductor.config
from model import MoEFeedForward

from torchao.utils import get_model_size_in_bytes

Expand Down Expand Up @@ -199,7 +200,9 @@ def main(
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"),
compile: bool = True,
compile_prefill: bool = False,
compile_mode: str = "reduce-overhead",
moe_quant: Optional[str] = None,
decompose_grouped_mm: bool = False,
profile: Optional[Path] = None,
memory_profile: Optional[Path] = None,
device="cuda",
Expand All @@ -212,6 +215,13 @@ def main(
precision = torch.bfloat16
is_chat = "chat" in str(checkpoint_path)

if batch_size > 1 and moe_quant is None:
print(
"Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched,"
+ " if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable"
+ "module without quantization to run the quantizable module without quantization"
)

if device == "cuda" and memory_profile is not None:
torch.cuda.memory._record_memory_history(
True, trace_alloc_max_entries=500000, trace_alloc_record_context=True
Expand All @@ -236,10 +246,10 @@ def main(
]
)

from torchao.prototype.moe_quant.utils import (
from torchao.prototype.moe_quant import (
MoEMapping,
MoEQuantConfig,
UseFakeExtraDimTensor,
cond_ffn_filter,
)
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Expand All @@ -255,71 +265,72 @@ def main(

if moe_quant:
torch._dynamo.config.capture_dynamic_output_shape_ops = True
config = None
config = MoEQuantConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor point but the training config is called MoETrainingConfig so it would be more clear to call this MoEInferenceConfig, to remove any potential ambiguity present in the name MoEQuantConfig

mapping=MoEMapping(
target_module_type=MoEFeedForward,
decompose_grouped_mm=decompose_grouped_mm,
)
)
if "int8wo-base" in moe_quant:
config = MoEQuantConfig(Int8WeightOnlyConfig())
config.base_config = Int8WeightOnlyConfig()

elif "int8wo" in moe_quant:
config = MoEQuantConfig(
Int8WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int8WeightOnlyConfig()
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE

elif "int8dq-base" in moe_quant:
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig())
config.base_config = Int8DynamicActivationInt8WeightConfig()

elif "int8dq" in moe_quant:
config = MoEQuantConfig(
Int8DynamicActivationInt8WeightConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int8DynamicActivationInt8WeightConfig()
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE

elif "int4wo-base" in moe_quant:
config = MoEQuantConfig(Int4WeightOnlyConfig())
config.base_config = Int4WeightOnlyConfig()

elif "int4wo" in moe_quant:
config = MoEQuantConfig(
Int4WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Int4WeightOnlyConfig()
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE

elif "fp8wo-base" in moe_quant:
config = MoEQuantConfig(Float8WeightOnlyConfig())
config.base_config = Float8WeightOnlyConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally focus on float8 dynamic quant for now, it's more important to get good perf + accuracy on a small set of techniques than to achieve broad coverage of techniques.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I have done that, I thought we'd be able to do dispatch grouped_mm -> linear -> quantized linear for hte rest of the tensor subclasses but it doesn't work so i've gone back to what it was initially and this PR mostly focuses on fp8dq and bf16


elif "fp8wo" in moe_quant:
config = MoEQuantConfig(
Float8WeightOnlyConfig(),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.base_config = Float8WeightOnlyConfig()
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE

elif "fp8dq-base" in moe_quant:
config = MoEQuantConfig(
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
config.base_config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow()
)

elif "fp8dq" in moe_quant:
config = MoEQuantConfig(
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
config.base_config = Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow()
)
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE

elif "intxdq" in moe_quant:
config = MoEQuantConfig(
config.base_config = (
Int8DynamicActivationIntxWeightConfig(
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE
elif "noquant" in moe_quant:
pass
else:
assert config is not None, (
f"expected moe_quant to match one of the options but got {moe_quant}"
)

if config is not None:
quantize_(model, config, filter_fn=cond_ffn_filter, device=device)
print(
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
)
def filter_fn(mod, fqn):
return isinstance(mod, MoEFeedForward)

quantize_(model, config, filter_fn=filter_fn, device=device)
print(
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds"
)

model.to(device=device)
device_sync(device=device)
Expand All @@ -335,12 +346,14 @@ def main(

global decode_one_token, prefill

if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant):
if not decompose_grouped_mm or (
batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant)
):
decode_one_token = torch.compile(
decode_one_token, mode="reduce-overhead", fullgraph=True
decode_one_token, mode=compile_mode, fullgraph=True
)
else:
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead")
decode_one_token = torch.compile(decode_one_token, mode=compile_mode)

if args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
Expand Down Expand Up @@ -474,11 +487,22 @@ def callback(x):
action="store_true",
help="Whether to compile the prefill (improves prefill perf, but higher compile times)",
)
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8')
parser.add_argument(
"--compile_mode",
type=str,
default="reduce-overhead",
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we expect a meaningful % of users to use max-autotune? if not, I'd say p1 and do it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic added it so i wanted to test and see the speedup, at least for these shapes it doesn't make a huge difference so its not hugely important.

)
parser.add_argument(
"--moe_quant",
type=str,
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq",
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant",
)
parser.add_argument(
"--decompose_grouped_mm",
action="store_true",
default=False,
help="Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set",
)
parser.add_argument("--profile", type=Path, default=None, help="Profile path.")
parser.add_argument(
Expand All @@ -499,7 +523,9 @@ def callback(x):
args.checkpoint_path,
args.compile,
args.compile_prefill,
args.compile_mode,
args.moe_quant,
args.decompose_grouped_mm,
args.profile,
args.memory_profile,
args.device,
Expand Down
Loading
Loading