Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/kernels/quantization/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from huggingface_hub import snapshot_download

import vllm._custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
from vllm.platforms import current_platform
Expand Down Expand Up @@ -176,12 +175,11 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,

w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
act = SiluAndMul()

output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, act)
topk_ids, quant_type, quant_type, "silu")

ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,14 +1291,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# Some quantization is not compatible with torch.compile.
V1_UNSUPPORTED_QUANT = ["gguf"]
if model_config.quantization in V1_UNSUPPORTED_QUANT:
_raise_or_fallback(
feature_name=f"--quantization {model_config.quantization}",
recommend_to_remove=False)
return False

# No Embedding Models so far.
if model_config.task not in ["generate"]:
_raise_or_fallback(feature_name=f"--task {model_config.task}",
Expand Down
4 changes: 0 additions & 4 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ def weight_loader(self,
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return

param_data = param.data
Expand Down Expand Up @@ -982,8 +980,6 @@ def weight_loader(self,
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return

param_data = param.data
Expand Down
223 changes: 181 additions & 42 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
Expand All @@ -19,6 +18,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -96,8 +96,8 @@ def get_quant_method(self, layer: torch.nn.Module,
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# HACK: when doing chunked prefill we don't generate output tokens
# so input to logits generator is empty which causes invalid parameter
if x.shape[0] == 0:
Expand Down Expand Up @@ -130,6 +130,30 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
return y


def _fused_mul_mat_gguf_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
) -> torch.Tensor:
return torch.empty(x.shape[0],
qweight.shape[0],
dtype=x.dtype,
device=x.device)


try:
direct_register_custom_op(
op_name="_fused_mul_mat_gguf",
op_func=_fused_mul_mat_gguf,
mutates_args=[],
fake_impl=_fused_mul_mat_gguf_fake,
)
fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf

except AttributeError as error:
raise error


def _fused_moe_gguf(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be overkill to do it in this PR, but we should aim to migrate this to use the modular kernel approach, inheriting from classes like

class FusedMoEPermuteExpertsUnpermute(ABC):

This will take care of activation dispatch and other features

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great. Let me migrate it in a following PR!

x: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -138,8 +162,21 @@ def _fused_moe_gguf(
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
act,
activation: str,
) -> torch.Tensor:

def act(x: torch.Tensor):
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if activation == "silu":
torch.ops._C.silu_and_mul(out, x)
elif activation == "gelu":
torch.ops._C.gelu_and_mul(out, x)
else:
raise ValueError(f"Unsupported activation: {activation}")
return out

# lazy import to avoid triggering triton import in CPU backend
from vllm.model_executor.layers.fused_moe.fused_moe import (
moe_align_block_size)
Expand Down Expand Up @@ -189,12 +226,12 @@ def _fused_moe_gguf(
for ww, ii in zip(w, idx):
expert_up = w1[ii]

out = _fuse_mul_mat(inp, expert_up, qweight_type)
out = fused_mul_mat_gguf(inp, expert_up, qweight_type)
out = act(out)

expert_down = w2[ii]
current_state = _fuse_mul_mat(out, expert_down,
qweight_type2).mul_(ww)
current_state = fused_mul_mat_gguf(out, expert_down,
qweight_type2).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
Expand All @@ -203,6 +240,78 @@ def _fused_moe_gguf(
return out_hidden_states


def _fused_moe_gguf_fake(
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
qweight_type: int,
qweight_type2: int,
activation: str,
) -> torch.Tensor:
return torch.empty_like(x)


try:
direct_register_custom_op(
op_name="_fused_moe_gguf",
op_func=_fused_moe_gguf,
mutates_args=[],
fake_impl=_fused_moe_gguf_fake,
)
fused_moe_gguf = torch.ops.vllm._fused_moe_gguf

except AttributeError as error:
raise error


def _apply_gguf_embedding(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if qweight_type in UNQUANTIZED_TYPES:
return torch.embedding(qweight, x)
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
x_flat = x.flatten()
assert (hidden_size == qweight.shape[1] // type_size * block_size)
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], dtype)
return dequant.view(*x.shape, hidden_size)
else:
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")


def _apply_gguf_embedding_fake(
x: torch.Tensor,
qweight: torch.Tensor,
qweight_type: int,
hidden_size: int,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return torch.empty(x.shape[0], hidden_size, dtype=dtype, device=x.device)


try:
direct_register_custom_op(
op_name="_apply_gguf_embedding",
op_func=_apply_gguf_embedding,
mutates_args=[],
fake_impl=_apply_gguf_embedding_fake,
)
apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding

except AttributeError as error:
raise error


class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.

Expand Down Expand Up @@ -249,26 +358,76 @@ def create_weights(self, layer: torch.nn.Module,
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)

def process_weights_after_loading(self, layer: torch.nn.Module):
qweight_type = layer.qweight_type.weight_type
if not (qweight_type in UNQUANTIZED_TYPES
or qweight_type in DEQUANT_TYPES):
qweight_type = WeightType(qweight_type)
raise ValueError(
f"Unsupported GGUF quantization type {qweight_type} in "
f"layer {layer}.")
# For MergedColumnParallelLinear and QKVParallelLinear, we need to
# materialize the padded weight parameter for CUDA Graph compatibility.
self._create_padded_weight_param(layer)

def _create_padded_weight_param(self, layer: torch.nn.Module):
"""Create padded weight parameter for GGUF MergedLinear layer."""
qweight = layer.qweight
shard_id_map = qweight.shard_id_map
shard_id = qweight.shard_id
if len(data_container := qweight.data_container) > 1:
dtype = {data.dtype for data in data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
# concat dim0 and pad dim1
padded_side = max(x.size(1) for x in data_container)
concat_side = sum(x.size(0) for x in data_container)
# Pad the quantized weights to dense tensor, and create a map
# with the location of each shard in the padded tensor.
padded_data = torch.zeros((concat_side, padded_side),
dtype=dtype,
device=qweight.device)
# (dim0_start, dim0_end, dim1_size)
shard_offset_map = dict[str, tuple[int, int, int]]()
for idx in shard_id:
id_in_container = shard_id_map[idx]
start = sum(
x.size(0) for x in data_container[:id_in_container])
end = start + data_container[id_in_container].size(0)
size = data_container[id_in_container].size(1)
padded_data[start:end, :size] = data_container[id_in_container]
shard_offset_map[idx] = (start, end, size)
qweight.data_container.clear()
padded_param = Parameter(padded_data, requires_grad=False)
set_weight_attrs(padded_param, vars(qweight))
set_weight_attrs(padded_param,
{"shard_offset_map": shard_offset_map})
layer.register_parameter("qweight", padded_param)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_id = getattr(layer.qweight, "shard_id", None)
shard_id = layer.qweight.shard_id

if shard_id:
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0)
qweight = layer.qweight
result = []
for idx in shard_id:
q_idx = layer.qweight.shard_id_map[idx]
start, end, offset = layer.qweight.shard_offset_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
result.append(
fused_mul_mat_gguf(
x, qweight[start:end, :offset].contiguous(),
qweight_type))
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = _fuse_mul_mat(x, qweight, qweight_type)
out = fused_mul_mat_gguf(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
Expand Down Expand Up @@ -338,7 +497,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()

def apply(
self,
Expand Down Expand Up @@ -375,10 +533,10 @@ def apply(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return _fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, self.act)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,
layer.w2_qweight_type.weight_type, activation)


class GGUFEmbeddingMethod(GGUFLinearMethod):
Expand All @@ -392,34 +550,15 @@ def embedding(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]

block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
hidden_size = qweight.shape[1] // type_size * block_size
if qweight_type < 2:
return torch.embedding(qweight, x)
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size)
return apply_gguf_embedding(x,
qweight,
qweight_type,
hidden_size,
dtype=self.params_dtype)


class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: list[torch.Tensor]

def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=dtype)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
require_grad=False)
for k, v in self.__dict__.items():
setattr(param, k, v)
return param