Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 166 additions & 34 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import itertools
from abc import abstractmethod
from typing import Optional, Union
from typing import Any, Literal, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter

Expand Down Expand Up @@ -84,6 +85,43 @@
return param[shard_id], loaded_weight


# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
"""
Separate the BitsAndBytes 4-bit shard.

For example, given bnb weight attributes as below:
{
'bnb_shard_offsets': array([0, 4, 8, 16]),
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
}

The function will return:
{
'bnb_shard_offsets': array([0, 4]),
'bnb_quant_state': {0: ...},
}
and
{
'bnb_shard_offsets': array([0, 4, 12]),
'bnb_quant_state': {0: ..., 1: ...},
}
"""
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
offset_l = shard_offsets[:2]
offset_r = shard_offsets[1:] - shard_offsets[1]
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
quant_state_r = {
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
for i in range(1,
len(shard_offsets) - 1)
}
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
return left, right


class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""

Expand Down Expand Up @@ -1229,7 +1267,24 @@
return s


class QKVCrossParallelLinear(torch.nn.Module):
class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.

Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""

def __init__(self,
hidden_size: int,
Expand All @@ -1241,12 +1296,33 @@
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# input_size and output_size are not used, just for alignment
input_size = hidden_size
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
super().__init__(input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)

self.quant_config = quant_config

if quant_config is None:
quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
quant_method = quant_config.get_quant_method(self, prefix=prefix)

# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
placeholder_size = 0
quant_method.create_weights(self,

Check failure on line 1319 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[QuantizeMethodBase]" has no attribute "create_weights" [union-attr]

Check failure on line 1319 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[QuantizeMethodBase]" has no attribute "create_weights" [union-attr]

Check failure on line 1319 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[QuantizeMethodBase]" has no attribute "create_weights" [union-attr]

Check failure on line 1319 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[QuantizeMethodBase]" has no attribute "create_weights" [union-attr]

Check failure on line 1319 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "QuantizeMethodBase | None" has no attribute "create_weights" [union-attr]
placeholder_size, [placeholder_size],
placeholder_size,
placeholder_size,
self.params_dtype,
weight_loader=self.weight_loader)

# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
Expand Down Expand Up @@ -1276,18 +1352,87 @@
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
"weight_loader": self.weight_loader,
})

@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
return layer

@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
return layer

def sync_weight_attrs(
self,
src_param: nn.Parameter,
tgt_param: nn.Parameter,
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
):
missing_attrs_dict = {
k: getattr(src_param, k)
for k in (set(src_param.__dict__.keys()) -
set(tgt_param.__dict__.keys()))
}
# TODO(Isotr0py): handle bitsandbytes 8bit
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
False)
if (missing_attrs_dict and use_bitsandbytes_4bit):
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
missing_attrs_dict)
if mode == "q_proj_decoder":
set_weight_attrs(tgt_param, q_proj_attrs)
elif mode == "kv_proj_encoder":
set_weight_attrs(tgt_param, kv_proj_attrs)
else:
set_weight_attrs(tgt_param, missing_attrs_dict)

def _is_same_param(
self,
src_param: torch.nn.Parameter,
map_param: torch.nn.Parameter,
) -> bool:
"""Check if two parameters are exactly pointing to same things."""
# ignore weight_loader because it's always different
key_to_ignore = ["weight_loader", "_weight_loader"]
has_same_type_name = type(src_param) is type(map_param)
src_param_attrs = {
k: v
for k, v in src_param.__dict__.items() if k not in key_to_ignore
}
map_param_attrs = {
k: v
for k, v in map_param.__dict__.items() if k not in key_to_ignore
}
has_same_attrs = src_param_attrs == map_param_attrs
return has_same_type_name and has_same_attrs

def selet_proj_params(
self,
layer: nn.Module,
param: nn.Parameter,
) -> nn.Parameter:
"""
Given the placeholder param,
return the corresponding param in the proj layers.
"""
target_param_list = [
v for _, v in layer.named_parameters()
if self._is_same_param(param, v)
]
assert len(target_param_list) == 1
target_param = target_param_list[0]
return target_param

def forward(self, decoder_hidden_states, encoder_hidden_states):

Check failure on line 1435 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "LinearBase" [override]

Check failure on line 1435 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "LinearBase" [override]

Check failure on line 1435 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "LinearBase" [override]

Check failure on line 1435 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "LinearBase" [override]

Check failure on line 1435 in vllm/model_executor/layers/linear.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "LinearBase" [override]
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
Expand All @@ -1300,25 +1445,12 @@
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v

def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)

def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.selet_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
41 changes: 11 additions & 30 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -813,20 +814,11 @@ def __init__(
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim

# TODO(Isotr0py): Use QKVCrossParallelLinear when it supports
# quantization
self.q_proj = ColumnParallelLinear(
input_size=self.hidden_size,
output_size=self.num_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
self.qkv_proj = QKVCrossParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.num_key_value_heads,
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
Expand Down Expand Up @@ -862,15 +854,11 @@ def forward(
kv_range_for_decode: Optional[List[Tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
if cross_attention_states is not None:
kv, _ = self.kv_proj(cross_attention_states)
k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
else:
k = v = None

q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
Expand Down Expand Up @@ -1161,13 +1149,8 @@ def forward(
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down Expand Up @@ -1437,11 +1420,9 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
(".cross_attn.kv_proj", ".cross_attn.k_proj", "k"),
(".cross_attn.kv_proj", ".cross_attn.v_proj", "v"),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
Expand Down Expand Up @@ -1570,4 +1551,4 @@ def convert_dense_cross_attention_mask_to_tensor(
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
return mask