-
Notifications
You must be signed in to change notification settings - Fork 317
Open
Description
With the new addition of INT8 mixed-precision training, there are now 2 implementations of scaled INT8 matmul (INT8 matmul + dequant)
- https://github.com/pytorch/ao/blob/main/torchao/kernel/intmm_triton.py
- https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mm.py
I have identified the key differences
intmm_triton.py |
int8_mm.py |
---|---|
Only fuse act scale | Fuse both act scale and weight scale |
Scale step is acc_i32 x scale |
Scale step is cast to fp32 acc_i32.to(f32) x scale.to(f32) |
Different autotune configs | Different autotune configs |
Ideally we should only keep 1. The tedious part is to validate there is no accuracy+speed regression, regardless of which final implementation we will adopt.
Here are the places that use intmm_triton.py
- https://github.com/pytorch/ao/blob/main/benchmarks/intmm.py
- https://github.com/pytorch/ao/blob/main/torchao/quantization/utils.py (I think this file is legacy from module-swap API?)
- https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py
- https://github.com/pytorch/ao/blob/main/test/kernel/test_autotuner.py
-> Basically ensure INT8 dynamic quantization for Llama and SAM benchmarks don't regress
Here are the places that use int8_mm.py
- https://github.com/pytorch/ao/blob/main/torchao/prototype/quantized_training/int8_mixed_precision.py
- https://github.com/pytorch/ao/blob/main/benchmarks/quantized_training/benchmark_int8mm.py (this is similar to the benchmark script for
intmm_triton.py
above
-> Ensure INT8 mixed-precision training doesn't regress
Another question. Is it ok to change int_scaled_matmul()
signature to accept scales for both A and B instead of only for A?
Metadata
Metadata
Assignees
Labels
No labels