diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index e520e99b071c..ad755fe7f7a0 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -8,7 +8,6 @@ from huggingface_hub import snapshot_download import vllm._custom_ops as ops -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf from vllm.platforms import current_platform @@ -176,12 +175,11 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype) - act = SiluAndMul() output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"), torch.tensor(w2.data, device="cuda"), topk_weights, - topk_ids, quant_type, quant_type, act) + topk_ids, quant_type, quant_type, "silu") ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights, topk_ids).reshape(output.shape) diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 3ff36502df57..5f17d12284a0 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -78,8 +78,12 @@ def gguf_model(self): ) MODELS = [ - LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, - DOLPHIN_CONFIG + LLAMA_CONFIG, + QWEN2_CONFIG, + PHI3_CONFIG, + GPT2_CONFIG, + # STABLELM_CONFIG, # enable this when v1 support head_size=80 + DOLPHIN_CONFIG, # STARCODER_CONFIG, # broken ] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b90880167dc..442e4100fea1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1291,14 +1291,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # Some quantization is not compatible with torch.compile. - V1_UNSUPPORTED_QUANT = ["gguf"] - if model_config.quantization in V1_UNSUPPORTED_QUANT: - _raise_or_fallback( - feature_name=f"--quantization {model_config.quantization}", - recommend_to_remove=False) - return False - # No Embedding Models so far. if model_config.task not in ["generate"]: _raise_or_fallback(feature_name=f"--task {model_config.task}", diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dd2e477f3954..269ac043d26c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -587,8 +587,6 @@ def weight_loader(self, param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 2: - self.qweight = param.materialize_nested() return param_data = param.data @@ -982,8 +980,6 @@ def weight_loader(self, param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) - if len(param.data_container) == 3: - self.qweight = param.materialize_nested() return param_data = param.data diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index d7d4a5d6acdb..1fcb6d7afc9b 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -9,7 +9,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, FusedMoEMethodBase) from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase @@ -19,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module, MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: # HACK: when doing chunked prefill we don't generate output tokens # so input to logits generator is empty which causes invalid parameter if x.shape[0] == 0: @@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, return y +def _fused_mul_mat_gguf_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, +) -> torch.Tensor: + return torch.empty(x.shape[0], + qweight.shape[0], + dtype=x.dtype, + device=x.device) + + +try: + direct_register_custom_op( + op_name="_fused_mul_mat_gguf", + op_func=_fused_mul_mat_gguf, + mutates_args=[], + fake_impl=_fused_mul_mat_gguf_fake, + ) + fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf + +except AttributeError as error: + raise error + + def _fused_moe_gguf( x: torch.Tensor, w1: torch.Tensor, @@ -138,8 +162,21 @@ def _fused_moe_gguf( topk_ids: torch.Tensor, qweight_type: int, qweight_type2: int, - act, + activation: str, ) -> torch.Tensor: + + def act(x: torch.Tensor): + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if activation == "silu": + torch.ops._C.silu_and_mul(out, x) + elif activation == "gelu": + torch.ops._C.gelu_and_mul(out, x) + else: + raise ValueError(f"Unsupported activation: {activation}") + return out + # lazy import to avoid triggering triton import in CPU backend from vllm.model_executor.layers.fused_moe.fused_moe import ( moe_align_block_size) @@ -189,12 +226,12 @@ def _fused_moe_gguf( for ww, ii in zip(w, idx): expert_up = w1[ii] - out = _fuse_mul_mat(inp, expert_up, qweight_type) + out = fused_mul_mat_gguf(inp, expert_up, qweight_type) out = act(out) expert_down = w2[ii] - current_state = _fuse_mul_mat(out, expert_down, - qweight_type2).mul_(ww) + current_state = fused_mul_mat_gguf(out, expert_down, + qweight_type2).mul_(ww) if current_hidden_state is None: current_hidden_state = current_state else: @@ -203,6 +240,78 @@ def _fused_moe_gguf( return out_hidden_states +def _fused_moe_gguf_fake( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + qweight_type: int, + qweight_type2: int, + activation: str, +) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="_fused_moe_gguf", + op_func=_fused_moe_gguf, + mutates_args=[], + fake_impl=_fused_moe_gguf_fake, + ) + fused_moe_gguf = torch.ops.vllm._fused_moe_gguf + +except AttributeError as error: + raise error + + +def _apply_gguf_embedding( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if qweight_type in UNQUANTIZED_TYPES: + return torch.embedding(qweight, x) + elif qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + x_flat = x.flatten() + assert (hidden_size == qweight.shape[1] // type_size * block_size) + quant = torch.index_select(qweight, dim=0, index=x_flat) + dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, + x_flat.shape[0], dtype) + return dequant.view(*x.shape, hidden_size) + else: + qweight_type = WeightType(qweight_type) + raise NotImplementedError( + f"Unsupported GGUF quantization type: {qweight_type}") + + +def _apply_gguf_embedding_fake( + x: torch.Tensor, + qweight: torch.Tensor, + qweight_type: int, + hidden_size: int, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device) + + +try: + direct_register_custom_op( + op_name="_apply_gguf_embedding", + op_func=_apply_gguf_embedding, + mutates_args=[], + fake_impl=_apply_gguf_embedding_fake, + ) + apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding + +except AttributeError as error: + raise error + + class GGUFLinearMethod(LinearMethodBase): """Linear method for GGUF. @@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(qweight_type, extra_weight_attrs) layer.register_parameter("qweight_type", qweight_type) + def process_weights_after_loading(self, layer: torch.nn.Module): + qweight_type = layer.qweight_type.weight_type + if not (qweight_type in UNQUANTIZED_TYPES + or qweight_type in DEQUANT_TYPES): + qweight_type = WeightType(qweight_type) + raise ValueError( + f"Unsupported GGUF quantization type {qweight_type} in " + f"layer {layer}.") + # For MergedColumnParallelLinear and QKVParallelLinear, we need to + # materialize the padded weight parameter for CUDA Graph compatibility. + self._create_padded_weight_param(layer) + + def _create_padded_weight_param(self, layer: torch.nn.Module): + """Create padded weight parameter for GGUF MergedLinear layer.""" + qweight = layer.qweight + shard_id_map = qweight.shard_id_map + shard_id = qweight.shard_id + if len(data_container := qweight.data_container) > 1: + dtype = {data.dtype for data in data_container} + assert len(dtype) == 1, ValueError( + f"Data container has mixed dtypes: {dtype}") + dtype = next(iter(dtype)) + # concat dim0 and pad dim1 + padded_side = max(x.size(1) for x in data_container) + concat_side = sum(x.size(0) for x in data_container) + # Pad the quantized weights to dense tensor, and create a map + # with the location of each shard in the padded tensor. + padded_data = torch.zeros((concat_side, padded_side), + dtype=dtype, + device=qweight.device) + # (dim0_start, dim0_end, dim1_size) + shard_offset_map = dict[str, tuple[int, int, int]]() + for idx in shard_id: + id_in_container = shard_id_map[idx] + start = sum( + x.size(0) for x in data_container[:id_in_container]) + end = start + data_container[id_in_container].size(0) + size = data_container[id_in_container].size(1) + padded_data[start:end, :size] = data_container[id_in_container] + shard_offset_map[idx] = (start, end, size) + qweight.data_container.clear() + padded_param = Parameter(padded_data, requires_grad=False) + set_weight_attrs(padded_param, vars(qweight)) + set_weight_attrs(padded_param, + {"shard_offset_map": shard_offset_map}) + layer.register_parameter("qweight", padded_param) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - shard_id = getattr(layer.qweight, "shard_id", None) + shard_id = layer.qweight.shard_id if shard_id: # dequantize shard weights respectively shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id - qweight = layer.qweight.unbind(0) + qweight = layer.qweight result = [] for idx in shard_id: - q_idx = layer.qweight.shard_id_map[idx] + start, end, offset = layer.qweight.shard_offset_map[idx] qweight_type = layer.qweight_type.shard_weight_type[idx] - result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) + result.append( + fused_mul_mat_gguf( + x, qweight[start:end, :offset].contiguous(), + qweight_type)) out = torch.cat(result, axis=1) else: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type - out = _fuse_mul_mat(x, qweight, qweight_type) + out = fused_mul_mat_gguf(x, qweight, qweight_type) if bias is not None: out.add_(bias) return out @@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_qweight_type, extra_weight_attrs) layer.register_parameter("w2_qweight_type", w2_qweight_type) - self.act = SiluAndMul() def apply( self, @@ -375,10 +533,10 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, - topk_weights, topk_ids, - layer.w13_qweight_type.weight_type, - layer.w2_qweight_type.weight_type, self.act) + return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, + topk_weights, topk_ids, + layer.w13_qweight_type.weight_type, + layer.w2_qweight_type.weight_type, activation) class GGUFEmbeddingMethod(GGUFLinearMethod): @@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type + hidden_size = qweight.tensor_shape[1] - block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] - hidden_size = qweight.shape[1] // type_size * block_size - if qweight_type < 2: - return torch.embedding(qweight, x) - x_flat = x.flatten() - quant = torch.index_select(qweight, dim=0, index=x_flat) - dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, - x_flat.shape[0], self.params_dtype) - return dequant.view(*x.shape, hidden_size) + return apply_gguf_embedding(x, + qweight, + qweight_type, + hidden_size, + dtype=self.params_dtype) class GGUFUninitializedParameter(UninitializedParameter): cls_to_become = Parameter data_container: list[torch.Tensor] - - def materialize_nested(self) -> Parameter: - dtype = {data.dtype for data in self.data_container} - assert len(dtype) == 1, ValueError( - f"Data container has mixed dtypes: {dtype}") - dtype = next(iter(dtype)) - nested_data = torch.nested.nested_tensor(self.data_container, - device=self.device, - dtype=dtype) - self.data_container.clear() - param = torch.Tensor._make_subclass(self.cls_to_become, - nested_data, - require_grad=False) - for k, v in self.__dict__.items(): - setattr(param, k, v) - return param