Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,14 +670,6 @@ def _verify_cuda_graph(self) -> None:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)

MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
and not self.enforce_eager):
logger.warning(
"CUDA graph is not supported for %s yet, fallback to the eager "
"mode.", self.hf_config.model_type)
self.enforce_eager = True

def _verify_bnb_config(self) -> None:
"""
The current version of bitsandbytes (0.44.0) with 8-bit models does not
Expand Down
11 changes: 10 additions & 1 deletion vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.text_config.vocab_size)
self.sampler = get_sampler()

self.capture_mode = False

def compute_logits(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1366,12 +1368,19 @@ def forward(
cross_attention_mask = None
kv_range_for_decode = None

skip_cross_attention = False

# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None:
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor
!= 0).reshape(-1, 1).to(input_ids.device)
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0

if not self.capture_mode:
# NOTE: when doing cuda graph capture, we never want to skip
# cross attention. Skipping this line in such case enables
# CUDA graph capture to succeed.
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0

# For image-present prefill.
else:
Expand Down
6 changes: 6 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,6 +1935,9 @@ def capture(
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
if hasattr(self.model, "capture_mode"):
self.model.capture_mode = True

output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
positions=positions,
Expand All @@ -1960,6 +1963,9 @@ def capture(
gc.collect()
torch.cuda.synchronize()

if hasattr(self.model, "capture_mode"):
self.model.capture_mode = False

# Save the input and output buffers.
self.input_buffers = {
"input_ids":
Expand Down