File tree Expand file tree Collapse file tree 4 files changed +57
-1
lines changed
vllm/model_executor/layers/quantization/compressed_tensors Expand file tree Collapse file tree 4 files changed +57
-1
lines changed Original file line number Diff line number Diff line change 1414requests
1515ray
1616sentence-transformers # required for embedding
17+ sparseml==1.8.0 # required for compressed-tensors
18+ compressed-tensors==0.4.0 # required for compressed-tensors
1719
1820# Benchmarking
1921aiohttp
Original file line number Diff line number Diff line change @@ -176,6 +176,7 @@ def __init__(
176176 model_kwargs : Optional [Dict [str , Any ]] = None ,
177177 is_embedding_model : bool = False ,
178178 is_vision_model : bool = False ,
179+ is_sparseml_model : bool = False ,
179180 ) -> None :
180181 assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
181182 torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE [dtype ]
@@ -193,6 +194,9 @@ def __init__(
193194 else :
194195 if is_vision_model :
195196 auto_cls = AutoModelForVision2Seq
197+ elif is_sparseml_model :
198+ from sparseml .transformers import SparseAutoModelForCausalLM
199+ auto_cls = SparseAutoModelForCausalLM
196200 else :
197201 auto_cls = AutoModelForCausalLM
198202
Original file line number Diff line number Diff line change 1+ """Compares vllm vs sparseml for compressed-tensors
2+
3+ Note: vllm and sparseml do not have bitwise correctness,
4+ so in this test, we just confirm that the top selected
5+ tokens of the are in the top 5 selections of each other.
6+ """
7+
8+ import pytest
9+
10+ from tests .quantization .utils import is_quant_method_supported
11+
12+ from .utils import check_logprobs_close
13+
14+ MODELS = [
15+ "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test" ,
16+ ]
17+
18+ MAX_TOKENS = 32
19+ NUM_LOGPROBS = 5
20+
21+
22+ @pytest .mark .skipif (
23+ not is_quant_method_supported ("compressed-tensors" ),
24+ reason = "compressed-tensors is not supported on this machine type." )
25+ @pytest .mark .parametrize ("model_name" , MODELS )
26+ def test_models (
27+ vllm_runner ,
28+ hf_runner ,
29+ example_prompts ,
30+ model_name ,
31+ ) -> None :
32+ # Run sparseml.
33+ with hf_runner (model_name = model_name ,
34+ is_sparseml_model = True ) as sparseml_model :
35+
36+ sparseml_outputs = sparseml_model .generate_greedy_logprobs_limit (
37+ example_prompts , MAX_TOKENS , NUM_LOGPROBS )
38+
39+ # Run vllm.
40+ with vllm_runner (model_name = model_name ) as vllm_model :
41+ vllm_outputs = vllm_model .generate_greedy_logprobs (
42+ example_prompts , MAX_TOKENS , NUM_LOGPROBS )
43+
44+ check_logprobs_close (
45+ outputs_0_lst = sparseml_outputs ,
46+ outputs_1_lst = vllm_outputs ,
47+ name_0 = "sparseml" ,
48+ name_1 = "vllm" ,
49+ )
Original file line number Diff line number Diff line change @@ -34,7 +34,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
3434 return [torch .float16 , torch .bfloat16 ]
3535
3636 # Need to figure it out
37- def get_min_capability (self ) -> int :
37+ @classmethod
38+ def get_min_capability (cls ) -> int :
3839 return 60
3940
4041 def get_name (self ) -> str :
You can’t perform that action at this time.
0 commit comments