Skip to content

[Llama] Make torchao's Llama trainable #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install transformers sentencepiece wandb
# pip install huggingface_hub sentencepiece wandb
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
Expand All @@ -9,21 +9,33 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
from functools import partial
from pathlib import Path

import numpy as np
import torch
import wandb
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM

from torchao._models.llama.model import ModelArgs, Transformer
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_


def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
return model(batch, labels=batch).loss
# hack from fairseq
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
def enable_activation_checkpointing(m: torch.nn.Module):
assert not hasattr(m, "_forward")
m._forward = m.forward
m.forward = partial(checkpoint, m.forward)


def get_loss(model: Transformer, batch: torch.Tensor):
logits = model(batch)[:, :-1].flatten(0, 1)
labels = batch[:, 1:].flatten()
return torch.nn.functional.cross_entropy(logits, labels)


def get_tinystories():
Expand Down Expand Up @@ -91,17 +103,19 @@ def get_tinystories():
if args.seed is not None:
torch.manual_seed(args.seed)

config = LlamaConfig(
hidden_size=args.d_model,
config = ModelArgs(
block_size=args.seq_len,
n_layer=args.depth,
n_head=args.d_model // args.head_dim,
dim=args.d_model,
intermediate_size=args.ffn_size,
num_hidden_layers=args.depth,
num_attention_heads=args.d_model // args.head_dim,
max_position_embeddings=args.seq_len,
use_cache=False,
)
model = LlamaForCausalLM(config).bfloat16().cuda()
model = Transformer(config).bfloat16().cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
model.gradient_checkpointing_enable()
for layer in model.layers:
enable_activation_checkpointing(layer)
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize is not None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
24 changes: 16 additions & 8 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_length = -1

def setup_caches(self, max_batch_size, max_seq_length):
def setup_caches(self, max_batch_size, max_seq_length, training: bool = False):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
Expand All @@ -163,16 +163,21 @@ def setup_caches(self, max_batch_size, max_seq_length):
dtype = self.output.scales.dtype
elif hasattr(self.output, "scales_and_zeros"):
dtype = self.output.scales_and_zeros.dtype
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
if not training:
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
if input_pos is not None:
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
else:
mask = None
freqs_cis = self.freqs_cis[:idx.shape[1]]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
Expand All @@ -194,7 +199,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Optional[Tensor], freqs_cis: Tensor, mask: Optional[Tensor]) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -224,7 +229,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -244,7 +249,10 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
if mask is not None:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
else:
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down
Loading