88from vllm .logger import init_logger
99from vllm .model_executor .layers .linear import (LinearBase ,
1010 UnquantizedLinearMethod )
11+ from vllm .model_executor .layers .quantization import QuantizationMethods
1112from vllm .model_executor .layers .quantization .base_config import (
1213 QuantizationConfig )
1314from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
@@ -74,7 +75,7 @@ def __repr__(self) -> str:
7475 f"group_size={ self .group_size } , sym={ self .sym } )" )
7576
7677 @classmethod
77- def get_name (cls ): ## use str will trigger preci issue
78+ def get_name (cls ) -> QuantizationMethods :
7879 return "auto-round"
7980
8081 @classmethod
@@ -142,18 +143,18 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
142143 prefix , layer .__class__ .__name__ , weight_bits , group_size ,
143144 sym )
144145 if backend == "auto" or "marlin" in backend :
146+ AWQ_TYPE_MAP = {
147+ 4 : scalar_types .uint4 ,
148+ 8 : scalar_types .uint8 ,
149+ }
150+ use_marlin = (weight_bits
151+ in AWQ_TYPE_MAP ) and check_marlin_supported (
152+ AWQ_TYPE_MAP [weight_bits ], group_size , not sym )
153+
145154 if isinstance (layer , FusedMoE ):
146- use_marlin = check_moe_marlin_supports_layer (layer , group_size )
147- else :
155+ use_marlin = use_marlin and check_moe_marlin_supports_layer (
156+ layer , group_size )
148157
149- AWQ_TYPE_MAP = {
150- 4 : scalar_types .uint4 ,
151- 8 : scalar_types .uint8 ,
152- }
153- use_marlin = ((weight_bits , sym ) in AWQ_TYPE_MAP
154- and check_marlin_supported (
155- AWQ_TYPE_MAP [(weight_bits )], group_size ,
156- not sym ))
157158 else :
158159 use_marlin = False
159160 if use_marlin :
@@ -180,10 +181,11 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
180181 from vllm .model_executor .layers .quantization .moe_wna16 import (
181182 MoeWNA16Config )
182183 config = {
183- "linear_quant_method " : "awq" ,
184- "weight_bits " : weight_bits ,
184+ "quant_method " : "awq" ,
185+ "bits " : weight_bits ,
185186 "group_size" : group_size ,
186187 "zero_point" : not sym ,
188+ "lm_head" : False ,
187189 }
188190 return MoeWNA16Config .from_config (config ).get_quant_method (
189191 layer , prefix )
@@ -213,18 +215,18 @@ def apply_gptq_quant_layer(self,
213215 prefix , layer .__class__ .__name__ , weight_bits , group_size ,
214216 sym )
215217 if backend == "auto" or "marlin" in backend :
218+ GPTQ_TYPE_MAP = {
219+ (4 , True ): scalar_types .uint4b8 ,
220+ (8 , True ): scalar_types .uint8b128 ,
221+ }
222+ use_marlin = ((weight_bits , sym ) in GPTQ_TYPE_MAP
223+ and check_marlin_supported (
224+ GPTQ_TYPE_MAP [(weight_bits , sym )],
225+ group_size ,
226+ has_zp = not sym ))
216227 if isinstance (layer , FusedMoE ):
217- use_marlin = check_moe_marlin_supports_layer (layer , group_size )
218- else :
219- GPTQ_TYPE_MAP = {
220- (4 , True ): scalar_types .uint4b8 ,
221- (8 , True ): scalar_types .uint8b128 ,
222- }
223- use_marlin = ((weight_bits , sym ) in GPTQ_TYPE_MAP
224- and check_marlin_supported (
225- GPTQ_TYPE_MAP [(weight_bits , sym )],
226- group_size ,
227- has_zp = not sym ))
228+ use_marlin = use_marlin and check_moe_marlin_supports_layer (
229+ layer , group_size )
228230 else :
229231 use_marlin = False
230232 if use_marlin :
@@ -251,11 +253,11 @@ def apply_gptq_quant_layer(self,
251253 from vllm .model_executor .layers .quantization .moe_wna16 import (
252254 MoeWNA16Config )
253255 config = {
254- "linear_quant_method " : "gptq" ,
255- "weight_bits " : weight_bits ,
256+ "quant_method " : "gptq" ,
257+ "bits " : weight_bits ,
256258 "group_size" : group_size ,
257259 "sym" : sym ,
258- "lm_head_quantized " : False ,
260+ "lm_head " : False ,
259261 }
260262 return MoeWNA16Config .from_config (config ).get_quant_method (
261263 layer , prefix )
0 commit comments