Skip to content

Add BF16 stochastic rounding option for optimizers #1124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def get_parser():
parser.add_argument("--optim", default="AdamW", choices=OPTIM_MAP.keys())
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])

Expand Down Expand Up @@ -206,7 +207,12 @@ def evaluate_model(model, args):
train_batch_size=args.batch_size,
optimizer=dict(
type="Adam",
params=dict(lr=args.lr, weight_decay=args.weight_decay, fp32_optimizer_states=False),
params=dict(
lr=args.lr,
weight_decay=args.weight_decay,
fp32_optimizer_states=False,
**args.optim_kwargs,
),
),
bf16=dict(enabled=args.full_bf16),
zero_optimization=dict(
Expand All @@ -225,7 +231,12 @@ def evaluate_model(model, args):
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)

optim = optim_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optim = optim_cls(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
**args.optim_kwargs,
)

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
Expand Down
9 changes: 8 additions & 1 deletion benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
import json
import time
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -108,6 +109,7 @@ def get_tinystories():
parser.add_argument("--optim", default="AdamW")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=1e-2)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())

parser.add_argument("--project", default="quantized_training")
parser.add_argument("--run_name")
Expand Down Expand Up @@ -171,7 +173,12 @@ def insert_rmsnorm(module: torch.nn.Module):
# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
if args.optim == "AdamW":
args.optim = "_AdamW"
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
optim = getattr(low_bit_optim, args.optim)(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
**args.optim_kwargs,
)

data = get_tinystories().cuda()
args.torch_version = torch.__version__
Expand Down
57 changes: 54 additions & 3 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim.quant_utils import quantize_8bit_with_qmap, quantize_4bit_with_qmap
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5
from torchao.prototype.low_bit_optim.quant_utils import (
quantize_8bit_with_qmap,
quantize_4bit_with_qmap,
_fp32_to_bf16_sr,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -74,6 +78,22 @@ def test_quantize_4bit_with_qmap_compile(self, device):

torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize("compile", [False, True])
def test_bf16_stochastic_round(self, device, compile):
x = torch.rand(32, device=device) * 100
x_rep = x.view(-1, 1).repeat(1, 100_000)

if compile:
x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep)
else:
x_rep_bf16 = _fp32_to_bf16_sr(x_rep)

assert x_rep_bf16.dtype is torch.bfloat16

# must cast BF16 tensor back to FP32 so that .mean() is accurate
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)


class TestOptim(TestCase):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
Expand Down Expand Up @@ -249,13 +269,44 @@ def test_optim_cpu_offload_save_load(self):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

def test_optim_bf16_stochastic_round_correctness(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(2024)
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1).bfloat16()

# small LR so that weight update is small
# when bf16_stochastic_round=False, the test will fail after 1 iteration
optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5)
optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True)

# overfit on this sample
x = torch.randn(4, 32, device=device)

for idx in range(5):
# mixed-precision training
with torch.autocast(device, dtype=torch.bfloat16):
loss1 = model1(x)
loss1 = loss1.sum() # under autocast context, bf16.sum() will return fp32
loss1.backward()
optim1.step()
optim1.zero_grad()

# full BF16 training with stochastic round weight update
loss2 = model2(x.bfloat16()).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}")


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.")
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
Expand Down
22 changes: 22 additions & 0 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This folder implements:
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)
- Stochastic rounding for BF16 weight (https://arxiv.org/abs/2010.06192, experimental)

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. Thus, your platform must support `torch.compile()` to use these optimizers. We only test on CPU and CUDA, so there might be bugs or errors on other platforms.

Expand Down Expand Up @@ -56,6 +57,27 @@ ao 4-bit | 33.2 | 2900 | 42.27

NOTE: lpmm's 4-bit AdamW does not support BF16 weights.

## Stochastic rounding for BF16 weight

BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights.

Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform **stochastic rounding** when casting FP32->BF16.
- In stochastic rounding, we will round up with the probability of `(x - round_down(x)) / (round_up(x) - round_down(x))`, and round down otherwise.
- It follows that successive weight update with stochastic rounding will correctly approximate high-precision weight update.
- Since BF16 is simply a truncation of FP32, there is an efficient implementation for FP32->BF16 stochastic rounding (the same is not true for FP32->FP16).
- More detailed discussion can be found at https://arxiv.org/abs/2010.06192. [llm.c](https://github.com/karpathy/llm.c/blob/7ecd8906afe6ed7a2b2cdb731c042f26d525b820/llmc/adamw.cuh#L43) also implements this approach.

```python
# a clone of torch.optim.AdamW with extra features
from torchao.prototype.low_bit_optim import _AdamW

model = ...
model_bf16 = model.bfloat16()
optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True)
```

All of our low-bit optimizers mentioned above also support `bf16_stochastic_round` flag. Note that this flag only applies to BF16 weight.

## Optimizer CPU offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. Only CUDA is supported. For multi-GPU training, you can use FSDP's built-in CPU offload.
Expand Down
Loading
Loading