Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ docker run --privileged --net host --shm-size=16G -it \
&& export VLLM_USE_V1=1 \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo TEST_1 \
&& pytest /workspace/vllm/tests/tpu/test_compilation.py \
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
&& echo TEST_2 \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
&& echo TEST_3 \
Expand Down
57 changes: 17 additions & 40 deletions tests/tpu/test_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
import tempfile

import depyf
import pytest

from vllm.config import CompilationLevel


@pytest.mark.skip(reason="Not working; needs investigation.")
def test_tpu_compilation():
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
Expand All @@ -22,27 +18,24 @@ def test_tpu_compilation():
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
" or, through inaction",
" what is essential ",
" but in rising ",
]
N = 1

# Currently, top-p sampling is disabled. `top_p` should be 1.0.
N = 1
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)

# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.

# disable custom dispatcher, let Dynamo takes over
# all the control
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
max_model_len=512,
max_num_seqs=64,
enforce_eager=True,
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
max_num_batched_tokens=256,
max_model_len=256,
max_num_seqs=32,
enforce_eager=False)

outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
Expand All @@ -56,16 +49,11 @@ def test_tpu_compilation():
for i, compiled_code in enumerate(compiled_codes):
print("{} file: {}".format(i + 1, compiled_code))

# We should only trigger Dynamo compilation 4 times:
# 1. forward pass (symbolic)
# 2. compute_logits (symbolic)
# 3. forward pass (shape 16)
# 4. forward pass (shape 32)
# and later calls should not trigger Dynamo compilation again.
# NOTE: It might still trigger XLA compilation.

# We should only trigger Dynamo compilation 2 times:
# 1. Forward pass without kv_caches
# 2. Forward pass with kv_caches
# Check we have 4 compiled codes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check we have 4 compiled codes
# Check we have 2 compiled codes

assert len(compiled_codes) == 4
assert len(compiled_codes) == 2

kv_cache_prefix = "kv_cache"
attn_prefix = "ragged_paged_attention"
Expand All @@ -77,24 +65,13 @@ def test_tpu_compilation():
for i, compiled_fn in enumerate(compiled_fns):
print("{} file: {}".format(i + 1, compiled_fn))

# The first compilation is symbolic, so it should not have any kv_caches
# The first compilation should not have any kv_caches
with open(compiled_fns[0]) as f:
content = f.read()
assert kv_cache_prefix not in content

# The second compilation is symbolic, so it should not have any kv_caches
with open(compiled_fns[1]) as f:
content = f.read()
assert kv_cache_prefix not in content

# The third compilation is shape 16, so it should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[2]) as f:
content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content)

# The forth compilation is shape 32, so it should have kv_caches and the
# The second compilation should have kv_caches and the
# ragged_paged_attention
with open(compiled_fns[3]) as f:
with open(compiled_fns[1]) as f:
content = f.read()
assert (kv_cache_prefix in content and attn_prefix in content)