diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 1e202907171f..08e4b1b2f309 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -17,6 +17,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) +from ....quantization.utils import is_quant_method_supported from ....utils import large_gpu_test from ...utils import check_logprobs_close @@ -397,6 +398,50 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, ) +@large_gpu_test(min_gb=48) +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["float16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +def test_bnb_regression( + image_assets: _ImageAssets, + model: str, + dtype: str, + max_tokens: int, +): + stop_sign = image_assets[0].pil_image + prompts = [ + { + "prompt": "<|begin_of_text|>The content of the image <|image|> is", + "multi_modal_data": { + "image": stop_sign + }, + }, + { + "prompt": + "The color of the sky is blue but sometimes it can also be", + }, + ] + # Test regression about QKVCrossParallelLinear + llm = LLM( + model=model, + dtype=dtype, + max_model_len=4096, + max_num_seqs=2, + enforce_eager=True, + quantization="bitsandbytes", + load_format="bitsandbytes", + ) + sampling_params = SamplingParams( + temperature=0, + max_tokens=max_tokens, + ) + outputs = llm.generate(prompts, sampling_params) + assert outputs + + @large_gpu_test(min_gb=48) @pytest.mark.core_model @pytest.mark.parametrize("model", models) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c96e2b220d6b..3912c53e183d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,9 +2,10 @@ import itertools from abc import abstractmethod -from typing import Optional, Union +from typing import Any, Literal, Optional, Union import torch +import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter @@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight +# TODO(Isotr0py): We might need a more flexible structure to handle +# bitsandbytes shard offsets. +def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): + """ + Separate the BitsAndBytes 4-bit shard. + + For example, given bnb weight attributes as below: + { + 'bnb_shard_offsets': array([0, 4, 8, 16]), + 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, + } + + The function will return: + { + 'bnb_shard_offsets': array([0, 4]), + 'bnb_quant_state': {0: ...}, + } + and + { + 'bnb_shard_offsets': array([0, 4, 12]), + 'bnb_quant_state': {0: ..., 1: ...}, + } + """ + shard_offsets = bnb_weight_attrs["bnb_shard_offsets"] + offset_l = shard_offsets[:2] + offset_r = shard_offsets[1:] - shard_offsets[1] + quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} + quant_state_r = { + i - 1: bnb_weight_attrs["bnb_quant_state"][i] + for i in range(1, + len(shard_offsets) - 1) + } + left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) + right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) + return left, right + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -1229,7 +1267,24 @@ def extra_repr(self) -> str: return s -class QKVCrossParallelLinear(torch.nn.Module): +class QKVCrossParallelLinear(LinearBase): + """Linear layers for efficient cross-attention's QKV transformation. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ def __init__(self, hidden_size: int, @@ -1241,12 +1296,28 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): - super().__init__() + # input_size and output_size are not used, just for alignment + input_size = hidden_size + output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size + super().__init__(input_size=input_size, + output_size=output_size, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + + self.quant_config = quant_config + # Empty placeholders for loading as a single module. - self.weight = torch.nn.Parameter() - set_weight_attrs(self.weight, { - "weight_loader": self.weight_loader_weight, - }) + placeholder_size = 0 + assert self.quant_method is not None + self.quant_method.create_weights(self, + placeholder_size, [placeholder_size], + placeholder_size, + placeholder_size, + self.params_dtype, + weight_loader=self.weight_loader) + # Use a dictionary to avoid submodules parameters auto-registration: # drop-in replacement for a `QKVParallelLinear` module. self.proj = dict() @@ -1276,18 +1347,94 @@ def __init__(self, if bias: self.bias = torch.nn.Parameter() set_weight_attrs(self.bias, { - "weight_loader": self.weight_loader_bias, + "output_dim": 0, + "weight_loader": self.weight_loader, }) + else: + self.bias = None @property - def q_proj_decoder(self): - return self.proj["q_proj_decoder"] + def q_proj_decoder(self) -> ColumnParallelLinear: + layer = self.proj["q_proj_decoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name) + self.sync_weight_attrs(param, target_param, mode="q_proj_decoder") + return layer @property - def kv_proj_encoder(self): - return self.proj["kv_proj_encoder"] + def kv_proj_encoder(self) -> QKVParallelLinear: + layer = self.proj["kv_proj_encoder"] + for name, param in self.named_parameters(): + target_param = getattr(layer, name) + self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder") + return layer + + def sync_weight_attrs( + self, + src_param: nn.Parameter, + tgt_param: nn.Parameter, + mode: Literal["q_proj_decoder", "kv_proj_encoder"], + ): + missing_attrs_dict = { + k: getattr(src_param, k) + for k in (set(src_param.__dict__.keys()) - + set(tgt_param.__dict__.keys())) + } + # TODO(Isotr0py): handle bitsandbytes 8bit + use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit", + False) + if (missing_attrs_dict and use_bitsandbytes_4bit): + q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard( + missing_attrs_dict) + if mode == "q_proj_decoder": + set_weight_attrs(tgt_param, q_proj_attrs) + elif mode == "kv_proj_encoder": + set_weight_attrs(tgt_param, kv_proj_attrs) + else: + set_weight_attrs(tgt_param, missing_attrs_dict) - def forward(self, decoder_hidden_states, encoder_hidden_states): + def _is_same_param( + self, + src_param: torch.nn.Parameter, + map_param: torch.nn.Parameter, + ) -> bool: + """Check if two parameters are exactly pointing to same things.""" + # ignore weight_loader because it's always different + key_to_ignore = ["weight_loader", "_weight_loader"] + has_same_type_name = type(src_param) is type(map_param) + src_param_attrs = { + k: v + for k, v in src_param.__dict__.items() if k not in key_to_ignore + } + map_param_attrs = { + k: v + for k, v in map_param.__dict__.items() if k not in key_to_ignore + } + has_same_attrs = src_param_attrs == map_param_attrs + return has_same_type_name and has_same_attrs + + def select_proj_params( + self, + layer: nn.Module, + param: nn.Parameter, + ) -> nn.Parameter: + """ + Given the placeholder param, + return the corresponding param in the proj layers. + """ + target_param_list = [ + v for _, v in layer.named_parameters() + if self._is_same_param(param, v) + ] + assert len(target_param_list) == 1 + target_param = target_param_list[0] + return target_param + + def forward( # type: ignore[override] + self, + decoder_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: q, _ = self.q_proj_decoder(decoder_hidden_states) if encoder_hidden_states is None: # Encoder KV already cached. @@ -1300,25 +1447,21 @@ def forward(self, decoder_hidden_states, encoder_hidden_states): k, v = kv_enc.split(self.kv_size, dim=-1) return q, k, v - def weight_loader_weight(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - # NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param. - param = self.q_proj_decoder.weight if loaded_shard_id == "q" \ - else self.kv_proj_encoder.weight - param.weight_loader( - param, - loaded_weight) if loaded_shard_id == "q" else param.weight_loader( - param, loaded_weight, loaded_shard_id) - - def weight_loader_bias(self, - param: torch.nn.Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None): - param = self.q_proj_decoder.bias if loaded_shard_id == "q" \ - else self.kv_proj_encoder.bias - param.weight_loader( - param, - loaded_weight) if loaded_shard_id == "q" else param.weight_loader( - param, loaded_weight, loaded_shard_id) \ No newline at end of file + def weight_loader(self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + layer = (self.q_proj_decoder + if loaded_shard_id == "q" else self.kv_proj_encoder) + target_param = self.select_proj_params(layer, param) + shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else () + layer.weight_loader(target_param, loaded_weight, *shard_id_args) + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", q_size={self.q_proj_decoder.output_size_per_partition}" + s += f", kv_size={self.kv_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += ", gather_output=False" + return s diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45f5dea08521..afc30f93b524 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -43,6 +43,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVCrossParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -813,20 +814,11 @@ def __init__( self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim - # TODO(Isotr0py): Use QKVCrossParallelLinear when it supports - # quantization - self.q_proj = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.num_heads * self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - ) - self.kv_proj = QKVParallelLinear( + self.qkv_proj = QKVCrossParallelLinear( self.hidden_size, self.head_dim, - total_num_heads=0, - total_num_kv_heads=self.num_key_value_heads, + self.num_heads, + self.num_key_value_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -862,15 +854,11 @@ def forward( kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], ) -> torch.Tensor: - q, _ = self.q_proj(hidden_states) + q, k, v = self.qkv_proj(hidden_states, cross_attention_states) if cross_attention_states is not None: - kv, _ = self.kv_proj(cross_attention_states) - k, v = kv.split([self.kv_local_size, self.kv_local_size], dim=-1) k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) - else: - k = v = None q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) @@ -1161,13 +1149,8 @@ def forward( class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsV0Only): packed_modules_mapping = { - "self_attn.qkv_proj": [ - "self_attn.q_proj", - "self_attn.k_proj", - "self_attn.v_proj", - ], - "cross_attn.kv_proj": ["cross_attn.k_proj", "cross_attn.v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -1437,11 +1420,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), - (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), - (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), - (".cross_attn.kv_proj", ".cross_attn.k_proj", "k"), - (".cross_attn.kv_proj", ".cross_attn.v_proj", "v"), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] @@ -1570,4 +1551,4 @@ def convert_dense_cross_attention_mask_to_tensor( full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) mask *= full_text_mask # (num_prompt_tokens, num_encoder_tokens) - return mask + return mask \ No newline at end of file