Skip to content

[Llama] Make Llama in torchao trainable #674

@gau-nernst

Description

@gau-nernst

While working on #644, @msaroufim suggested to use the built-in Llama for testing the mini train recipe. I looked into it and here are the 2 main changes to be made.

  1. Initialize freq_cis without initializing KV-Cache and causal mask
    def setup_caches(self, max_batch_size, max_seq_length):
    if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
    return
    head_dim = self.config.dim // self.config.n_head
    max_seq_length = find_multiple(max_seq_length, 8)
    self.max_seq_length = max_seq_length
    self.max_batch_size = max_batch_size
    dtype = self.output.weight.dtype
    # For quantized layers, dtype is encoded in scales
    if hasattr(self.output, "scales"):
    dtype = self.output.scales.dtype
    elif hasattr(self.output, "scales_and_zeros"):
    dtype = self.output.scales_and_zeros.dtype
    for b in self.layers:
    b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
    self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
    self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
  2. Don't use attention mask, just use is_causal=True directly
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

This will make it convenient for some of our training recipes (e.g. QAT) to have a mini training scripts directly in torchao, and also act as self-contained examples.

API wise, I think we can add a training flag to Transformer.setup_caches() method.

  • When training=False (default), the old behavior is maintained.
  • When training=True, only freq_cis is initialized, and in the .forward() method, we don't pass mask to TransformerBlock/Attention.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions