1
1
import pytest
2
2
3
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_4
3
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , TORCH_VERSION_AT_LEAST_2_6
4
4
5
5
if not TORCH_VERSION_AT_LEAST_2_4 :
6
6
pytest .skip ("Requires torch>=2.4" , allow_module_level = True )
11
11
import torch .distributed as dist
12
12
import torch .nn .functional as F
13
13
from torch import nn
14
- from torch .distributed ._composable .fsdp import fully_shard , MixedPrecisionPolicy
14
+ from torch .distributed ._composable .fsdp import MixedPrecisionPolicy , fully_shard
15
15
from torch .testing ._internal .common_distributed import skip_if_lt_x_gpu
16
16
from torch .testing ._internal .common_fsdp import FSDPTest
17
17
from torch .testing ._internal .common_utils import TestCase , instantiate_parametrized_tests , parametrize , run_tests
20
20
from torchao .prototype .low_bit_optim import _AdamW
21
21
from torchao .prototype .quantized_training import (
22
22
Int8MixedPrecisionTrainingConfig ,
23
+ bitnet_training ,
23
24
int8_mixed_precision_training ,
24
25
int8_weight_only_quantized_training ,
25
26
quantize_int8_rowwise ,
@@ -165,7 +166,7 @@ def test_int8_mixed_precision_training(self, compile, config):
165
166
embed_dim = 64
166
167
device = "cuda"
167
168
168
- linear = nn .Linear (embed_dim , embed_dim ). cuda ( )
169
+ linear = nn .Linear (embed_dim , embed_dim , device = device )
169
170
linear_int8mp = copy .deepcopy (linear )
170
171
quantize_ (linear_int8mp , int8_mixed_precision_training (config ), set_inductor_config = False )
171
172
@@ -187,6 +188,70 @@ def snr(ref, actual):
187
188
assert snr (inputs_ref .grad , inputs_int8mp .grad ) > 20
188
189
assert snr (linear .weight .grad , linear_int8mp .weight .grad ) > 20
189
190
191
+ @parametrize ("compile" , [False , True ])
192
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
193
+ def test_bitnet_training (self , compile ):
194
+ # reference implementation
195
+ # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
196
+ # Figure 3
197
+ class BitLinear (nn .Linear ):
198
+ def activation_quant (self , x ):
199
+ scale = 127.0 / x .abs ().max (dim = - 1 , keepdim = True ).values .clamp_ (min = 1e-5 )
200
+ return (x * scale ).round ().clamp_ (- 128 , 127 ) / scale
201
+
202
+ def weight_quant (self , x ):
203
+ scale = 1.0 / x .abs ().mean ().clamp_ (min = 1e-5 )
204
+ return (x * scale ).round ().clamp_ (- 1 , 1 ) / scale
205
+
206
+ def forward (self , x ):
207
+ w = self .weight
208
+ x = x + (self .activation_quant (x ) - x ).detach ()
209
+ w = w + (self .weight_quant (w ) - w ).detach ()
210
+ return F .linear (x , w , self .bias )
211
+
212
+ _reset ()
213
+ bsize = 4
214
+ embed_dim = 32
215
+ device = "cuda"
216
+
217
+ # only use 1 matmul shape to reduce triton autotune time
218
+ model_ref = nn .Sequential (
219
+ nn .Linear (embed_dim , embed_dim , bias = False ),
220
+ nn .GELU (),
221
+ nn .Linear (embed_dim , embed_dim ),
222
+ ).to (device )
223
+ model = copy .deepcopy (model_ref )
224
+ quantize_ (model , bitnet_training (), set_inductor_config = False )
225
+
226
+ # change model_ref to use BitLinear
227
+ model_ref [0 ].__class__ = BitLinear
228
+ model_ref [2 ].__class__ = BitLinear
229
+
230
+ if compile :
231
+ model_ref .compile ()
232
+ model .compile ()
233
+
234
+ optim_ref = torch .optim .AdamW (model_ref .parameters ())
235
+ optim = torch .optim .AdamW (model .parameters ())
236
+
237
+ for i in range (5 ):
238
+ inputs = torch .randn (bsize , embed_dim , device = device )
239
+ labels = torch .randint (embed_dim , size = (bsize ,), device = device )
240
+ loss_ref = F .cross_entropy (model_ref (inputs ), labels )
241
+ loss = F .cross_entropy (model (inputs ), labels )
242
+
243
+ torch .testing .assert_close (loss , loss_ref )
244
+
245
+ loss_ref .backward ()
246
+ optim_ref .step ()
247
+ optim_ref .zero_grad ()
248
+
249
+ loss .backward ()
250
+ for p in model .parameters ():
251
+ assert p .grad is not None
252
+ optim .step ()
253
+ optim .zero_grad ()
254
+
190
255
191
256
_FSDP_WORLD_SIZE = 2
192
257
@@ -198,35 +263,36 @@ def world_size(self) -> int:
198
263
199
264
@skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
200
265
def test_fsdp2_correctness (self ):
266
+ mp_policy = MixedPrecisionPolicy ()
267
+
268
+ # quantize_fn, mp_policy, tolerance
201
269
test_args = [
202
- (
203
- int8_weight_only_quantized_training (), # quantize_fn for base model
204
- int8_weight_only_quantized_training (), # quantize_fn for FSDP model
205
- MixedPrecisionPolicy (),
206
- 0.05 , # tolerance. due to stochastic rounding, use a pretty large tolerance here
207
- ),
208
- (
209
- int8_mixed_precision_training (),
210
- int8_mixed_precision_training (),
211
- MixedPrecisionPolicy (),
212
- 1e-6 ,
213
- ),
214
- (
215
- # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
216
- # We would need to cast all params to BF16 in forward and backward pass, while keeping
217
- # the params in FP32 for optim step.
218
- # torch.autocast() will only do this for F.linear() layer (and its backward).
219
- # To keep it simple, we just use a larger tolerance here.
220
- int8_mixed_precision_training (),
221
- int8_mixed_precision_training (Int8MixedPrecisionTrainingConfig (fsdp_param_dtype = torch .bfloat16 )),
222
- MixedPrecisionPolicy (param_dtype = torch .bfloat16 ),
223
- 1e-2 ,
224
- ),
270
+ # high tolerance due to stochastic rounding
271
+ (int8_weight_only_quantized_training , mp_policy , 0.05 ),
272
+ (int8_mixed_precision_training , mp_policy , 1e-6 ),
273
+ (bitnet_training , mp_policy , 1e-5 ),
225
274
]
275
+
276
+ # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129
277
+ if TORCH_VERSION_AT_LEAST_2_6 :
278
+ # It's complicated (though possible) to simulate FSDP BF16 mixed-precision for base_model.
279
+ # We would need to cast all params to BF16 in forward and backward pass, while keeping
280
+ # the params in FP32 for optim step.
281
+ # torch.autocast() will only do this for F.linear() layer (and its backward).
282
+ # To keep it simple, we just use a larger tolerance here.
283
+ bf16_mp_policy = MixedPrecisionPolicy (param_dtype = torch .bfloat16 )
284
+
285
+ extra_args = [
286
+ (int8_weight_only_quantized_training , bf16_mp_policy , 1e-2 ),
287
+ (int8_mixed_precision_training , bf16_mp_policy , 1e-2 ),
288
+ (bitnet_training , bf16_mp_policy , 1e-2 ),
289
+ ]
290
+ test_args .extend (extra_args )
291
+
226
292
self .run_subtests ({"args" : test_args }, self ._run_subtest )
227
293
228
294
def _run_subtest (self , args ):
229
- base_quantize_fn , fsdp_quantize_fn , mp_policy , tolerance = args
295
+ quantize_fn , mp_policy , tolerance = args
230
296
231
297
batch_size = 3
232
298
vocab_size = 32
@@ -245,8 +311,8 @@ def _run_subtest(self, args):
245
311
base_model = Transformer (model_args ).cuda ()
246
312
fsdp_model = copy .deepcopy (base_model )
247
313
248
- quantize_ (base_model .layers , base_quantize_fn , set_inductor_config = False )
249
- quantize_ (fsdp_model .layers , fsdp_quantize_fn , set_inductor_config = False )
314
+ quantize_ (base_model .layers , quantize_fn () , set_inductor_config = False )
315
+ quantize_ (fsdp_model .layers , quantize_fn () , set_inductor_config = False )
250
316
251
317
for layer in fsdp_model .layers :
252
318
fully_shard (layer , mp_policy = mp_policy )
@@ -275,7 +341,25 @@ def _run_subtest(self, args):
275
341
base_optim .step ()
276
342
277
343
rel_error = (fsdp_loss - base_loss ).abs () / base_loss .abs ()
278
- assert rel_error < tolerance , (iter_idx , rel_error )
344
+ assert rel_error < tolerance , (quantize_fn .__name__ , mp_policy , iter_idx , rel_error )
345
+
346
+ @skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
347
+ def test_precompute_bitnet_scale (self ):
348
+ from torchao .prototype .quantized_training .bitnet import get_bitnet_scale , precompute_bitnet_scale_for_fsdp
349
+
350
+ model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).cuda ()
351
+ model_fsdp = copy .deepcopy (model )
352
+ quantize_ (model_fsdp , bitnet_training ())
353
+ fully_shard (model_fsdp )
354
+
355
+ precompute_bitnet_scale_for_fsdp (model_fsdp )
356
+
357
+ torch .testing .assert_close (
358
+ get_bitnet_scale (model [0 ].weight ), model_fsdp [0 ].weight ._local_tensor ._precomputed_scale
359
+ )
360
+ torch .testing .assert_close (
361
+ get_bitnet_scale (model [2 ].weight ), model_fsdp [2 ].weight ._local_tensor ._precomputed_scale
362
+ )
279
363
280
364
281
365
instantiate_parametrized_tests (TestQuantizedTraining )
0 commit comments