-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
Description
🚀 The feature, motivation and pitch
As we can see, Google Gemini can support up to million tokens and to serve longer context length, we have to do context parallelism, which means, split the input matrix on sequence dimension to multi-GPUs.
I myself have experience on context parallelism and vLLM (i did Whisper fork), and I am ready to support Context Parallelism development for vLLM. This issue will be long to explain the idea from my side.
Context Parallelism is utilizing Blockwise Attention, https://arxiv.org/abs/2305.19370, which,
- Each Q block must access all KV blocks to compute local attentions and LSEs, after that aggregate.
Improvement of Blockwise Attention,
- Ring Attention, only make sense for gpus > 2, https://arxiv.org/abs/2310.01889
- Tree Attention, topology aware, only make sense for multiple worlds aka multi-nodes, https://arxiv.org/abs/2408.04093
- Ring and Tree just to improve communication overhead, but still, each Q block must access all KV blocks and aggregate.
Assume variables
num_gpus = 2
batch_size = 1
num_head = 32
head_dim = 128
- Why only 2 gpus? Easy to draw at below. But can be any size of devices.
- We also assumed one device only fit max 50k context length.
Long context
Assumed the user input 100k context length, which is the long Input, one device no longer able fit. Because we got 2 GPUs, so we must chunk the input, so each GPU will received 50k context length.
Prefill
Feed into multi GPUs
User input [100k tokens] in CPU, so there are 2 ways to pass to multi GPUs,
- CPU [100k] -> GPU 0 -> chunk [50k, 50k] -> dist.scatter -> [50k] GPU 0, [50k] GPU 1, communication overhead.
- CPU [100k] -> chunk [50k, 50k], loop assign using 'device:{rank}'
Blockwise Attention
Each GPU will calculate their own Blockwise Attention and can decide which communication need to use,
- if num_gpus > 2 and multi nodes, use tree.
- else, use ring.
Data movement
Now for each GPU,
- [50k] -> embedding [50k, dim] -> QKV linear, QKV [3, 50k, num_head, head_dim]
- Do blockwise attention,
As stated, we actually just want to last timestamp for next tokens prediction.
3. Store KV in cache, as you can see, the KV cache exist in GPU 0 and GPU 1. GPU 0 store the first 50k KV and GPU 1 store the second 50k KV.
Example prefill blockwise attention using Xformers, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/prefill-sdpa.ipynb
Next token / step
- Because next token prediction is just [1, dim], so the computation only can be done on one GPU, GPU 0, but the KV Cache stored in both GPU 0 and GPU 1.
- [1, dim] -> QKV linear, QKV [3, 1, num_head, head_dim], and do blockwise attention. GPU 0 need to gather KV Cache from GPU 0 and GPU 1.
Example step blockwise attention using Xformers, https://github.com/mesolitica/context-parallelism-xformers/blob/master/playground/step-sdpa.ipynb
Flash attention also returned LSE, should be able to the same blockwise.
During prefill, context parallelism use all GPUs, during step, only one GPU been use.
Short context
Assumed the user input 100 context length, one GPU is enough, GPU 0.
Prefill
Because the user input fit in one GPU, GPU 0, we just do normal attention.
Next token / step
- During steps, there is might possible that KV cache generated longer than 50k, so we need to add new KV cache into GPU 1.
- If the KV Cache stored on more than one GPU, we have to do blockwise attention.
Discussion
- We need to start develop KV Cache for multi GPUs.
- We need to start develop Blockwise Attention for prefilling.
- We need to start develop switching mode between normal attention and blockwise attention for long prefill and step on long kv cache.
I am super excited to serve longer context length beyond one GPU and ready to develop this together.