Skip to content

[Feature]: Context Parallelism #7519

@huseinzol05

Description

@huseinzol05

🚀 The feature, motivation and pitch

As we can see, Google Gemini can support up to million tokens and to serve longer context length, we have to do context parallelism, which means, split the input matrix on sequence dimension to multi-GPUs.

I myself have experience on context parallelism and vLLM (i did Whisper fork), and I am ready to support Context Parallelism development for vLLM. This issue will be long to explain the idea from my side.

Context Parallelism is utilizing Blockwise Attention, https://arxiv.org/abs/2305.19370, which,

  • Each Q block must access all KV blocks to compute local attentions and LSEs, after that aggregate.

Improvement of Blockwise Attention,

Assume variables

num_gpus = 2
batch_size = 1
num_head = 32
head_dim = 128

  • Why only 2 gpus? Easy to draw at below. But can be any size of devices.
  • We also assumed one device only fit max 50k context length.

Long context

Assumed the user input 100k context length, which is the long Input, one device no longer able fit. Because we got 2 GPUs, so we must chunk the input, so each GPU will received 50k context length.

Prefill

Feed into multi GPUs

User input [100k tokens] in CPU, so there are 2 ways to pass to multi GPUs,

  1. CPU [100k] -> GPU 0 -> chunk [50k, 50k] -> dist.scatter -> [50k] GPU 0, [50k] GPU 1, communication overhead.
  2. CPU [100k] -> chunk [50k, 50k], loop assign using 'device:{rank}'

Blockwise Attention

Each GPU will calculate their own Blockwise Attention and can decide which communication need to use,

  1. if num_gpus > 2 and multi nodes, use tree.
  2. else, use ring.

Data movement

Now for each GPU,

  1. [50k] -> embedding [50k, dim] -> QKV linear, QKV [3, 50k, num_head, head_dim]
  2. Do blockwise attention,
Screenshot 2024-08-14 at 10 39 43 PM

As stated, we actually just want to last timestamp for next tokens prediction.
3. Store KV in cache, as you can see, the KV cache exist in GPU 0 and GPU 1. GPU 0 store the first 50k KV and GPU 1 store the second 50k KV.

Example prefill blockwise attention using Xformers, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/prefill-sdpa.ipynb

Next token / step

  1. Because next token prediction is just [1, dim], so the computation only can be done on one GPU, GPU 0, but the KV Cache stored in both GPU 0 and GPU 1.
  2. [1, dim] -> QKV linear, QKV [3, 1, num_head, head_dim], and do blockwise attention. GPU 0 need to gather KV Cache from GPU 0 and GPU 1.

Example step blockwise attention using Xformers, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/step-sdpa.ipynb

Flash attention also returned LSE, should be able to the same blockwise.

During prefill, context parallelism use all GPUs, during step, only one GPU been use.

Short context

Assumed the user input 100 context length, one GPU is enough, GPU 0.

Prefill

Because the user input fit in one GPU, GPU 0, we just do normal attention.

Next token / step

  1. During steps, there is might possible that KV cache generated longer than 50k, so we need to add new KV cache into GPU 1.
  2. If the KV Cache stored on more than one GPU, we have to do blockwise attention.

Discussion

  1. We need to start develop KV Cache for multi GPUs.
  2. We need to start develop Blockwise Attention for prefilling.
  3. We need to start develop switching mode between normal attention and blockwise attention for long prefill and step on long kv cache.

I am super excited to serve longer context length beyond one GPU and ready to develop this together.

Metadata

Metadata

Assignees

Labels

feature requestNew feature or requestunstaleRecieved activity after being labelled stale

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions