-
Notifications
You must be signed in to change notification settings - Fork 317
MoE refactor to use grouped_mm and scaled_grouped_mm #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f0cb3bb
2b187ad
200d862
96cb29b
0e001d0
9bb0ad1
eda28db
2fc42cd
f9f6c9f
bb597cd
5109a26
2bb805b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
import torch | ||
import torch._dynamo.config | ||
import torch._inductor.config | ||
from model import MoEFeedForward | ||
|
||
from torchao.utils import get_model_size_in_bytes | ||
|
||
|
@@ -199,7 +200,9 @@ def main( | |
checkpoint_path: Path = Path("checkpoints/mistralai/Mixtral-8x7B-v0.1/model.pth"), | ||
compile: bool = True, | ||
compile_prefill: bool = False, | ||
compile_mode: str = "reduce-overhead", | ||
moe_quant: Optional[str] = None, | ||
decompose_grouped_mm: bool = False, | ||
profile: Optional[Path] = None, | ||
memory_profile: Optional[Path] = None, | ||
device="cuda", | ||
|
@@ -212,6 +215,13 @@ def main( | |
precision = torch.bfloat16 | ||
is_chat = "chat" in str(checkpoint_path) | ||
|
||
if batch_size > 1 and moe_quant is None: | ||
print( | ||
"Warning: Detected no moe_quant but batchsize>1. The default MoE implementation uses a lot of memory when batched," | ||
+ " if it OOMs you can instead run without quantization by specifying --moe_quant noquant which uses the AO quantizable" | ||
+ "module without quantization to run the quantizable module without quantization" | ||
) | ||
|
||
if device == "cuda" and memory_profile is not None: | ||
torch.cuda.memory._record_memory_history( | ||
True, trace_alloc_max_entries=500000, trace_alloc_record_context=True | ||
|
@@ -236,10 +246,10 @@ def main( | |
] | ||
) | ||
|
||
from torchao.prototype.moe_quant.utils import ( | ||
from torchao.prototype.moe_quant import ( | ||
MoEMapping, | ||
MoEQuantConfig, | ||
UseFakeExtraDimTensor, | ||
cond_ffn_filter, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
|
@@ -255,71 +265,72 @@ def main( | |
|
||
if moe_quant: | ||
torch._dynamo.config.capture_dynamic_output_shape_ops = True | ||
config = None | ||
config = MoEQuantConfig( | ||
mapping=MoEMapping( | ||
target_module_type=MoEFeedForward, | ||
decompose_grouped_mm=decompose_grouped_mm, | ||
) | ||
) | ||
if "int8wo-base" in moe_quant: | ||
config = MoEQuantConfig(Int8WeightOnlyConfig()) | ||
config.base_config = Int8WeightOnlyConfig() | ||
|
||
elif "int8wo" in moe_quant: | ||
config = MoEQuantConfig( | ||
Int8WeightOnlyConfig(), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.base_config = Int8WeightOnlyConfig() | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
|
||
elif "int8dq-base" in moe_quant: | ||
config = MoEQuantConfig(Int8DynamicActivationInt8WeightConfig()) | ||
config.base_config = Int8DynamicActivationInt8WeightConfig() | ||
|
||
elif "int8dq" in moe_quant: | ||
config = MoEQuantConfig( | ||
Int8DynamicActivationInt8WeightConfig(), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.base_config = Int8DynamicActivationInt8WeightConfig() | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
|
||
elif "int4wo-base" in moe_quant: | ||
config = MoEQuantConfig(Int4WeightOnlyConfig()) | ||
config.base_config = Int4WeightOnlyConfig() | ||
|
||
elif "int4wo" in moe_quant: | ||
config = MoEQuantConfig( | ||
Int4WeightOnlyConfig(), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.base_config = Int4WeightOnlyConfig() | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
|
||
elif "fp8wo-base" in moe_quant: | ||
config = MoEQuantConfig(Float8WeightOnlyConfig()) | ||
config.base_config = Float8WeightOnlyConfig() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd personally focus on float8 dynamic quant for now, it's more important to get good perf + accuracy on a small set of techniques than to achieve broad coverage of techniques. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I have done that, I thought we'd be able to do dispatch grouped_mm -> linear -> quantized linear for hte rest of the tensor subclasses but it doesn't work so i've gone back to what it was initially and this PR mostly focuses on fp8dq and bf16 |
||
|
||
elif "fp8wo" in moe_quant: | ||
config = MoEQuantConfig( | ||
Float8WeightOnlyConfig(), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.base_config = Float8WeightOnlyConfig() | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
|
||
elif "fp8dq-base" in moe_quant: | ||
config = MoEQuantConfig( | ||
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) | ||
config.base_config = Float8DynamicActivationFloat8WeightConfig( | ||
granularity=PerRow() | ||
) | ||
|
||
elif "fp8dq" in moe_quant: | ||
config = MoEQuantConfig( | ||
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
config.base_config = Float8DynamicActivationFloat8WeightConfig( | ||
granularity=PerRow() | ||
) | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
|
||
elif "intxdq" in moe_quant: | ||
config = MoEQuantConfig( | ||
config.base_config = ( | ||
Int8DynamicActivationIntxWeightConfig( | ||
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), | ||
), | ||
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE, | ||
) | ||
config.use_fake_extra_dim_tensor = UseFakeExtraDimTensor.TRUE | ||
elif "noquant" in moe_quant: | ||
pass | ||
else: | ||
assert config is not None, ( | ||
f"expected moe_quant to match one of the options but got {moe_quant}" | ||
) | ||
|
||
if config is not None: | ||
quantize_(model, config, filter_fn=cond_ffn_filter, device=device) | ||
print( | ||
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" | ||
) | ||
def filter_fn(mod, fqn): | ||
return isinstance(mod, MoEFeedForward) | ||
|
||
quantize_(model, config, filter_fn=filter_fn, device=device) | ||
print( | ||
f"Time to apply quantization with config {config} to model: {time.time() - t0:.02f} seconds" | ||
) | ||
|
||
model.to(device=device) | ||
device_sync(device=device) | ||
|
@@ -335,12 +346,14 @@ def main( | |
|
||
global decode_one_token, prefill | ||
|
||
if batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant): | ||
if not decompose_grouped_mm or ( | ||
batch_size == 1 and (isinstance(moe_quant, str) and "base" in moe_quant) | ||
): | ||
decode_one_token = torch.compile( | ||
decode_one_token, mode="reduce-overhead", fullgraph=True | ||
decode_one_token, mode=compile_mode, fullgraph=True | ||
) | ||
else: | ||
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead") | ||
decode_one_token = torch.compile(decode_one_token, mode=compile_mode) | ||
|
||
if args.compile_prefill: | ||
prefill = torch.compile(prefill, fullgraph=True, dynamic=True) | ||
|
@@ -474,11 +487,22 @@ def callback(x): | |
action="store_true", | ||
help="Whether to compile the prefill (improves prefill perf, but higher compile times)", | ||
) | ||
# parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8') | ||
parser.add_argument( | ||
"--compile_mode", | ||
type=str, | ||
default="reduce-overhead", | ||
help="which torch.compile mode to use: reduce-overhead or max-autotune, does nothing if --compile is not set.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we expect a meaningful % of users to use max-autotune? if not, I'd say p1 and do it later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alexsamardzic added it so i wanted to test and see the speedup, at least for these shapes it doesn't make a huge difference so its not hugely important. |
||
) | ||
parser.add_argument( | ||
"--moe_quant", | ||
type=str, | ||
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq", | ||
help="Which quantization techniques to apply: int8dq, int8wo, int4wo, fp8wo, fp8dq, noquant", | ||
) | ||
parser.add_argument( | ||
"--decompose_grouped_mm", | ||
action="store_true", | ||
default=False, | ||
help="Whether to decompose grouped_mm into linear ops for the MoE module, only relevant when moe_quant is set", | ||
) | ||
parser.add_argument("--profile", type=Path, default=None, help="Profile path.") | ||
parser.add_argument( | ||
|
@@ -499,7 +523,9 @@ def callback(x): | |
args.checkpoint_path, | ||
args.compile, | ||
args.compile_prefill, | ||
args.compile_mode, | ||
args.moe_quant, | ||
args.decompose_grouped_mm, | ||
args.profile, | ||
args.memory_profile, | ||
args.device, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor point but the training config is called
MoETrainingConfig
so it would be more clear to call thisMoEInferenceConfig
, to remove any potential ambiguity present in the nameMoEQuantConfig