diff --git a/torchao/_models/llama/README.md b/torchao/_models/llama/README.md index 466901e74f..99f1919fc9 100644 --- a/torchao/_models/llama/README.md +++ b/torchao/_models/llama/README.md @@ -27,3 +27,7 @@ To see how these techniques scale generally we've run `generate.py` with subsets | 32768 | 23.83 | 21.72 | 20.64 | | 65536 | 33.5 | 29.54 | 25.24 | | 131072 | 59.27 | 52.62 | 34.18 | + +## Adding Benchmarks For New Techniques + +If you want to add benchmarks that you think should be kept up to date, please try to keep the format consistent. For performance focused techniques (e.g. if they require fine-tuning or something else) add an option to run them in generate.py and an execution command in benchmarks.sh in the relevant section. If its a technique that's still in development, add it in the section for `OTHER BENCHMARKS` if there's a finalized api and you want those numbers in the main quantization README, add them in the `README BENCHMARKS` section. For accuracy focused techniques, add them in eval.py and evaluations.sh in a similar vein. Ideally techniques in the main readme will have both benchmarks and evaluations set up here so they can be monitored and reproduced easily. diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 0868c964db..d59c5f552e 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,28 +1,23 @@ +README BENCHMARKS llama 2 -20240831224311, tok/s= 26.75, mem/s= 707.01 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240831224512, tok/s= 22.97, mem/s= 303.53 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240831224958, tok/s=108.48, mem/s=1433.57 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 - 20240831225155, tok/s=107.38, mem/s=1418.93 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831225810, tok/s= 9.61, mem/s= 63.67 GB/s, peak_mem= 8.61 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831230013, tok/s=170.83, mem/s=1131.18 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910152454, tok/s=117.89, mem/s= 584.57 GB/s, peak_mem= 6.52 GB, model_size= 4.96 GB quant: fp6, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831230205, tok/s=201.14, mem/s= 751.42 GB/s, peak_mem= 4.87 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831230736, tok/s=177.45, mem/s=1194.35 GB/s, peak_mem= 8.64 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240902100527, tok/s=209.19, mem/s= 804.32 GB/s, peak_mem= 4.89 GB, model_size= 3.84 GB quant: autoquant-int4, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 llama 3 -20240831231514, tok/s= 26.54, mem/s= 796.59 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240831231725, tok/s= 23.67, mem/s= 355.33 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240831232327, tok/s= 96.59, mem/s=1449.85 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 - 20240831232535, tok/s= 95.64, mem/s=1435.54 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831233224, tok/s= 8.61, mem/s= 64.75 GB/s, peak_mem= 9.24 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831233853, tok/s=153.03, mem/s=1150.80 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910153353, tok/s=161.58, mem/s= 910.02 GB/s, peak_mem= 7.72 GB, model_size= 5.63 GB quant: fp6, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831234218, tok/s=180.80, mem/s= 763.33 GB/s, peak_mem= 6.88 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240831235355, tok/s=158.10, mem/s=1193.24 GB/s, peak_mem=10.04 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 20240902101015, tok/s=188.41, mem/s= 800.58 GB/s, peak_mem= 7.14 GB, model_size= 4.25 GB quant: autoquant-int4, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant-int4 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -kv cache quantization: +KV CACHE QUANTIZATION: 20240826161508, tok/s= 19.71, mem/s= 295.80 GB/s, peak_mem=17.86 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192 20240826161747, tok/s= 13.52, mem/s= 202.96 GB/s, peak_mem=17.52 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization 20240826162028, tok/s= 13.30, mem/s= 199.66 GB/s, peak_mem=17.47 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization --linear_causal_mask @@ -38,4 +33,20 @@ kv cache quantization: 20240826171015, tok/s= 1.95, mem/s= 29.21 GB/s, peak_mem=59.27 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072 20240826172121, tok/s= 1.73, mem/s= 26.02 GB/s, peak_mem=52.62 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization 20240826173230, tok/s= 1.73, mem/s= 25.95 GB/s, peak_mem=34.18 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3.1-8B, kv_quant: True, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 --cache_size 131072--kv_cache_quantization --linear_causal_mask -20240906054415, tok/s=226.02, mem/s= 689.20 GB/s, peak_mem= 5.32 GB, model_size= 3.05 GB quant: marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +OTHER BENCHMARKS +20240831224311, tok/s= 26.75, mem/s= 707.01 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224512, tok/s= 22.97, mem/s= 303.53 GB/s, peak_mem=13.64 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831224958, tok/s=108.48, mem/s=1433.57 GB/s, peak_mem=13.90 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910004030, tok/s= 22.72, mem/s= 112.66 GB/s, peak_mem=10.41 GB, model_size= 4.96 GB quant: fp6, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910004539, tok/s= 50.99, mem/s= 200.08 GB/s, peak_mem= 6.29 GB, model_size= 3.92 GB quant: uintx-4-64, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910005147, tok/s= 40.25, mem/s= 265.95 GB/s, peak_mem= 9.24 GB, model_size= 6.61 GB quant: uintx-2-8, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-2-8 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910110554, tok/s=245.07, mem/s= 657.93 GB/s, peak_mem= 4.05 GB, model_size= 2.68 GB quant: sparse-marlin, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240831231514, tok/s= 26.54, mem/s= 796.59 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831231725, tok/s= 23.67, mem/s= 355.33 GB/s, peak_mem=16.19 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240831232327, tok/s= 96.59, mem/s=1449.85 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910005537, tok/s= 20.22, mem/s= 113.89 GB/s, peak_mem=23.17 GB, model_size= 5.63 GB quant: fp6, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910010056, tok/s= 47.85, mem/s= 213.24 GB/s, peak_mem=11.85 GB, model_size= 4.46 GB quant: uintx-4-64, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-4-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910010647, tok/s= 34.83, mem/s= 261.42 GB/s, peak_mem=14.99 GB, model_size= 7.51 GB quant: uintx-2-8, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization uintx-2-8 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240910110958, tok/s=223.95, mem/s= 682.88 GB/s, peak_mem= 5.59 GB, model_size= 3.05 GB quant: sparse-marlin, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.float16, device: cuda repro: python generate.py --quantization sparse-marlin --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 7d2bda1d6e..48c75931fb 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -1,44 +1,28 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder - +# README BENCHMARKS export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# in readme python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt -# auto-round w/ quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround -# auto-round w/o quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 export MODEL_REPO=meta-llama/Meta-Llama-3-8B -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# in readme python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt -# sparse marlin (NOTE: float16) -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt -# auto-round w/ quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround -# auto-round w/o quant_lm_head -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 +# OTHER BENCHMARKS +# kv cache quantization export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization @@ -55,3 +39,29 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask + +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt +# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head + +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt +# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index fa2e77cf05..f3300e3315 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -44,7 +44,11 @@ def run_evaluation( pad_calibration_inputs: Optional[bool] = False, ): """Runs the evaluation of a model using LM Eval.""" - + print( + f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, " + +f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, " + +f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n" + ) torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -73,11 +77,10 @@ def run_evaluation( quantize_(model, fpx_weight_only(3, 2)) if "int4wo" in quantization and not "gptq" in quantization: if "hqq" in quantization: - quantization = quantization[:-4] use_hqq = True else: use_hqq = False - groupsize=int(quantization.split("-")[-1]) + groupsize=int(quantization.split("-")[1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "uintx" in quantization: @@ -85,15 +88,17 @@ def run_evaluation( # "uintx-2-64" if "hqq" in quantization: use_hqq = True - quantization = quantization[:-4] else: use_hqq = False _quant_args = quantization.split("-") - nbits = int(_quant_args[0]) + nbits = int(_quant_args[1]) _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} dtype = _NBITS_TO_DTYPE[nbits] - group_size = int(_quant_args[1]) + group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) + if "marlin" in quantization: + from torchao.dtypes import MarlinSparseLayoutType + quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -140,7 +145,12 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--, uintx---hqq") + parser.add_argument('-q', '--quantization', type=str, + help=( + 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--gptq, autoquant, autoquant-int4, '+ + 'int4wo--hqq, uintx--, uintx---hqq, sparse-marlin' + ) + ) parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 30f5874d15..1ce3ef6b67 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -219,10 +219,9 @@ def main( if "int4wo" in quantization: if "hqq" in quantization: use_hqq=True - quantization = quantization[:-4] else: use_hqq=False - groupsize=int(quantization.split("-")[-1]) + groupsize=int(quantization.split("-")[1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) if "marlin" in quantization: @@ -273,22 +272,22 @@ def main( ) model.to(device) model.reset_caches() - if "fp6" in quantization: + # TODO this needs to be expanded to all of fpx so they can + if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "uintx" in quantization: # uintx-nbits-groupsize, e.g. "uintx-2-64" if "hqq" in quantization: # uintx-nbits-groupsize-hqq - quantization = quantization[:-4] use_hqq = True else: use_hqq = False _quant_args = quantization.split("-") - nbits = int(_quant_args[0]) + nbits = int(_quant_args[1]) assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} dtype = _NBITS_TO_DTYPE[nbits] - group_size = int(_quant_args[1]) + group_size = int(_quant_args[2]) quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) if "autoquant" in quantization: if "autoquant-int4" == quantization: @@ -459,7 +458,13 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------, uintx--, uintx---hqq') + parser.add_argument('-q', '--quantization', type=str, + help=( + 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' + +'autoquant-int4, autoround-------, ' + +'uintx--, uintx---hqq, sparse-marlin' + ) + ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 052ba73a37..673aae1b5f 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -2,27 +2,23 @@ Typically quantization algorithms will have different schemes for how the activation and weights are quantized so A16W8 for instance means the activations are quantized to 16 bits wheras the weights are quantized to 8 bits. Trying out different quantization schemes in `torchao` is generally a 1 line change. Note: exact APIs are not stable, we may change them in the future. ## Benchmarks -Benchmarks are run on a machine with a single A100 GPU using the script in _models/llama which generates text in a latency optimized way (batchsize=1), evaluation was done -Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. +Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B. | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | | | int8dq | 12.262 | 9.61 | 63.67 | 8.61 | 6.62 | | | int8wo | 12.204 | 170.83 | 1131.18 | 8.95 | 6.62 | +| | fp6 | 12.369 | 117.89 | 584.57 | 6.52 | 4.96 | | | int4wo-64 | 12.843 | 201.14 | 751.42 | 4.87 | 3.74 | | | int4wo-64-GPTQ | 12.527 | 201.14 | 751.42 | 4.87 | 3.74 | -| | uintx-4-64 | 12.891 | 48.25 | 189.32 | 6.29 | 3.92 | -| | uintx-2-8 | 28.766 | 36.11 | 238.58 | 9.26 | 6.61 | | | autoquant-int4hqq | 12.825 | 209.19 | 804.32 | 4.89 | 3.84 | | Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | | | int8dq | 7.581 | 8.61 | 64.75 | 9.24 | 7.52 | | | int8wo | 7.447 | 153.03 | 1150.80 | 10.42 | 7.52 | +| | fp6 | 7.661 | 161.58 | 910.02 | 7.72 | 5.63 | | | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 | -| | int4wo-64-sparse-marlin | N/A | 226.02 | 689.20 | 5.32 | 3.05 | -| | uintx-4-64 | 8.113 | 47.77 | 212.90 | 11.85 | 4.46 | -| | uintx-2-8 | 39.368 | 33.21 | 249.22 | 15.04 | 7.51 | | | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. @@ -33,7 +29,7 @@ And a quick crash course on inference quantization to help parse the above table ## Autoquantization -When used as in the example below, autoquant first identifies the shapes of the activations that the different linear layers see, it then benchmarks these shapes across different types of quantized and non-quantized layers in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, it swaps the linear. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()`. +Autoquantization is a tool to automatically determine the best way to apply quantization to your model by comparing the performance of each quantization technique to each layer for the input types and shapes you care about. ```python import torch @@ -57,6 +53,10 @@ elif not use_autoquant_default: model(input) ``` +When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model. + +When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. + Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. ```python @@ -73,8 +73,66 @@ from torchao.quantization.autoquant import AUTOQUANT_CACHE with open("quantization-cache.pkl", "rb") as f: AUTOQUANT_CACHE.update(pickle.load(f)) ``` -## Affine Quantization -Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`) + +## Quantization Techniques +While the above `autoquant` api tries multiple quantization techniques to find the best combination for your model, the techniques themselves can +be applied individually. While there are a large variety of quantization apis, the following techniques have been thoroughly tested and perform well for the metrics they seek to optimize. Each are examples of affine quantization + +#### A16W4 WeightOnly Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, int4_weight_only +group_size = 32 + +# you can enable [hqq](https://ithub.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through +# use_hqq flag for `int4_weight_only` quantization +use_hqq = False +quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors +change_linear_weights_to_int4_woqtensors(model) +``` + +Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. + +#### A16W8 WeightOnly Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, int8_weight_only +quantize_(model, int8_weight_only()) + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors +change_linear_weights_to_int8_woqtensors(model) +``` + +#### A8W8 Dynamic Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight +quantize_(model, int8_dynamic_activation_int8_weight()) + +# for torch 2.2.2 and 2.3 +from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors +change_linear_weights_to_int8_dqtensors(model) +``` + +#### A16W6 Floating Point WeightOnly Quantization + +```python +# for torch 2.4+ +from torchao.quantization import quantize_, fpx_weight_only +quantize_(model, fpx_weight_only(3, 2)) +``` + +You can find more information [here](../dtypes/fpx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. + +## Affine Quantization Details +Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. ### Quantization Primitives We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. @@ -87,7 +145,7 @@ We also have a unified quantized tensor subclass that implements how to get a qu #### Layouts We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. -### Quantization Flow Example +### Full Affine Quantization Flow Example Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul as an example: ```python @@ -152,20 +210,20 @@ speedup: 2.2715200981216173 ``` What we do underlying the APIs are roughly the following: -``` +```python from torchao.dtypes import to_affine_quantized_intx def int8wo_quant(weight): return to_affine_quantized_intx(weight, MappingType.SYMMETRIC, (1, weight.shape[1]), torch.int8, eps=torch.finfo(torch.float32).eps, zero_point_dtype=torch.int64) -for n, m in model.named_modules(): - if isinstance(m, torch.nn.Linear): +for module, name in model.named_modules(): + if isinstance(module, torch.nn.Linear): # optional filtering for module name, shape etc. - m.weight = nn.Parameter(int8wo_quant(m.weight)) + m.weight = nn.Parameter(int8wo_quant(module.weight)) # note: quantization for activation need to be applied after the weight quantization # quantization activation (needed by dynamic quantization) input_quant_func = int8wo_quant # specify how input activation is quantized - m.weight = nn.Parameter(to_linear_activation_quantized(m.weight, input_quant_func)) + module.weight = nn.Parameter(to_linear_activation_quantized(module.weight, input_quant_func)) ``` #### Workaround with `unwrap_tensor_subclass` for `export`, `AOTI` and `torch.compile` (pytorch 2.4 and before only) @@ -188,64 +246,46 @@ but if you use 2.4 or before, you'll need to use `unwrap_tensor_subclass` as wel Note that the workaround will not be needed after https://github.com/pytorch/pytorch/issues/129682 is fixed. -### Automatic Inductor Configuration -The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. -### Other Available Quantization Techniques -#### A8W8 Dynamic Quantization -```python -# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor -torch._inductor.config.force_fuse_int_mm_with_mul = True +## Other Available Quantization Techniques -# for torch 2.4+ -from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight -quantize_(model, int8_dynamic_activation_int8_weight()) - -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors -change_linear_weights_to_int8_dqtensors(model) -``` - -#### A16W8 WeightOnly Quantization - -```python -# for torch 2.4+ -from torchao.quantization import quantize_, int8_weight_only -quantize_(model, int8_weight_only()) +### KV Cache Quantization +We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors -change_linear_weights_to_int8_woqtensors(model) -``` +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](../../torchao/_models/llama/README.md#KV-Cache-Quantization-Memory-Efficient-Inference) -This technique works best when the torch._inductor.config.use_mixed_mm option is enabled. This avoids dequantizing the weight tensor before the matmul, instead fusing the dequantization into the matmul, thereby avoiding materialization of a large floating point weight tensor. +### Sparse-Marlin +Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regressive Linear (Marlin) dense kernel to support 4-bit quantized weights and 2:4 sparsity for extremely high performance. -#### A16W4 WeightOnly Quantization +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | 15.01 | +| | int8wo | 153.03 | 1150.80 | 10.42 | 7.52 | +| | int4wo-64 | 180.80 | 763.33 | 6.88 | 4.22 | +| | int4wo-64-sparse-marlin | 226.02 | 689.20 | 5.32 | 3.05 | -```python -# for torch 2.4+ -from torchao.quantization import quantize_, int4_weight_only -group_size = 32 +More details can be found [here](../sparsity/README.md) -# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `int4_weight_only` quantization -use_hqq = False -quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) +### UINTx Quantization +We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. -# for torch 2.2.2 and 2.3 -from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors -change_linear_weights_to_int4_woqtensors(model) -``` +| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 | +| | uintx-4-64-hqq | 12.775 | 50.99 | 200.08 | 6.29 | 3.92 | +| | uintx-2-8-hqq | 24.500 | 40.25 | 265.95 | 9.24 | 6.61 | +| Llama-3-8B | Base (bfloat16) | 7.441 | 95.64 | 1435.54 | 16.43 | 15.01 | +| | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | +| | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | -Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model. +You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. -### KV Cache Quantization -We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. -In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) +### Automatic Inductor Configuration +The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. ## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index 2e3c604d30..7748c93bea 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -37,7 +37,6 @@ On Meta LLama3, we observe a 25% tok/s increase (180 -> 226) compared to our exi | Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | --------------- | | Llama-3-8B | Base (bfloat16) | 95.64 | 1435.54 | 16.43 | 15.01 | -| | int8dq | 8.61 | 64.75 | 9.24 | 7.52 | | | int8wo | 153.03 | 1150.80 | 10.42 | 7.52 | | | int4wo-64 | 180.80 | 763.33 | 6.88 | 4.22 | | | int4wo-64-sparse-marlin | 226.02 | 689.20 | 5.32 | 3.05 | @@ -61,6 +60,9 @@ model = model.cuda().half() quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType())) ``` +Note the existing API results in an extremely high accuracy degredation and is intended to be used in concert with an already sparsified+finetuned checkpoint where possible until we develop +the necessary supporting flows in torchao. + ### int8 dynamic quant + 2:4 sparasity We support composing int8 dynaic quantization with 2:4 sparsity. We fuse one of the scalar dequant multiplications into our cuSPARSELt sparse mm in order to remain performant.