Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 0008715

Browse files
dsikkaRobert Shaw
authored andcommitted
[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (vllm-project#5542)
1 parent be2f123 commit 0008715

File tree

4 files changed

+45
-32
lines changed

4 files changed

+45
-32
lines changed

tests/quantization/test_compressed_tensors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
22-
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
22+
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
2323
with vllm_runner(model_path, enforce_eager=True) as llm:
2424
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
2525
layer = model.model.layers[0]
@@ -48,15 +48,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
4848

4949

5050
def test_compressed_tensors_no_enforce_eager(vllm_runner):
51-
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
51+
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
5252
with vllm_runner(model_path) as llm:
5353
sampling_params = SamplingParams()
5454
output = llm.generate("Hello world!", sampling_params=sampling_params)
5555
assert output
5656

5757

58-
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
59-
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
58+
@pytest.mark.parametrize("model_args", [
59+
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
60+
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
61+
])
62+
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
63+
model_path, strategy = model_args
6064
with vllm_runner(model_path, dtype=torch.float16) as llm:
6165
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
6266
layer = model.model.layers[0]
@@ -65,6 +69,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
6569

6670
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
6771
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
72+
assert qkv_proj.scheme.strategy == strategy
6873
assert qkv_proj.weight.dtype is torch.int8
6974

7075

vllm/model_executor/layers/linear.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -476,13 +476,6 @@ def weight_loader(self,
476476
"MergedColumnParallelLinear, assume the weight is "
477477
"the same for all partitions.")
478478

479-
if fp8_scales_shard_indexer is None:
480-
if len(param_data.shape) == 0:
481-
param_data = param_data.reshape(1)
482-
483-
if len(loaded_weight.shape) == 0:
484-
loaded_weight = loaded_weight.reshape(1)
485-
486479
# UPSTREAM SYNC: needed for LazyCompressedParameter
487480
self.loaded_shards.add(loaded_shard_id)
488481
assert param_data.shape == loaded_weight.shape
@@ -707,12 +700,6 @@ def weight_loader(self,
707700
"QKVParallelLinear, assume the weight is the same "
708701
"for all partitions.")
709702

710-
if len(param_data.shape) == 0:
711-
param_data = param_data.reshape(1)
712-
713-
if len(loaded_weight.shape) == 0:
714-
loaded_weight = loaded_weight.reshape(1)
715-
716703
assert param_data.shape == loaded_weight.shape
717704
param_data.copy_(loaded_weight)
718705

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,15 @@ def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
9595
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
9696
input_quant: BaseModel) -> bool:
9797
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
98-
is_token_tensor = (weight_quant.strategy
99-
== QuantizationStrategy.TENSOR.value) and (
100-
input_quant.strategy
101-
== QuantizationStrategy.TOKEN.value)
98+
weight_strategy = (
99+
weight_quant.strategy == QuantizationStrategy.TENSOR.value
100+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
101+
is_token = (weight_strategy and input_quant.strategy
102+
== QuantizationStrategy.TOKEN.value)
102103
is_symmetric = weight_quant.symmetric and input_quant.symmetric
103104
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
104105

105-
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
106+
return is_8_bits and is_token and is_symmetric and is_dynamic
106107

107108
def _is_w4a16(self, weight_quant: BaseModel,
108109
input_quant: BaseModel) -> bool:
@@ -133,7 +134,8 @@ def _get_schema(self, weight_quant: BaseModel,
133134
return CompressedTensorsW8A8StaticTensor()
134135

135136
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
136-
return CompressedTensorsW8A8DynamicToken()
137+
return CompressedTensorsW8A8DynamicToken(
138+
strategy=weight_quant.strategy)
137139

138140
raise NotImplementedError(
139141
"No compressed-tensors compatible scheme was found.")

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
from vllm import _custom_ops as custom_ops
77
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
88
CompressedTensorsScheme)
9+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
10+
QuantizationStrategy)
911
from vllm.model_executor.utils import set_weight_attrs
1012

1113
__all__ = ["CompressedTensorsW8A8DynamicToken"]
1214

1315

1416
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
1517

18+
def __init__(self, strategy: str):
19+
self.strategy = strategy
20+
1621
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
1722
if isinstance(shard_id, int):
1823
return shard_id
@@ -45,11 +50,17 @@ def create_weights(self, layer: torch.nn.Module,
4550
# CompressedTensorsW8A8StaticTensor::create_weights for further
4651
# information.
4752
is_tensor_partitioned = len(output_partition_sizes) != 1
48-
weight_scale_dim = sum(
49-
output_partition_sizes) if is_tensor_partitioned else 1
53+
# when doing channel-wise quantization, number of scales
54+
# is equal to output_dim
55+
weight_scale_dim = sum(output_partition_sizes) if (
56+
is_tensor_partitioned
57+
or self.strategy == QuantizationStrategy.CHANNEL) else 1
58+
59+
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
60+
if self.strategy == QuantizationStrategy.CHANNEL:
61+
shape = (weight_scale_dim, 1)
5062

51-
weight_scale = Parameter(torch.empty(weight_scale_dim,
52-
dtype=torch.float32),
63+
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
5364
requires_grad=False)
5465

5566
weight = Parameter(torch.empty(sum(output_partition_sizes),
@@ -67,12 +78,20 @@ def create_weights(self, layer: torch.nn.Module,
6778
})
6879

6980
layer.register_parameter("weight_scale", weight_scale)
70-
set_weight_attrs(
71-
weight_scale, {
72-
"weight_loader": weight_loader,
73-
"shard_splitter": self.scales_shard_splitter,
74-
"logical_widths": output_partition_sizes
81+
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
82+
83+
# Don't need a shard_splitter for channel-wise quantization
84+
# Use the default loading method
85+
if self.strategy == QuantizationStrategy.CHANNEL:
86+
set_weight_attrs(weight_scale, {
87+
"output_dim": 0,
7588
})
89+
else:
90+
set_weight_attrs(
91+
weight_scale, {
92+
"logical_widths": output_partition_sizes,
93+
"shard_splitter": self.scales_shard_splitter,
94+
})
7695

7796
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
7897
weight = layer.weight

0 commit comments

Comments
 (0)