Skip to content

Commit d5ee4f3

Browse files
authored
add configs with their models (#421)
* add configs with their models * fix tests * doc update * doc update * fix path
1 parent 2b19f45 commit d5ee4f3

File tree

20 files changed

+342
-378
lines changed

20 files changed

+342
-378
lines changed

docs/source/_toctree.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
- local: package_reference/evaluation_tracker
3434
title: EvaluationTracker
3535
- local: package_reference/models
36-
title: Models
37-
- local: package_reference/model_config
38-
title: ModelConfig
36+
title: Models and ModelConfigs
3937
- local: package_reference/pipeline
4038
title: Pipeline
4139
title: Main classes

docs/source/package_reference/model_config.mdx

Lines changed: 0 additions & 10 deletions
This file was deleted.

docs/source/package_reference/models.mdx

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,38 @@
44
### LightevalModel
55
[[autodoc]] models.abstract_model.LightevalModel
66

7+
78
## Accelerate and Transformers Models
89
### BaseModel
9-
[[autodoc]] models.base_model.BaseModel
10+
[[autodoc]] models.transformers.base_model.BaseModelConfig
11+
[[autodoc]] models.transformers.base_model.BaseModel
12+
1013
### AdapterModel
11-
[[autodoc]] models.adapter_model.AdapterModel
14+
[[autodoc]] models.transformers.adapter_model.AdapterModelConfig
15+
[[autodoc]] models.transformers.adapter_model.AdapterModel
16+
1217
### DeltaModel
13-
[[autodoc]] models.delta_model.DeltaModel
18+
[[autodoc]] models.transformers.delta_model.DeltaModelConfig
19+
[[autodoc]] models.transformers.delta_model.DeltaModel
1420

15-
## Inference Endpoints and TGI Models
21+
## Endpoints-based Models
1622
### InferenceEndpointModel
17-
[[autodoc]] models.endpoint_model.InferenceEndpointModel
18-
### ModelClient
19-
[[autodoc]] models.tgi_model.ModelClient
23+
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
24+
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
25+
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel
26+
27+
### TGI ModelClient
28+
[[autodoc]] models.endpoints.tgi_model.TGIModelConfig
29+
[[autodoc]] models.endpoints.tgi_model.ModelClient
30+
31+
### Open AI Models
32+
[[autodoc]] models.endpoints.openai_model.OpenAIClient
2033

2134
## Nanotron Model
2235
### NanotronLightevalModel
23-
[[autodoc]] models.nanotron_model.NanotronLightevalModel
36+
[[autodoc]] models.nanotron.nanotron_model.NanotronLightevalModel
2437

2538
## VLLM Model
2639
### VLLMModel
27-
[[autodoc]] models.vllm_model.VLLMModel
40+
[[autodoc]] models.vllm.vllm_model.VLLMModelConfig
41+
[[autodoc]] models.vllm.vllm_model.VLLMModel

src/lighteval/main_accelerate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def accelerate( # noqa C901
107107
from accelerate import Accelerator, InitProcessGroupKwargs
108108

109109
from lighteval.logging.evaluation_tracker import EvaluationTracker
110-
from lighteval.models.model_config import AdapterModelConfig, BaseModelConfig, BitsAndBytesConfig, DeltaModelConfig
110+
from lighteval.models.transformers.adapter_model import AdapterModelConfig
111+
from lighteval.models.transformers.base_model import BaseModelConfig, BitsAndBytesConfig
112+
from lighteval.models.transformers.delta_model import DeltaModelConfig
111113
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
112114

113115
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])

src/lighteval/main_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def inference_endpoint(
201201
import yaml
202202

203203
from lighteval.logging.evaluation_tracker import EvaluationTracker
204-
from lighteval.models.model_config import (
204+
from lighteval.models.endpoints.endpoint_model import (
205205
InferenceEndpointModelConfig,
206206
)
207207
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

src/lighteval/main_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def vllm(
8989
Evaluate models using vllm as backend.
9090
"""
9191
from lighteval.logging.evaluation_tracker import EvaluationTracker
92-
from lighteval.models.model_config import VLLMModelConfig
92+
from lighteval.models.vllm.vllm_model import VLLMModelConfig
9393
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters
9494

9595
TOKEN = os.getenv("HF_TOKEN")

src/lighteval/models/dummy_model.py renamed to src/lighteval/models/dummy/dummy_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
# inspired by https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/dummy.py
2424

2525
import random
26+
from dataclasses import dataclass
2627
from typing import Optional
2728

2829
from transformers import AutoTokenizer
2930

3031
from lighteval.models.abstract_model import LightevalModel, ModelInfo
31-
from lighteval.models.model_config import DummyModelConfig
3232
from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse
3333
from lighteval.tasks.requests import (
3434
GreedyUntilRequest,
@@ -39,6 +39,11 @@
3939
from lighteval.utils.utils import EnvConfig
4040

4141

42+
@dataclass
43+
class DummyModelConfig:
44+
seed: int = 42
45+
46+
4247
class DummyModel(LightevalModel):
4348
"""Dummy model to generate random baselines."""
4449

src/lighteval/models/endpoint_model.py renamed to src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import logging
2525
import re
2626
import time
27-
from typing import Coroutine, List, Optional, Union
27+
from dataclasses import dataclass
28+
from typing import Coroutine, Dict, List, Optional, Union
2829

2930
import requests
3031
import torch
@@ -47,7 +48,6 @@
4748

4849
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
4950
from lighteval.models.abstract_model import LightevalModel, ModelInfo
50-
from lighteval.models.model_config import InferenceEndpointModelConfig, InferenceModelConfig
5151
from lighteval.models.model_output import GenerativeResponse, LoglikelihoodResponse, LoglikelihoodSingleTokenResponse
5252
from lighteval.tasks.requests import (
5353
GreedyUntilRequest,
@@ -74,6 +74,59 @@
7474
]
7575

7676

77+
@dataclass
78+
class InferenceModelConfig:
79+
model: str
80+
add_special_tokens: bool = True
81+
82+
83+
@dataclass
84+
class InferenceEndpointModelConfig:
85+
endpoint_name: str = None
86+
model_name: str = None
87+
should_reuse_existing: bool = False
88+
accelerator: str = "gpu"
89+
model_dtype: str = None # if empty, we use the default
90+
vendor: str = "aws"
91+
region: str = "us-east-1" # this region has the most hardware options available
92+
instance_size: str = None # if none, we autoscale
93+
instance_type: str = None # if none, we autoscale
94+
framework: str = "pytorch"
95+
endpoint_type: str = "protected"
96+
add_special_tokens: bool = True
97+
revision: str = "main"
98+
namespace: str = None # The namespace under which to launch the endopint. Defaults to the current user's namespace
99+
image_url: str = None
100+
env_vars: dict = None
101+
102+
def __post_init__(self):
103+
# xor operator, one is None but not the other
104+
if (self.instance_size is None) ^ (self.instance_type is None):
105+
raise ValueError(
106+
"When creating an inference endpoint, you need to specify explicitely both instance_type and instance_size, or none of them for autoscaling."
107+
)
108+
109+
if not (self.endpoint_name is None) ^ int(self.model_name is None):
110+
raise ValueError("You need to set either endpoint_name or model_name (but not both).")
111+
112+
def get_dtype_args(self) -> Dict[str, str]:
113+
if self.model_dtype is None:
114+
return {}
115+
model_dtype = self.model_dtype.lower()
116+
if model_dtype in ["awq", "eetq", "gptq"]:
117+
return {"QUANTIZE": model_dtype}
118+
if model_dtype == "8bit":
119+
return {"QUANTIZE": "bitsandbytes"}
120+
if model_dtype == "4bit":
121+
return {"QUANTIZE": "bitsandbytes-nf4"}
122+
if model_dtype in ["bfloat16", "float16"]:
123+
return {"DTYPE": model_dtype}
124+
return {}
125+
126+
def get_custom_env_vars(self) -> Dict[str, str]:
127+
return {k: str(v) for k, v in self.env_vars.items()} if self.env_vars else {}
128+
129+
77130
class InferenceEndpointModel(LightevalModel):
78131
"""InferenceEndpointModels can be used both with the free inference client, or with inference
79132
endpoints, which will use text-generation-inference to deploy your model for the duration of the evaluation.

src/lighteval/models/openai_model.py renamed to src/lighteval/models/endpoints/openai_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
import os
2525
import time
2626
from concurrent.futures import ThreadPoolExecutor
27+
from dataclasses import dataclass
2728
from typing import Optional
2829

2930
from tqdm import tqdm
3031

3132
from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
3233
from lighteval.models.abstract_model import LightevalModel
33-
from lighteval.models.endpoint_model import ModelInfo
34+
from lighteval.models.endpoints.endpoint_model import ModelInfo
3435
from lighteval.models.model_output import (
3536
GenerativeResponse,
3637
LoglikelihoodResponse,
@@ -58,6 +59,11 @@
5859
logging.getLogger("httpx").setLevel(logging.ERROR)
5960

6061

62+
@dataclass
63+
class OpenAIModelConfig:
64+
model: str
65+
66+
6167
class OpenAIClient(LightevalModel):
6268
_DEFAULT_MAX_LENGTH: int = 4096
6369

src/lighteval/models/tgi_model.py renamed to src/lighteval/models/endpoints/tgi_model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
# SOFTWARE.
2222

2323
import asyncio
24+
from dataclasses import dataclass
2425
from typing import Coroutine, Optional
2526

2627
import requests
2728
from huggingface_hub import TextGenerationInputGrammarType, TextGenerationOutput
2829
from transformers import AutoTokenizer
2930

30-
from lighteval.models.endpoint_model import InferenceEndpointModel, ModelInfo
31+
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel, ModelInfo
3132
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available
3233

3334

@@ -44,6 +45,13 @@ def divide_chunks(array, n):
4445
yield array[i : i + n]
4546

4647

48+
@dataclass
49+
class TGIModelConfig:
50+
inference_server_address: str
51+
inference_server_auth: str
52+
model_id: str
53+
54+
4755
# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
4856
# the client functions, since they use a different client.
4957
class ModelClient(InferenceEndpointModel):

0 commit comments

Comments
 (0)