Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def forward(

if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
Expand Down Expand Up @@ -705,7 +705,7 @@ def forward(
self.calc_kv_scales(q, kv_c_normed, k_pe)

if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.impl.forward(
self,
q,
Expand All @@ -722,7 +722,7 @@ def forward(
)
else:
if self.attn_backend.accept_output_buffer:
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

attn_type = self.attn_type

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
# return torch.empty_like(query)

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

assert attn_metadata.use_cascade is False

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

assert attn_metadata.use_cascade is False

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

assert attn_metadata.use_cascade is False

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def forward(

if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)

# Cache the input KVs.
key_cache, value_cache = kv_cache.unbind(0)
Expand Down