Skip to content

Commit a382752

Browse files
authored
BitNet b1.58 training (#930)
* first upstream of BitNet * fix type annotation * skip bitnet test on cpu. add bitnet to benchmark script * add bitnet option to example training script. update backward * add FSDP2 test * remove FSDP2 mixed-precision workaround. cleanup test * fix typo * adjust tolerance * update command * add precompute scale for FSDP2 * fix typing * add test for precompute scale * rename * separate BitNet model surgery * minor fixes. add note on packing
1 parent ae3e7c6 commit a382752

File tree

8 files changed

+619
-108
lines changed

8 files changed

+619
-108
lines changed

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
22
# pip install huggingface_hub sentencepiece wandb
33
#
4-
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
5-
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
6-
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision
4+
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile
5+
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_weight_only
6+
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize int8_mixed_precision
7+
# BitNet: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --quantize bitnet --modify_rmsnorm_for_bitnet
78

89
import os
910

@@ -20,14 +21,14 @@
2021
from torch.utils.checkpoint import checkpoint
2122
from tqdm import tqdm
2223

23-
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs
24+
from torchao import quantize_
25+
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs, RMSNorm
2426
from torchao.prototype import low_bit_optim
2527
from torchao.prototype.quantized_training import (
28+
bitnet_training,
2629
int8_mixed_precision_training,
2730
int8_weight_only_quantized_training,
2831
)
29-
from torchao.quantization.quant_api import quantize_
30-
3132

3233
# not official models
3334
transformer_configs.update(
@@ -92,10 +93,14 @@ def get_tinystories():
9293
if __name__ == "__main__":
9394
parser = argparse.ArgumentParser()
9495
parser.add_argument("--model", default="470M", choices=transformer_configs.keys())
96+
parser.add_argument("--bf16_model", action="store_true")
97+
parser.add_argument("--bf16_amp", action="store_true")
9598
parser.add_argument("--quantize")
9699
parser.add_argument("--activation_checkpointing", action="store_true")
97100
parser.add_argument("--compile", action="store_true")
98101

102+
parser.add_argument("--modify_rmsnorm_for_bitnet", action="store_true")
103+
99104
parser.add_argument("--n_steps", type=int, default=1000)
100105
parser.add_argument("--batch_size", type=int, default=4)
101106
parser.add_argument("--seq_len", type=int, default=2048)
@@ -104,7 +109,7 @@ def get_tinystories():
104109
parser.add_argument("--lr", type=float, default=3e-4)
105110
parser.add_argument("--weight_decay", type=float, default=1e-2)
106111

107-
parser.add_argument("--project", default="int8_quantized_training")
112+
parser.add_argument("--project", default="quantized_training")
108113
parser.add_argument("--run_name")
109114
parser.add_argument("--seed", type=int)
110115
parser.add_argument("--log_interval", type=int, default=10)
@@ -115,19 +120,47 @@ def get_tinystories():
115120

116121
config = ModelArgs.from_name(args.model)
117122
config.block_size = args.seq_len
118-
model = Transformer(config).bfloat16().cuda()
123+
model = Transformer(config)
124+
if args.bf16_model:
125+
model.bfloat16()
126+
model.cuda()
119127
with torch.device("cuda"):
120128
model.setup_caches(args.batch_size, args.seq_len, training=True)
121129
if args.activation_checkpointing:
122130
for layer in model.layers:
123131
enable_activation_checkpointing(layer)
124132

133+
# as recommended by https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
134+
# section 3
135+
if args.modify_rmsnorm_for_bitnet:
136+
# remove old RMSNorm
137+
for layer in model.layers:
138+
layer.attention_norm = torch.nn.Identity()
139+
layer.ffn_norm = torch.nn.Identity()
140+
141+
# insert new RMSNorm
142+
def insert_rmsnorm(module: torch.nn.Module):
143+
for name, child in module.named_children():
144+
if isinstance(child, torch.nn.Linear):
145+
w = child.weight
146+
norm = RMSNorm(child.in_features).to(device=w.device, dtype=w.dtype)
147+
setattr(module, name, torch.nn.Sequential(norm, child))
148+
else:
149+
insert_rmsnorm(child)
150+
151+
insert_rmsnorm(model.layers)
152+
125153
# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
126154
# TODO: might want to do the same for int8_weight_only to standardize.
127155
if args.quantize == "int8_weight_only":
128156
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
157+
129158
elif args.quantize == "int8_mixed_precision":
130159
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)
160+
161+
elif args.quantize == "bitnet":
162+
quantize_(model.layers, bitnet_training(), set_inductor_config=False)
163+
131164
elif args.quantize is not None:
132165
raise ValueError(f"Unsupported quantize={args.quantize}")
133166

@@ -155,7 +188,8 @@ def get_tinystories():
155188
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item()
156189
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long()
157190

158-
loss = _get_loss(model, batch)
191+
with torch.autocast("cuda", torch.bfloat16, enabled=args.bf16_amp):
192+
loss = _get_loss(model, batch)
159193
loss.backward()
160194

161195
if step % args.log_interval == 0:
@@ -165,10 +199,6 @@ def get_tinystories():
165199
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
166200
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
167201
)
168-
if step > 0:
169-
time1 = time.time()
170-
log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)
171-
time0 = time1
172202
run.log(log_dict, step=step)
173203
pbar.set_postfix(loss=log_dict["loss"])
174204

@@ -178,4 +208,10 @@ def get_tinystories():
178208
step += 1
179209
pbar.update()
180210

211+
if step % args.log_interval == 0:
212+
time1 = time.time()
213+
log_dict = dict(tokens_per_second=(args.log_interval * args.batch_size * args.seq_len) / (time1 - time0))
214+
time0 = time1
215+
run.log(log_dict, step=step)
216+
181217
run.finish()

test/prototype/test_quantized_training.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
3+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
44

55
if not TORCH_VERSION_AT_LEAST_2_4:
66
pytest.skip("Requires torch>=2.4", allow_module_level=True)
@@ -11,7 +11,7 @@
1111
import torch.distributed as dist
1212
import torch.nn.functional as F
1313
from torch import nn
14-
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
14+
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
1515
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1616
from torch.testing._internal.common_fsdp import FSDPTest
1717
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests
@@ -20,6 +20,7 @@
2020
from torchao.prototype.low_bit_optim import _AdamW
2121
from torchao.prototype.quantized_training import (
2222
Int8MixedPrecisionTrainingConfig,
23+
bitnet_training,
2324
int8_mixed_precision_training,
2425
int8_weight_only_quantized_training,
2526
quantize_int8_rowwise,
@@ -165,7 +166,7 @@ def test_int8_mixed_precision_training(self, compile, config):
165166
embed_dim = 64
166167
device = "cuda"
167168

168-
linear = nn.Linear(embed_dim, embed_dim).cuda()
169+
linear = nn.Linear(embed_dim, embed_dim, device=device)
169170
linear_int8mp = copy.deepcopy(linear)
170171
quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)
171172

@@ -187,6 +188,70 @@ def snr(ref, actual):
187188
assert snr(inputs_ref.grad, inputs_int8mp.grad) > 20
188189
assert snr(linear.weight.grad, linear_int8mp.weight.grad) > 20
189190

191+
@parametrize("compile", [False, True])
192+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
193+
def test_bitnet_training(self, compile):
194+
# reference implementation
195+
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
196+
# Figure 3
197+
class BitLinear(nn.Linear):
198+
def activation_quant(self, x):
199+
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
200+
return (x * scale).round().clamp_(-128, 127) / scale
201+
202+
def weight_quant(self, x):
203+
scale = 1.0 / x.abs().mean().clamp_(min=1e-5)
204+
return (x * scale).round().clamp_(-1, 1) / scale
205+
206+
def forward(self, x):
207+
w = self.weight
208+
x = x + (self.activation_quant(x) - x).detach()
209+
w = w + (self.weight_quant(w) - w).detach()
210+
return F.linear(x, w, self.bias)
211+
212+
_reset()
213+
bsize = 4
214+
embed_dim = 32
215+
device = "cuda"
216+
217+
# only use 1 matmul shape to reduce triton autotune time
218+
model_ref = nn.Sequential(
219+
nn.Linear(embed_dim, embed_dim, bias=False),
220+
nn.GELU(),
221+
nn.Linear(embed_dim, embed_dim),
222+
).to(device)
223+
model = copy.deepcopy(model_ref)
224+
quantize_(model, bitnet_training(), set_inductor_config=False)
225+
226+
# change model_ref to use BitLinear
227+
model_ref[0].__class__ = BitLinear
228+
model_ref[2].__class__ = BitLinear
229+
230+
if compile:
231+
model_ref.compile()
232+
model.compile()
233+
234+
optim_ref = torch.optim.AdamW(model_ref.parameters())
235+
optim = torch.optim.AdamW(model.parameters())
236+
237+
for i in range(5):
238+
inputs = torch.randn(bsize, embed_dim, device=device)
239+
labels = torch.randint(embed_dim, size=(bsize,), device=device)
240+
loss_ref = F.cross_entropy(model_ref(inputs), labels)
241+
loss = F.cross_entropy(model(inputs), labels)
242+
243+
torch.testing.assert_close(loss, loss_ref)
244+
245+
loss_ref.backward()
246+
optim_ref.step()
247+
optim_ref.zero_grad()
248+
249+
loss.backward()
250+
for p in model.parameters():
251+
assert p.grad is not None
252+
optim.step()
253+
optim.zero_grad()
254+
190255

191256
_FSDP_WORLD_SIZE = 2
192257

@@ -198,35 +263,36 @@ def world_size(self) -> int:
198263

199264
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
200265
def test_fsdp2_correctness(self):
266+
mp_policy = MixedPrecisionPolicy()
267+
268+
# quantize_fn, mp_policy, tolerance
201269
test_args = [
202-
(
203-
int8_weight_only_quantized_training(), # quantize_fn for base model
204-
int8_weight_only_quantized_training(), # quantize_fn for FSDP model
205-
MixedPrecisionPolicy(),
206-
0.05, # tolerance. due to stochastic rounding, use a pretty large tolerance here
207-
),
208-
(
209-
int8_mixed_precision_training(),
210-
int8_mixed_precision_training(),
211-
MixedPrecisionPolicy(),
212-
1e-6,
213-
),
214-
(
215-
# It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
216-
# We would need to cast all params to BF16 in forward and backward pass, while keeping
217-
# the params in FP32 for optim step.
218-
# torch.autocast() will only do this for F.linear() layer (and its backward).
219-
# To keep it simple, we just use a larger tolerance here.
220-
int8_mixed_precision_training(),
221-
int8_mixed_precision_training(Int8MixedPrecisionTrainingConfig(fsdp_param_dtype=torch.bfloat16)),
222-
MixedPrecisionPolicy(param_dtype=torch.bfloat16),
223-
1e-2,
224-
),
270+
# high tolerance due to stochastic rounding
271+
(int8_weight_only_quantized_training, mp_policy, 0.05),
272+
(int8_mixed_precision_training, mp_policy, 1e-6),
273+
(bitnet_training, mp_policy, 1e-5),
225274
]
275+
276+
# FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129
277+
if TORCH_VERSION_AT_LEAST_2_6:
278+
# It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
279+
# We would need to cast all params to BF16 in forward and backward pass, while keeping
280+
# the params in FP32 for optim step.
281+
# torch.autocast() will only do this for F.linear() layer (and its backward).
282+
# To keep it simple, we just use a larger tolerance here.
283+
bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
284+
285+
extra_args = [
286+
(int8_weight_only_quantized_training, bf16_mp_policy, 1e-2),
287+
(int8_mixed_precision_training, bf16_mp_policy, 1e-2),
288+
(bitnet_training, bf16_mp_policy, 1e-2),
289+
]
290+
test_args.extend(extra_args)
291+
226292
self.run_subtests({"args": test_args}, self._run_subtest)
227293

228294
def _run_subtest(self, args):
229-
base_quantize_fn, fsdp_quantize_fn, mp_policy, tolerance = args
295+
quantize_fn, mp_policy, tolerance = args
230296

231297
batch_size = 3
232298
vocab_size = 32
@@ -245,8 +311,8 @@ def _run_subtest(self, args):
245311
base_model = Transformer(model_args).cuda()
246312
fsdp_model = copy.deepcopy(base_model)
247313

248-
quantize_(base_model.layers, base_quantize_fn, set_inductor_config=False)
249-
quantize_(fsdp_model.layers, fsdp_quantize_fn, set_inductor_config=False)
314+
quantize_(base_model.layers, quantize_fn(), set_inductor_config=False)
315+
quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False)
250316

251317
for layer in fsdp_model.layers:
252318
fully_shard(layer, mp_policy=mp_policy)
@@ -275,7 +341,25 @@ def _run_subtest(self, args):
275341
base_optim.step()
276342

277343
rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs()
278-
assert rel_error < tolerance, (iter_idx, rel_error)
344+
assert rel_error < tolerance, (quantize_fn.__name__, mp_policy, iter_idx, rel_error)
345+
346+
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
347+
def test_precompute_bitnet_scale(self):
348+
from torchao.prototype.quantized_training.bitnet import get_bitnet_scale, precompute_bitnet_scale_for_fsdp
349+
350+
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
351+
model_fsdp = copy.deepcopy(model)
352+
quantize_(model_fsdp, bitnet_training())
353+
fully_shard(model_fsdp)
354+
355+
precompute_bitnet_scale_for_fsdp(model_fsdp)
356+
357+
torch.testing.assert_close(
358+
get_bitnet_scale(model[0].weight), model_fsdp[0].weight._local_tensor._precomputed_scale
359+
)
360+
torch.testing.assert_close(
361+
get_bitnet_scale(model[2].weight), model_fsdp[2].weight._local_tensor._precomputed_scale
362+
)
279363

280364

281365
instantiate_parametrized_tests(TestQuantizedTraining)

0 commit comments

Comments
 (0)