-
Notifications
You must be signed in to change notification settings - Fork 315
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
While working on #644, @msaroufim suggested to use the built-in Llama for testing the mini train recipe. I looked into it and here are the 2 main changes to be made.
- Initialize
freq_cis
without initializing KV-Cache and causal maskao/torchao/_models/llama/model.py
Lines 153 to 170 in e7fc0ed
def setup_caches(self, max_batch_size, max_seq_length): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: return head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size dtype = self.output.weight.dtype # For quantized layers, dtype is encoded in scales if hasattr(self.output, "scales"): dtype = self.output.scales.dtype elif hasattr(self.output, "scales_and_zeros"): dtype = self.output.scales_and_zeros.dtype for b in self.layers: b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) - Don't use attention mask, just use
is_causal=True
directlyao/torchao/_models/llama/model.py
Line 247 in e7fc0ed
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
This will make it convenient for some of our training recipes (e.g. QAT) to have a mini training scripts directly in torchao, and also act as self-contained examples.
API wise, I think we can add a training
flag to Transformer.setup_caches()
method.
- When
training=False
(default), the old behavior is maintained. - When
training=True
, onlyfreq_cis
is initialized, and in the.forward()
method, we don't passmask
to TransformerBlock/Attention.
msaroufim and andrewor14
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request