diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fc23c9cff0d8..8f8ff423a773 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -475,6 +475,18 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py +- label: Disaggregated Prefill Test # 4min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/parallel_state.py + - vllm/distributed/kv_transfer + - vllm/worker/worker_base.py + - vllm/worker/model_runner.py + commands: + - pytest -v -s kv_transfer/module_test.py + - pytest -v -s kv_transfer/disagg_test.py + - label: LoRA TP Test (Distributed) num_gpus: 4 soft_fail: true diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 000000000000..dec00c2c9fe0 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,140 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + export VLLM_PORT=12345 + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + + VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate $qps + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 000000000000..0e6875363f4d --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,172 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill -f pt_main_thread + pkill -f python3 + ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 10000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 10000 \ + --disable-log-stats \ + --disable-log-requests \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.8 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + # disagg prefill + VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0,1,2,3 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + -tp 4 \ + --max-model-len 10000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + VLLM_PORT=12345 VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=4,5,6,7 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + -tp 4 \ + --max-model-len 10000 \ + --disable-log-stats \ + --disable-log-requests \ + --gpu-memory-utilization 0.8 & + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-70B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=200 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len $output_len \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename $tag-qps-$qps.json \ + --request-rate $qps + + sleep 2 + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx matplotlib aiohttp + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 000000000000..4058b1c0a3b7 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,61 @@ +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): + continue + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 000000000000..6eb5f6398007 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,60 @@ +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse(status=response.status, + headers=response.headers) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route('*', '/{path:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 000000000000..6c5bf5c791dc --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,46 @@ +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json", "r") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) diff --git a/examples/distributed_kv/disagg_prefill_example.sh b/examples/distributed_kv/disagg_prefill_example.sh new file mode 100644 index 000000000000..efec87855dbe --- /dev/null +++ b/examples/distributed_kv/disagg_prefill_example.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') +export VLLM_PORT=12345 + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# prefilling instance, which is the KV producer +VLLM_DISTRIBUTED_KV_ROLE=producer CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# decoding instance, which is the KV consumer +VLLM_DISTRIBUTED_KV_ROLE=consumer CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.8 & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM instance +python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve two example requests +output1=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "Successfully finished 2 test requests!" +echo "" + +# Cleanup commands, suppressing their output +ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 > /dev/null 2>&1 +pkill -f python3 > /dev/null 2>&1 diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 000000000000..3dfacbdc5fe8 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,120 @@ +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + os.environ["VLLM_PORT"] = "12345" + + # Start prefill instance + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "1000", + ] + prefill_env = os.environ.copy() + prefill_env["VLLM_DISTRIBUTED_KV_ROLE"] = "producer" + prefill_env["CUDA_VISIBLE_DEVICES"] = "0,1" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "-tp", + "2", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.8", + "--max-model-len", + "1000", + ] + decode_env = os.environ.copy() + decode_env["VLLM_DISTRIBUTED_KV_ROLE"] = "consumer" + decode_env["CUDA_VISIBLE_DEVICES"] = "2,3" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py new file mode 100644 index 000000000000..355461919cd7 --- /dev/null +++ b/tests/kv_transfer/module_test.py @@ -0,0 +1,64 @@ +import subprocess +import sys + +import pytest +import torch + + +def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") + if process1.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py new file mode 100644 index 000000000000..0730f091a34b --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -0,0 +1,131 @@ +import os +import random + +import torch +from tqdm import tqdm + +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + + +def test_run(my_rank, buffer, device): + + # buffer should be empty in the beginning + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + # insert + tokens = torch.tensor([1, 2, 3]).to(device) + roi = (tokens > 0) + if my_rank == 0: + key = 2.0 * torch.ones([5, 6]).to(device) + value = 3.0 * torch.ones([5, 6]).to(device) + + placeholder = torch.tensor([1]).to(device) + + buffer.insert(tokens, roi, key, value, placeholder) + + torch.distributed.barrier() + + # drop_select + if my_rank == 1: + tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) + assert torch.allclose(tokens, tok) + assert torch.allclose(roi, roi_) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) + torch.distributed.barrier() + + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("Test run passed!") + + +def stress_test(my_rank, buf, device): + + torch.distributed.barrier() + torch.manual_seed(100) + + reqs = [ + ( + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] + + random.seed(my_rank) + random.shuffle(reqs) + + torch.distributed.barrier() + + n = 0 + + # the buffer size can only store 100 reqs + # so the sender will occasionally block to wait for the receiver. + for req in tqdm(reqs): + if my_rank == 0: + buf.insert(*req) + else: + tok, roi, k, v, h = req + tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) + + if tok_ is None: + assert roi_ is None + assert k_ is None + assert v_ is None + assert h_ is None + n += 1 + else: + assert torch.allclose(tok, tok_) + assert torch.allclose(roi, roi_) + assert torch.allclose(k, k_) + assert torch.allclose(v, v_) + assert torch.allclose(h, h_) + print('Rank %d done' % my_rank) + torch.distributed.barrier() + + if my_rank == 0: + x = torch.tensor([0]) + torch.distributed.recv(x, 1) + # the # of None received is the kv that are not selected + assert x.item() == len(buf.buffer) + # and the size of the buffer should be 2000 * buffer len + print(buf.buffer_size) + assert buf.buffer_size == 1700 * len(buf.buffer) + else: + torch.distributed.send(torch.tensor([n]), 0) + + print("Passed stress test!") + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) + + print("initialized! My rank is %d" % my_rank) + + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") + cpu_pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "gloo") + buffer = sklb.SimpleKVLookupBuffer(cpu_pipe, pipe, 170000) + + test_run(my_rank, buffer, pipe.device) + + stress_test(my_rank, buffer, pipe.device) + + buffer.close() + pipe.close() + cpu_pipe.close() + print('Done') diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py new file mode 100644 index 000000000000..ff771f34c032 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.py @@ -0,0 +1,139 @@ +import os +import time +from typing import List + +import torch +from tqdm import tqdm + +import vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe as tdp + + +def test_run(my_rank, pipe): + # test run + x = torch.tensor([1]).to(pipe.device) + y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + if my_rank == 0: + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + + else: + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + + assert torch.allclose(x, x2) + assert torch.allclose(y, y2) + + +def stress_test(my_rank, pipe): + + torch.distributed.barrier() + + tensors: List[torch.Tensor] = [] + + for i in tqdm(range(500)): + mean = torch.rand(1).item() + std = torch.rand(1).item() + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + + # 5% probability of sending a None + if torch.rand(1).item() < 0.05: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean().unsqueeze(0)) + tensors.append(x.std().unsqueeze(0)) + + torch.distributed.barrier() + + for i in tqdm(range(500)): + if my_rank == int((i % 10) > 3): + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + + torch.distributed.barrier() + + print("Stress test passed.") + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(500)): + + tensors = [] + + if my_rank == 0: + # create tensor + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + + torch.distributed.barrier() + + if my_rank == 0: + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) + else: + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group(init_method="tcp://127.0.0.1:23456", + world_size=2, + rank=my_rank) + + print("initialized! My rank is %d" % my_rank) + + pipe = tdp.TorchDistributedPipe([[0, 1]], my_rank, "nccl") + + torch.manual_seed(0) + test_run(my_rank, pipe) + stress_test(my_rank, pipe) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 000000000000..bad119a1aa92 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,108 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + + +class KVLookupBufferBase(ABC): + """ + Abstract base class for a lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + lookup buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 000000000000..eb052e2e41e1 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,223 @@ +import threading +import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleKVLookupBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: int): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0]) + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError("Unknown data type %s" % type(data)) + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 000000000000..79e235b48fd7 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,64 @@ +""" +This file defines +`KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All distributed communications for disagg prefill & KV cache storage should be +handled by `KVPipeBase`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_distributed_pipe.py new file mode 100644 index 000000000000..f8a52fdc929f --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_distributed_pipe.py @@ -0,0 +1,234 @@ +import json +import os +import pickle +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import List, Optional + +import mooncake_vllm_adaptor as mva +import torch +import zmq + +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) +NONE_INT = -150886311 + + +@dataclass +class MooncakeTransferEngineConfig: + prefill_url: str + decode_url: str + metadata_server: str + protocol: str + device_name: str + + @staticmethod + def from_file(file_path: str) -> 'MooncakeTransferEngineConfig': + """Load the config from a JSON file.""" + with open(file_path) as fin: + config = json.load(fin) + return MooncakeTransferEngineConfig( + prefill_url=config.get("prefill_url"), + decode_url=config.get("decode_url"), + metadata_server=config.get("metadata_server"), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + ) + + @staticmethod + def load_from_env() -> 'MooncakeTransferEngineConfig': + """Load config from a file specified in the environment variable.""" + config_file_path = os.getenv('MOONCAKE_CONFIG_PATH') + if config_file_path is None: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeTransferEngineConfig.from_file(config_file_path) + + +class MooncakeTransferEngine: + """Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ.""" + + def __init__(self, rank_in_group: int): + self.engine = mva.mooncake_vllm_adaptor() + + try: + self.config = MooncakeTransferEngineConfig.load_from_env() + logger.info("Configuration loaded successfully.") + except ValueError as e: + logger.error(e) + raise + except Exception as exc: + logger.error( + "An error occurred while loading the configuration: %s", exc) + raise + + self.initialize( + self.config.prefill_url if rank_in_group == 0 else + self.config.decode_url, self.config.metadata_server, + self.config.protocol, self.config.device_name) + + self.remote_url = (self.config.decode_url + if rank_in_group == 0 else self.config.prefill_url) + + # Initialize ZeroMQ context and sockets + self.context = zmq.Context() # type: ignore[attr-defined] + self.sender_socket = self.context.socket(zmq.constants.PUSH) + self.receiver_socket = self.context.socket(zmq.constants.PULL) + self.sender_ack = self.context.socket(zmq.constants.PULL) + self.receiver_ack = self.context.socket(zmq.constants.PUSH) + + host, port = self.remote_url.split(':') + self.buffer_cleaner = ThreadPoolExecutor(max_workers=1) + self._setup_sockets(rank_in_group, host, port) + + def _setup_sockets(self, rank_in_group: int, host: str, port: str) -> None: + """Set up ZeroMQ sockets for sending and receiving data.""" + if rank_in_group == 0: + self.sender_socket.bind(f"tcp://*:{int(port) + 1}") + self.receiver_socket.connect(f"tcp://{host}:{int(port) + 2}") + self.sender_ack.connect(f"tcp://{host}:{int(port) + 3}") + self.receiver_ack.bind(f"tcp://*:{int(port) + 4}") + else: + self.receiver_socket.connect(f"tcp://{host}:{int(port) + 1}") + self.sender_socket.bind(f"tcp://*:{int(port) + 2}") + self.receiver_ack.bind(f"tcp://*:{int(port) + 3}") + self.sender_ack.connect(f"tcp://{host}:{int(port) + 4}") + + def initialize(self, local_hostname: str, metadata_server: str, + protocol: str, device_name: str) -> None: + """Initialize the mooncake instance.""" + self.engine.initialize(local_hostname, metadata_server, protocol, + device_name) + + def allocate_managed_buffer(self, length: int) -> int: + """Allocate a managed buffer of the specified length.""" + ret = self.engine.allocateManagedBuffer(length) + if ret <= 0: + logger.error("Allocation Return Error") + raise Exception("Allocation Return Error") + return ret + + def free_managed_buffer(self, buffer: int, length: int) -> int: + """Free a previously allocated managed buffer.""" + return self.engine.freeManagedBuffer(buffer, length) + + def transfer_sync(self, buffer: int, peer_buffer_address: int, + length: int) -> int: + """Synchronously transfer data to the specified address.""" + ret = self.engine.transferSync(self.remote_url, buffer, + peer_buffer_address, length) + if ret < 0: + logger.error("Transfer Return Error") + raise Exception("Transfer Return Error") + return ret + + def write_bytes_to_buffer(self, buffer: int, user_data: bytes, + length: int) -> int: + """Write bytes to the allocated buffer.""" + return self.engine.writeBytesToBuffer(buffer, user_data, length) + + def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: + """Read bytes from the allocated buffer.""" + return self.engine.readBytesFromBuffer(buffer, length) + + def wait_for_ack(self, src_ptr: int, length: int) -> None: + """Asynchronously wait for ACK from the receiver.""" + ack = self.sender_ack.recv_pyobj() + if ack != b'ACK': + logger.error("Failed to receive ACK from the receiver") + + self.free_managed_buffer(src_ptr, length) + + def send_bytes(self, user_data: bytes) -> None: + """Send bytes to the remote process.""" + length = len(user_data) + src_ptr = self.allocate_managed_buffer(length) + self.write_bytes_to_buffer(src_ptr, user_data, length) + self.sender_socket.send_pyobj((src_ptr, length)) + self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) + + def recv_bytes(self) -> bytes: + """Receive bytes from the remote process.""" + src_ptr, length = self.receiver_socket.recv_pyobj() + dst_ptr = self.allocate_managed_buffer(length) + self.transfer_sync(dst_ptr, src_ptr, length) + ret = self.read_bytes_from_buffer(dst_ptr, length) + + # Buffer cleanup + self.receiver_ack.send_pyobj(b'ACK') + self.free_managed_buffer(dst_ptr, length) + + return ret + + +class MooncakeDistributedPipe(KVPipeBase): + """MooncakeTransferEngine based Pipe implementation.""" + + def __init__(self, group_ranks: List[List[int]], local_rank: int): + """Initialize the mooncake pipe and set related parameters.""" + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + + self.ranks = self.get_ranks(group_ranks) + self.world_size = len(self.ranks) + self.rank_in_group = self.ranks.index(self.rank) + + assert self.rank_in_group <= 1 + self.device = self._select_device() + + self.transfer_engine = MooncakeTransferEngine(self.rank_in_group) + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + def get_ranks(self, group_ranks: List[List[int]]) -> List[int]: + """Get the ranks for the current process.""" + for ranks in group_ranks: + if self.rank in ranks: + return ranks + raise ValueError("Rank not found in group") + + def _select_device(self) -> torch.device: + """Select available device (CUDA or CPU).""" + return torch.device( + f"cuda:{self.local_rank}") if torch.cuda.is_available() else "cpu" + + def tensor_hash(self, tensor: torch.Tensor) -> int: + """Calculate the hash value of the tensor.""" + return hash(tensor.data_ptr()) + + def _send_impl(self, tensor: torch.Tensor) -> None: + """Implement the tensor sending logic.""" + value_bytes = pickle.dumps(tensor) + self.transfer_engine.send_bytes(value_bytes) + + def _recv_impl(self) -> torch.Tensor: + """Implement the tensor receiving logic.""" + data = self.transfer_engine.recv_bytes() + return pickle.loads(data) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send tensor to the target process.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = tensor if tensor is not None else self.none_tensor + assert (len(tensor.shape) > 0) + self.transport_thread.submit(self._send_impl, tensor) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive tensor from other processes.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + tensor = self.transport_thread.submit(self._recv_impl).result() + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self) -> None: + """Cleanup logic when closing the pipe.""" + self.transfer_engine.sender_socket.close() + self.transfer_engine.receiver_socket.close() + self.transfer_engine.context.term() # Terminate the ZMQ context + logger.info("Closed the transfer engine and cleaned up resources.") diff --git a/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py new file mode 100644 index 000000000000..3fe3fa289c66 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/torch_distributed_pipe.py @@ -0,0 +1,289 @@ +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Union + +import torch +from torch.distributed import Backend + +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# if the tensor is only one-element and only contains NONE_INT +# this means that the sended object is None. +NONE_INT = -150886311 + +# Mapping tensor dtype to INT64, used for tensor metadata transmission +FLOAT16_INT = -543205003776624 +INT64_INT = -375623078607432 +BOOL_INT = -28035262008646 +BFLOAT16_INT = -452084912267662 +FLOAT32_INT = -1049557997456592 +FLOAT64_INT = -452201007054137 +FLOAT8_E4M3FN_INT = -1066697177659525 +FLOAT8_E5M2_INT = -618182574682355 + +DTYPE2INT = { + torch.float16: FLOAT16_INT, + torch.int64: INT64_INT, + torch.bool: BOOL_INT, + torch.bfloat16: BFLOAT16_INT, + torch.float32: FLOAT32_INT, + torch.float64: FLOAT64_INT, + torch.float8_e4m3fn: FLOAT8_E4M3FN_INT, + torch.float8_e5m2: FLOAT8_E5M2_INT, +} + +INT2DTYPE = { + FLOAT16_INT: torch.float16, + INT64_INT: torch.int64, + BOOL_INT: torch.bool, + BFLOAT16_INT: torch.bfloat16, + FLOAT32_INT: torch.float32, + FLOAT64_INT: torch.float64, + FLOAT8_E4M3FN_INT: torch.float8_e4m3fn, + FLOAT8_E5M2_INT: torch.float8_e5m2, +} + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class TorchDistributedPipe(KVPipeBase): + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + ): + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + + assert self.device_group is not None + assert self.rank_in_group <= 1 + + self.device = self._select_device(torch_distributed_backend) + + self.target_rank_for_send = self.ranks[(self.rank_in_group + 1) % + self.world_size] + self.target_rank_for_recv = self.ranks[(self.rank_in_group - 1) % + self.world_size] + + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + + self.none_tensor = torch.tensor([NONE_INT], device=self.device) + + # On-device tensors to be reused for recv + self.rcv_metadata_buffer = torch.zeros(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device=self.device) + + def _select_device(self, backend: Union[str, Backend]): + if torch.cuda.is_available() and backend == Backend.NCCL: + return torch.device(f"cuda:{self.local_rank}") + else: + return "cpu" + + def _make_metadata(self, tensor: torch.Tensor) -> torch.Tensor: + """Create the metadata on based on the input tensor, and move it to GPU. + The metadata's length is `TorchDistributedPipe.METADATA_LENGTH`. + + Currently, the metadata is a int64 tensor and it includes dtype, number + of dimensions, and the shape information of the input tensor. + + + The information follows the layout below: + - metadata[0] -- dtype + - metadata[1] -- number of dimensions + - metadata[2 : 2+ndims] -- the shape of the input tensor + + Parameters: + - tensor: the input tensor + + Returns: + - metadata: the metadata tensor, on self.device + """ + buffer = torch.empty(self.METADATA_LENGTH, + dtype=self.METADATA_DTYPE, + device="cpu") + buffer[0] = DTYPE2INT[tensor.dtype] + ndims = len(tensor.shape) + buffer[1] = len(tensor.shape) + buffer[2:2 + ndims] = torch.tensor(tensor.shape, + dtype=self.METADATA_DTYPE) + return buffer.to(self.device) + + def _prepare_recv_buffer(self, + d_metadata_buffer: torch.Tensor) -> torch.Tensor: + """Create a buffer to receive the tensor based on the metadata. + + Parameters: + - d_metadata_buffer: the metadata tensor on self.device + + Returns: + - buffer: the buffer tensor to receive the tensor, on self.device + """ + h_buffer = d_metadata_buffer.cpu().numpy() + dtype = INT2DTYPE[h_buffer[0]] + ndims = h_buffer[1] + shape = tuple(h_buffer[2:2 + ndims]) + return torch.empty(shape, dtype=dtype, device=self.device) + + def _send_metadata(self, d_metadata_buffer: torch.Tensor): + """Send the metadata buffer to the target rank. + """ + torch.distributed.send( + d_metadata_buffer, + dst=self.target_rank_for_send, + group=self.device_group, + ) + + def _recv_metadata(self) -> torch.Tensor: + """Receive the metadata buffer from the target rank. + + Returns: + - metadata_buffer: the metadata buffer tensor, on self.device + + Note: + The current implementation uses the assumption that there is no + race conditions during sending/receiving. Therefore, the metadata + buffer can be reused + """ + torch.distributed.recv( + self.rcv_metadata_buffer, + src=self.target_rank_for_recv, + group=self.device_group, + ) + + return self.rcv_metadata_buffer + + def _send_impl(self, tensor): + """ + The actual implementation of sending the tensor to the target rank. + This function will first send the metadata, and then send the tensor. + + Parameters: + - tensor: the input tensor to be sent + """ + + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + torch.distributed.send(tensor.to(self.device), + dst=self.target_rank_for_send, + group=self.device_group) + + def _recv_impl(self) -> torch.Tensor: + """ + The actual implementation of receiving the tensor from the target rank. + This function will first receive the metadata, then receive the tensor. + + This function will block if there is no tensor to receive. + + Returns: + - buffer: the received tensor, on self.device + """ + d_metadata = self._recv_metadata() + buffer = self._prepare_recv_buffer(d_metadata) + + torch.distributed.recv(buffer, + src=self.target_rank_for_recv, + group=self.device_group) + + return buffer + + def send_tensor_wrapper(self, tensor): + try: + """Wrapper for send_tensor_dict""" + tensor_size = tensor.element_size() * tensor.numel() + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size - tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """Block the current thread if the buffer size is larger than 1e9.""" + # TODO: replace this 1e9 with a configurable parameter or a constant + while self.buffer_size > 1e9: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Sends a tensor to the destination rank in a non-blocking way. + Flow: send tensor dim -- send tensor shape -- send tensor data + """ + + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is None: + tensor = self.none_tensor + + tensor_size = tensor.element_size() * tensor.numel() + + assert ( + 0 < len(tensor.shape) < self.MAX_TENSOR_DIMENSIONS + ), f"Only support dimensions within 1-{self.MAX_TENSOR_DIMENSIONS}" + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size = self.buffer_size + tensor_size + + self.transport_thread.submit( + self.send_tensor_wrapper, + tensor, + ) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receives a tensor from the src rank. Blocking.""" + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + # the underlying pipe is likely broken + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + # fault tolerance: if the pipe is broken, return None + return None + + if tensor.numel() == 1 and tensor.item() == NONE_INT: + return None + else: + return tensor + + def close(self): + """Close the pipe and release the resources.""" + if (hasattr(self, "transport_thread") + and self.transport_thread is not None): + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/vllm_adapter.py b/vllm/distributed/kv_transfer/vllm_adapter.py new file mode 100644 index 000000000000..ee022da405d4 --- /dev/null +++ b/vllm/distributed/kv_transfer/vllm_adapter.py @@ -0,0 +1,488 @@ +"""vLLM distributed KV cache transfer API. +These APIs are used in `vllm/worker/model_runner.py`. + +Currently supporting TP. The TP between prefill and decode instance needs to be +the same. + +Workflow (disaggregated prefill) +- In prefill instance + - After prefill, vLLM `insert` its KV caches into a lookup buffer. + - The prefill instance will also open up a thread that listens to + `drop_select` request. +- In decode instance + - vLLM first runs `drop_select` to send input tokens and a mask on input + tokens (we call it roi, region of interest) to prefill instance + - The prefill instance then respond to `drop_select` request by + - Finding a match in current lookup buffer. + - Clone and send the matched item out + - Delete the matched item in the lookup buffer to free up GPU memory. + - The decode vLLM then store the KV cache into paged memory. +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +import os +from copy import deepcopy + +import torch +from torch.distributed import Backend + +import vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer as sklb +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.kv_transfer.kv_pipe.mooncake_distributed_pipe import ( + MooncakeDistributedPipe) +from vllm.distributed.kv_transfer.kv_pipe.torch_distributed_pipe import ( + TorchDistributedPipe) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + +# check VLLM_DISTRIBUTERD_KV_ROLE and set corresponding flags +assert envs.VLLM_DISTRIBUTED_KV_ROLE in [None, "producer", "consumer", "both"],\ + "VLLM_DISTRIBUTERD_KV_ROLE can only be producer, consumer or both." +IS_DISTRIBUTED_KV_INSTANCE: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE + in ["producer", "consumer", "both"]) +IS_KV_PRODUCER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["producer", "both"]) +IS_KV_CONSUMER: bool = (envs.VLLM_DISTRIBUTED_KV_ROLE in ["consumer", "both"]) + +# When the current instance is both KV producer and KV consumer, +# it is likely connected to a KV storage service on CPU/disk +# so the communication backend needs to be "gloo" for that case. +DISTRIBUTED_BACKEND: str = "gloo" if (IS_KV_PRODUCER + and IS_KV_CONSUMER) else "nccl" +# corresponding device +DISTRIBUTED_DEVICE: str = "cpu" if (IS_KV_PRODUCER + and IS_KV_CONSUMER) else "cuda" + + +class KV_transfer_agent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend] = DISTRIBUTED_BACKEND, + # FIXME(Kuntai): remove this hardcoding + lookup_buffer_size: int = int(1e10)): + + self.lookup_buffer_size = lookup_buffer_size + + self.send_buffer: Optional[KVLookupBufferBase] = None + self.recv_buffer: Optional[KVLookupBufferBase] = None + + self.send_pipe: Optional[KVPipeBase] = None + self.recv_pipe: Optional[KVPipeBase] = None + self.send_signal_pipe: Optional[KVPipeBase] = None + self.recv_signal_pipe: Optional[KVPipeBase] = None + + SimpleKVLookupBuffer = sklb.SimpleKVLookupBuffer + + # Check if MOONCAKE_CONFIG_PATH is set + use_mooncake_distributed_pipe = os.getenv( + 'MOONCAKE_CONFIG_PATH') is not None + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + # In remote KV cache store, vLLM will use both send pipe and recv pipe + # So we build both send pipe and recv pipe for simplicity. + if IS_KV_PRODUCER: + + if use_mooncake_distributed_pipe: + # Use MooncakeDistributedPipe if environment variable is set + self.send_pipe = MooncakeDistributedPipe( + group_ranks, + local_rank, + ) + self.recv_pipe = self.send_pipe + self.send_signal_pipe = self.send_pipe + self.recv_signal_pipe = self.send_pipe + else: + # Use TorchDistributedPipe as default + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = DISTRIBUTED_DEVICE + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + + if use_mooncake_distributed_pipe: + # Use MooncakeDistributedPipe if environment variable is set + self.recv_pipe = MooncakeDistributedPipe( + group_ranks, + local_rank, + ) + # We only need to initialize MooncakeDistributedPipe once, it + # supports bidirectional transmission + self.send_pipe = self.recv_pipe + self.recv_signal_pipe = self.recv_pipe + self.send_signal_pipe = self.recv_pipe + else: + # Use TorchDistributedPipe as default + self.recv_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + self.send_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + DISTRIBUTED_BACKEND, + ) + + self.recv_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_signal_pipe = TorchDistributedPipe( + group_ranks, + local_rank, + "gloo", + ) + self.send_buffer = SimpleKVLookupBuffer(self.send_signal_pipe, + self.send_pipe, + self.lookup_buffer_size) + self.recv_buffer = SimpleKVLookupBuffer(self.recv_signal_pipe, + self.recv_pipe, + self.lookup_buffer_size) + self.tensor_device = DISTRIBUTED_DEVICE + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + # fix potential bugs on Volta and Turing GPUs + model_config = model_executable.model.config + hidden_size = model_config.hidden_size + num_heads = model_config.num_key_value_heads + num_hidden_layers = model_config.num_attention_heads + head_size = int(hidden_size/num_hidden_layers) + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + if self.send_buffer is not None: + self.send_buffer.insert( + current_tokens, torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def destroy(self) -> None: + if self.send_buffer is not None: + self.send_buffer.close() + if self.recv_buffer is not None: + self.recv_buffer.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When this flag is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + if self.recv_buffer is None: + bypass_model_exec = False + break + + ret = self.recv_buffer.drop_select( + current_tokens, torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # so we need to adjust model_input and redo the forwarding. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + rebuilt_model_input = build_partial_prefill_input( + model_input, + input_tokens_list, + num_computed_tokens_list, + start_pos_list, + slot_mapping, + device=input_tokens_tensor.device, + ) + model_input = rebuilt_model_input + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + +def build_partial_prefill_input( + model_input: "ModelInputForGPUWithSamplingMetadata", + input_tokens_list: List[torch.Tensor], + num_computed_tokens_list: List[int], + start_pos_list: List[int], + slot_mapping_flat: torch.Tensor, + device: torch.device, +) -> "ModelInputForGPUWithSamplingMetadata": + """ + Helper function to rebuild the model input for the current request. + Goal: avoid running redundant prefill on those tokens that already has KV + caches received. + """ + rebuilt_input_tokens = [] + rebuilt_input_positions = [] + rebuilt_query_lens = [] + + rebuilt_num_prefills = 0 + rebuilt_num_prefill_tokens = 0 + rebuilt_slot_mapping = [] + rebuilt_max_query_len = 0 + + rebuilt_block_tables = [] + + rebuilt_query_start_loc = [0] + rebuilt_context_lens_tensor = [] + rebuilt_selected_token_indices = [] + + # recounting query and context lengths + for idx in range(len(input_tokens_list)): + token_tensor = input_tokens_list[idx] + num_token = len(token_tensor) + num_computed_token = num_computed_tokens_list[idx] + # currently attention kernel cannot handle the case where there is 0 + # query token. + if num_computed_token == num_token: + num_computed_token -= 1 + start_pos = start_pos_list[idx] + + rebuilt_input_tokens.append(token_tensor[num_computed_token:]) + # TODO(Jiayi): please check the correctness of next line + rebuilt_input_positions.append( + model_input.input_positions[start_pos + + num_computed_token:start_pos + + num_token]) + q_len = num_token - num_computed_token + rebuilt_query_lens.append(q_len) + + # Attn metadata-related + rebuilt_num_prefills += 1 + rebuilt_num_prefill_tokens += q_len + new_slot_mapping = slot_mapping_flat[start_pos + + num_computed_token:start_pos + + num_token] + rebuilt_slot_mapping.append(new_slot_mapping) + rebuilt_max_query_len = max(q_len, rebuilt_max_query_len) + # TODO(Jiayi): remove hard-code (block_size=16) + blk_size = 16 + temp_block_table = [ + slot_mapping_flat[i] // blk_size + for i in range(start_pos, start_pos + num_token, blk_size) + ] + rebuilt_block_tables.append(temp_block_table) + rebuilt_query_start_loc.append( + rebuilt_num_prefill_tokens) #start with 0 + rebuilt_context_lens_tensor.append(num_computed_token) + + # Sampling metadata related + #seq_groups (use rebuilt query lens) + rebuilt_selected_token_indices.append(rebuilt_num_prefill_tokens - 1) + + # rebuilt attn_metadata + rebuilt_attn_metadata = deepcopy(model_input.attn_metadata) + rebuilt_attn_metadata.num_prefills = rebuilt_num_prefills + rebuilt_attn_metadata.num_prefill_tokens = rebuilt_num_prefill_tokens + rebuilt_attn_metadata.slot_mapping = torch.cat(rebuilt_slot_mapping).to( + device) + rebuilt_attn_metadata.max_query_len = rebuilt_max_query_len + + rebuilt_attn_metadata.block_tables = torch.tensor( + rebuilt_block_tables, + dtype=model_input.attn_metadata.block_tables.dtype).to(device) + + rebuilt_attn_metadata.query_start_loc = torch.tensor( + rebuilt_query_start_loc, + dtype=model_input.attn_metadata.query_start_loc.dtype).to(device) + rebuilt_attn_metadata.context_lens_tensor = torch.tensor( + rebuilt_context_lens_tensor, + dtype=model_input.attn_metadata.context_lens_tensor.dtype, + ).to(device) + + rebuilt_attn_metadata._cached_prefill_metadata = None + + # rebuilt sampling_metadata + rebuilt_sampling_metadata = deepcopy(model_input.sampling_metadata) + for idx, q_len in enumerate(rebuilt_query_lens): + if rebuilt_sampling_metadata.seq_groups is not None: + rebuilt_sampling_metadata.seq_groups[idx].query_len = q_len + + rebuilt_sampling_metadata.selected_token_indices = torch.tensor( + rebuilt_selected_token_indices, + dtype=model_input.sampling_metadata.selected_token_indices.dtype, + ).to(device) + + # import here to avoid circular import. + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + rebuilt_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.cat(rebuilt_input_tokens).to(device), + input_positions=torch.cat(rebuilt_input_positions).to(device), + seq_lens=model_input.seq_lens, + query_lens=rebuilt_query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + attn_metadata=rebuilt_attn_metadata, + prompt_adapter_mapping=model_input.prompt_adapter_mapping, + prompt_adapter_requests=model_input.prompt_adapter_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + request_ids_to_seq_ids=model_input.request_ids_to_seq_ids, + finished_requests_ids=model_input.finished_requests_ids, + virtual_engine=model_input.virtual_engine, + sampling_metadata=rebuilt_sampling_metadata, + is_prompt=model_input.is_prompt, + ) + + return rebuilt_model_input diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ccbe00386c5d..74c6309a00e3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -22,6 +22,7 @@ import contextlib import gc import pickle +import time import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext @@ -34,6 +35,9 @@ import torch.distributed from torch.distributed import Backend, ProcessGroup +# Use this import to check if disagg prefill is enabled. +# if enabled, need to adjust distributed group correspondingly. +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform @@ -842,10 +846,10 @@ def get_world_group() -> GroupCoordinator: return _WORLD -def init_world_group(ranks: List[int], local_rank: int, +def init_world_group(ranks: List[List[int]], local_rank: int, backend: str) -> GroupCoordinator: return GroupCoordinator( - group_ranks=[ranks], + group_ranks=ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=False, @@ -904,6 +908,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_DISAGG: Optional[dist_kv.KV_transfer_agent] = None + + +def get_disagg_group() -> dist_kv.KV_transfer_agent: + assert _DISAGG is not None, ( + "disaggregated prefill parallel group is not initialized") + return _DISAGG + @contextmanager def graph_capture(): @@ -935,6 +947,34 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def include_decoding_groups_if_disagg_enabled( + groups: List[List[int]], + world_size: int, +) -> List[List[int]]: + """ + Include the distributed group for decode + Only for disaggregated prefill + + Example: + Original group: [ [0,1], [2,3] ], world_size = 4 + Extended: [ [0,1], [2,3], [4,5], [6,7] ] + Arguments: + groups: original distributed group + world_size: the vLLM world size, which is half of + torch.distributed.get_world_size() + """ + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + new_groups = [] + for group in groups: + new_groups.append([rank for rank in group]) + for group in groups: + new_groups.append([rank + world_size for rank in group]) + return new_groups + else: + return groups + + def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -951,11 +991,30 @@ def init_distributed_environment( "distributed_init_method must be provided when initializing " "distributed environment") # this backend is used for WORLD + + # offset world size and rank in disaggregated prefill scenario + maybe_disagg_world_size = world_size + maybe_disagg_rank = rank + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + maybe_disagg_world_size = world_size * 2 + logger.debug("Distributed KV transfer enabled.") + if dist_kv.IS_KV_PRODUCER: + # for prefill, the ranks are [0, world_size) + logger.debug("rank %d is KV producer.", rank) + maybe_disagg_rank = rank + else: + # this is decode instance. + # offset global rank by tp * pp (which is world_size) + maybe_disagg_rank = rank + world_size + logger.debug("rank %d is KV consumer, adjust it to %d", rank, + maybe_disagg_rank) + torch.distributed.init_process_group( backend=backend, init_method=distributed_init_method, - world_size=world_size, - rank=rank) + world_size=maybe_disagg_world_size, + rank=maybe_disagg_rank) + logger.debug("torch.distributed initialized") # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -966,10 +1025,23 @@ def init_distributed_environment( local_rank = envs.LOCAL_RANK else: local_rank = rank + global _WORLD if _WORLD is None: - ranks = list(range(torch.distributed.get_world_size())) + # in single node single process the world size can be -1 + # need to infer the world size from torch.distributed.get_world_size() + torch_dist_world_size = torch.distributed.get_world_size() + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + # two vLLM instances in the world + # so this vLLM instance's world size is half of torch's world size + torch_dist_world_size = torch_dist_world_size // 2 + ranks = [[i for i in range(torch_dist_world_size)]] + ranks = include_decoding_groups_if_disagg_enabled(ranks, world_size) + _WORLD = init_world_group(ranks, local_rank, backend) + logger.debug("_WORLD initialized for rank %d", + torch.distributed.get_rank()) + time.sleep(5) else: assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") @@ -1001,12 +1073,37 @@ def initialize_model_parallel( are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. + + + Disaggregated prefill will also init its process group using this function. + Changes: + - vLLM world size: unchanged (tp * pp) + - torch.distributed.get_world_size(): + - 2 * tp * pp + - Why: both prefill vLLM and decode vLLM is in the world + - Global rank: + - [0, tp * pp) for prefill + - [tp * pp, 2 * tp * pp) for decode + - Parallel groups + - Extend _WORLD, _TP and _PP using + `include_decoding_groups_if_disagg_enabled` + - Add a new parallel group `_DISAGG` for disaggregated prefill + - [ [0, tp * pp], [1, tp * pp + 1], .. ] + - Local rank: unchanged """ + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + # Disaggregated prefill enabled + # This vLLM instance thinks its word size is tp * pp, but + # torch.distributed contains 2 vLLM instances, + # its world size is 2 * tp * pp + # Adjust the world_size to match. + world_size = world_size // 2 if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): @@ -1026,13 +1123,15 @@ def initialize_model_parallel( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) group_ranks.append(ranks) - + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_message_queue_broadcaster=True, group_name="tp") + logger.debug("_TP initialized for rank %d", torch.distributed.get_rank()) # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // @@ -1044,12 +1143,31 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) + group_ranks = include_decoding_groups_if_disagg_enabled( + group_ranks, world_size) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False, group_name="pp") + logger.debug("_PP initialized for rank %d", torch.distributed.get_rank()) + + if dist_kv.IS_DISTRIBUTED_KV_INSTANCE: + global _DISAGG + logger.debug("Disaggregated prefill enabled, create _DISAGG group") + group_ranks = [] + for i in range(world_size): + # prefill local rank: i + # decode global rank: i + world_size + group_ranks.append([i, i + world_size]) + logger.debug("Distributed group is %s", str(group_ranks)) + _DISAGG = dist_kv.KV_transfer_agent( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + ) + logger.debug("_DISAGG initialized for rank %d", + torch.distributed.get_rank()) def ensure_model_parallel_initialized( @@ -1092,7 +1210,7 @@ def model_parallel_is_initialized(): def patch_tensor_parallel_group(tp_group: GroupCoordinator): """Patch the tp group temporarily until this function ends. - This method is for draft workers of speculative decoding to run draft model + This method is for draft workers of speculative decode to run draft model with different tp degree from that of target model workers. Args: @@ -1135,6 +1253,11 @@ def destroy_model_parallel(): _PP.destroy() _PP = None + global _DISAGG + if _DISAGG: + _DISAGG.destroy() + _DISAGG = None + def destroy_distributed_environment(): global _WORLD diff --git a/vllm/envs.py b/vllm/envs.py index c896770e5f6b..b65816238e67 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -380,6 +380,11 @@ def get_default_config_root(): "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + # Specify the role of current vllm instance + # Value can be "producer", "consumer" or "both". + "VLLM_DISTRIBUTED_KV_ROLE": + lambda: os.getenv("VLLM_DISTRIBUTED_KV_ROLE", None), + # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7fa34456028d..41c2de86a60f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -42,7 +43,8 @@ def _get_worker_kwargs( """Return worker init args for a given rank.""" if distributed_init_method is None: distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) + get_ip(), + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) return dict( vllm_config=self.vllm_config, local_rank=local_rank, diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a6c05a71d2b6..a6e0495b4834 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -5,6 +5,7 @@ import torch +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker @@ -69,7 +70,8 @@ def _init_executor(self) -> None: # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) + "127.0.0.1", + get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) self.workers: List[ProcessWorkerWrapper] = [] # This is the list of workers that are rank 0 of each TP group EXCEPT diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6542b18ae70b..c0378e793da0 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,6 +6,7 @@ import msgspec +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) @@ -251,8 +252,11 @@ def sort_by_driver_then_worker_ip(worker): # solves this issue, as it always works for communication inside # the node. driver_ip = "127.0.0.1" + # force vLLM to use the port specified by envs.VLLM_PORT + # this port will be binded by prefill instance + # but the decode instance must use that port to init torch.distributed distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) + driver_ip, get_open_port(force=dist_kv.IS_DISTRIBUTED_KV_INSTANCE)) # Initialize the actual workers inside worker wrapper. init_worker_all_kwargs = [ diff --git a/vllm/utils.py b/vllm/utils.py index 6f7a6f8c54e4..29adecd481b4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,6 +39,7 @@ from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform @@ -528,11 +529,39 @@ def get_open_zmq_ipc_path() -> str: return f"ipc://{base_rpc_path}/{uuid4()}" -def get_open_port() -> int: +def get_open_port(force: bool = False) -> int: port = envs.VLLM_PORT + + if force: + # This flag will only be True in disaggregated prefill scenario + # and VLLM_PORT must be set so that vLLM can connect prefill vLLM + # instance and decode vLLM instance. + assert port is not None, "Please set environment variable VLLM_PORT in" + " order to use disaggregated prefill and distributed KV cache transfer" + + # For prefill vLLM instance (KV producer), `port` must be available. + # For decode vLLM instance `port` can be not available. + if dist_kv.IS_KV_PRODUCER: + # `port` must be available. + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError as e: + logger.error( + "Port %d must be empty so that prefill vLLM " + "instance can use this port to initialize " + "distributed KV communication with decode " + "vLLM instance.", port) + raise e + else: + # `port` can be not available + return port + if port is not None: while True: try: + logger.info('Trying port %d', port) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", port)) return port diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1f654a9cce46..459d70671329 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,6 +14,7 @@ import torch.distributed import torch.nn as nn +import vllm.distributed.kv_transfer.vllm_adapter as dist_kv import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState @@ -21,7 +22,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import get_disagg_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -1666,6 +1667,24 @@ def execute_model( else: model_executable = self.model + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.need_recv_kv(model_input, kv_caches): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_disagg_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -1677,21 +1696,35 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata, self.vllm_config): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + if not bypass_model_exec: + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.need_send_kv(model_input, kv_caches): + get_disagg_group().send_kv_caches_and_hidden_states( + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker @@ -1759,6 +1792,50 @@ def execute_model( return [output] + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return dist_kv.IS_KV_CONSUMER and ( + not is_profile_run) and is_prefill_run + + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return dist_kv.IS_KV_PRODUCER and ( + not is_profile_run) and is_prefill_run + # NOTE: this is nn.Module so the profiler can properly capture/group # kernels calls made within the graph