diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 950a7597a..536990eb1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -25,7 +25,7 @@ jobs: cache: 'pip' - name: Install lighteval in editable mode run: | - pip install -e .[dev,extended_tasks,multilingual] + pip install -e .[dev,extended_tasks,multilingual,litellm] - name: Get cached files uses: actions/cache@v4 id: get-cache diff --git a/pyproject.toml b/pyproject.toml index 9a4d3a3ce..2c3a76f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ dependencies = [ ] [project.optional-dependencies] +litellm = ["litellm", "diskcache"] tgi = ["text-generation==0.6.0"] optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 04a00f0a5..f992d65c9 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -367,3 +367,112 @@ def tgi( pipeline.save_and_push_results() return results + + +@app.command(rich_help_panel="Evaluation Backends") +def litellm( + # === general === + model_name: Annotated[ + str, Argument(help="The model name to evaluate (has to be available through the litellm API.") + ], + tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")], + # === Common parameters === + use_chat_template: Annotated[ + bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) + ] = False, + system_prompt: Annotated[ + Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANEL_NAME_4) + ] = None, + dataset_loading_processes: Annotated[ + int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANEL_NAME_1) + ] = 1, + custom_tasks: Annotated[ + Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANEL_NAME_1) + ] = None, + cache_dir: Annotated[ + str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANEL_NAME_1) + ] = CACHE_DIR, + num_fewshot_seeds: Annotated[ + int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANEL_NAME_1) + ] = 1, + # === saving === + output_dir: Annotated[ + str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2) + ] = "results", + push_to_hub: Annotated[ + bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + push_to_tensorboard: Annotated[ + bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + public_run: Annotated[ + bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + results_org: Annotated[ + Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANEL_NAME_2) + ] = None, + save_details: Annotated[ + bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANEL_NAME_2) + ] = False, + # === debug === + max_samples: Annotated[ + Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANEL_NAME_3) + ] = None, + override_batch_size: Annotated[ + int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANEL_NAME_3) + ] = -1, + job_id: Annotated[ + int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANEL_NAME_3) + ] = 0, +): + """ + Evaluate models using LiteLLM as backend. + """ + + from lighteval.logging.evaluation_tracker import EvaluationTracker + from lighteval.models.litellm_model import LiteLLMModelConfig + from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters + + env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir) + evaluation_tracker = EvaluationTracker( + output_dir=output_dir, + save_details=save_details, + push_to_hub=push_to_hub, + push_to_tensorboard=push_to_tensorboard, + public=public_run, + hub_results_org=results_org, + ) + + # TODO (nathan): better handling of model_args + parallelism_manager = ParallelismManager.NONE + + model_config = LiteLLMModelConfig(model=model_name) + + pipeline_params = PipelineParameters( + launcher_type=parallelism_manager, + env_config=env_config, + job_id=job_id, + dataset_loading_processes=dataset_loading_processes, + custom_tasks_directory=custom_tasks, + override_batch_size=override_batch_size, + num_fewshot_seeds=num_fewshot_seeds, + max_samples=max_samples, + use_chat_template=use_chat_template, + system_prompt=system_prompt, + ) + pipeline = Pipeline( + tasks=tasks, + pipeline_parameters=pipeline_params, + evaluation_tracker=evaluation_tracker, + model_config=model_config, + ) + + pipeline.evaluate() + + pipeline.show_results() + + results = pipeline.get_results() + + pipeline.save_and_push_results() + + return results diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index a4b5dfb13..e98a64aa4 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -28,7 +28,7 @@ from tqdm import tqdm -from lighteval.utils.imports import is_openai_available, is_vllm_available +from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available logging.getLogger("openai").setLevel(logging.ERROR) @@ -73,7 +73,7 @@ def __init__( model: str, templates: Callable, process_judge_response: Callable, - judge_backend: Literal["openai", "transformers", "tgi", "vllm"], + judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"], url: str | None = None, api_key: str | None = None, ): @@ -93,7 +93,7 @@ def __init__( def __lazy_load_client(self): match self.backend: - # Wether we use openai or TGI models, we go trhough the openai API + # Wether we use openai or TGI models, we go through the openai API # to route to the endpoint case "openai" | "tgi" if is_openai_available(): if self.client is None: @@ -104,6 +104,8 @@ def __lazy_load_client(self): else: self.client = OpenAI(base_url=self.url, api_key=self.api_key) return self.__call_api_parallel + case "litellm" if is_litellm_available(): + return self.__call_litellm case "vllm" if is_vllm_available(): if self.pipe is None: from vllm import LLM, SamplingParams @@ -187,6 +189,37 @@ def __call_vllm(self, prompt): outputs = [output.outputs[0].text for output in output] return outputs + def __call_litellm(self, prompts): + import litellm + + def __call_api(prompt): + for _ in range(self.API_MAX_RETRY): + try: + response = litellm.completion( + model=self.model, + messages=prompt, + response_format={"type": "text"}, + max_tokens=512, + n=1, + caching=True, + ) + text = response.choices[0].message.content + return text + except Exception as e: + logger.warning(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) + raise Exception("Failed to get response from the API") + + results = [] + with ThreadPoolExecutor(100) as executor: + for entry in tqdm(executor.map(__call_api, prompts), total=len(prompts)): + results.append(entry) + + if None in results: + raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") + + return results + def __call_api_parallel(self, prompts): results = [] with ThreadPoolExecutor(100) as executor: diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 2081b5606..b27d25c3e 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -858,7 +858,7 @@ def __init__( judge_model_name: str, template: Callable, process_judge_response: Callable, - judge_backend: Literal["openai", "transformers", "vllm", "tgi"], + judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"], short_judge_name: str | None = None, ) -> None: match judge_backend: @@ -871,6 +871,9 @@ def __init__( case "tgi": api_key = os.getenv("HF_TOKEN") url = "https://api-inference.huggingface.co/v1/" + case "litellm": + api_key = None + url = None case "transformers" | "vllm": api = HfApi() models = api.list_models(model_name=judge_model_name) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index b2ca25285..8733474d0 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -145,7 +145,6 @@ def greedy_until( Args: requests (list[Request]): list of requests containing the context and ending conditions. - disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. override_bs (int, optional): Override the batch size for generation. Defaults to None. Returns: diff --git a/src/lighteval/models/litellm_model.py b/src/lighteval/models/litellm_model.py new file mode 100644 index 000000000..21dfc45af --- /dev/null +++ b/src/lighteval/models/litellm_model.py @@ -0,0 +1,294 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import logging +import os +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional + +from tqdm import tqdm + +from lighteval.data import GenerativeTaskDataset +from lighteval.models.abstract_model import LightevalModel +from lighteval.models.endpoints.endpoint_model import ModelInfo +from lighteval.models.model_output import ( + GenerativeResponse, + LoglikelihoodResponse, + LoglikelihoodSingleTokenResponse, +) +from lighteval.tasks.requests import ( + GreedyUntilRequest, + LoglikelihoodRequest, + LoglikelihoodRollingRequest, + LoglikelihoodSingleTokenRequest, +) +from lighteval.utils.imports import is_litellm_available + + +logger = logging.getLogger(__name__) + +if is_litellm_available(): + import litellm + from litellm import encode + from litellm.caching.caching import Cache + from litellm.utils import ModelResponse + + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("LiteLLM").handlers.clear() + + litellm.cache = Cache(type="disk") + + +@dataclass +class LiteLLMModelConfig: + model: str + + +class LiteLLMClient(LightevalModel): + _DEFAULT_MAX_LENGTH: int = 4096 + + def __init__(self, config, env_config) -> None: + """ + IMPORTANT: Your API keys should be set in the environment variables. + If a base_url is not set, it will default to the public API. + """ + self.model_info = ModelInfo( + model_name=config.model, + model_sha="", + model_dtype=None, + model_size="", + ) + self.provider = config.model.split("/")[0] + self.base_url = os.getenv(f"{self.provider.upper()}_BASE_URL", None) + self.API_MAX_RETRY = 5 + self.API_RETRY_SLEEP = 3 + self.API_RETRY_MULTIPLIER = 2 + self.CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits + self.TEMPERATURE = 0.7 + self.TOP_P = 0.95 + self.model = config.model + self._tokenizer = encode + self.pairwise_tokenization = False + litellm.drop_params = True + litellm.verbose = True + + def _prepare_stop_sequence(self, stop_sequence): + """Prepare and validate stop sequence.""" + if self.provider == "anthropic": + # Filter out whitespace-only stop sequences + if stop_sequence: + stop_sequence = [s for s in stop_sequence if s and s.strip()] + if not stop_sequence: # If empty after filtering + stop_sequence = ["\n"] + return stop_sequence + + def _prepare_max_new_tokens(self, max_new_tokens): + """Calculate completion tokens based on max_new_tokens.""" + if not max_new_tokens or max_new_tokens <= 0: + return None + + if "o1" in self.model: + # We need to allow more tokens to include reasoning tokens + max_new_tokens = min(max_new_tokens * 10, 32000) + return max_new_tokens + + def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): + """Make API call with retries.""" + response = ModelResponse() + for attempt in range(self.API_MAX_RETRY): + try: + stop_sequence = self._prepare_stop_sequence(stop_sequence) + max_new_tokens = self._prepare_max_new_tokens(max_new_tokens) + + if return_logits and not self.provider == "openai": + logger.warning("Returning logits is not supported for this provider, ignoring.") + + # Prepare kwargs for completion call + kwargs = { + "model": self.model, + "messages": prompt, + "max_completion_tokens": max_new_tokens, + "logprobs": return_logits if self.provider == "openai" else None, + "stop": stop_sequence, + "base_url": self.base_url, + "n": num_samples, + "temperature": self.TEMPERATURE, + "top_p": self.TOP_P, + "caching": True, + } + + response = litellm.completion(**kwargs) + + # If response is empty, retry without caching (maybe the error is recoverable and solved with a retry) + if response.choices[0].message.content is None: + kwargs["caching"] = False + logger.info("Response is empty, retrying without caching") + response = litellm.completion(**kwargs) + return response + except litellm.BadRequestError as e: + if "message" in e.__dict__: + error_string = ( + "The response was filtered due to the prompt triggering Microsoft's content management policy" + ) + if error_string in e.__dict__["message"]: + logger.warning(f"{error_string}. Returning empty response.") + return ModelResponse() + except Exception as e: + wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s + logger.warning( + f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}" + ) + time.sleep(wait_time) + + logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, returning empty response.") + return ModelResponse() + + def __call_api_parallel( + self, + prompts, + return_logits: bool | list[bool], + max_new_tokens: int | list[int], + num_samples: int | list[int], + stop_sequence: list[str] | None = None, + ): + results = [] + + return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits + max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens + num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples + stop_sequencess = [stop_sequence for _ in prompts] + assert ( + len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(stop_sequencess) + ), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences, system_prompts should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}" + + with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor: + for entry in tqdm( + executor.map( + self.__call_api, + prompts, + return_logitss, + max_new_tokenss, + num_sampless, + stop_sequencess, + ), + total=len(prompts), + ): + results.append(entry) + + if None in results: + raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") + + return results + + def greedy_until( + self, + requests: list[GreedyUntilRequest], + override_bs: Optional[int] = None, + ) -> list[GenerativeResponse]: + """ + Generates responses using a greedy decoding strategy until certain ending conditions are met. + + Args: + requests (list[Request]): list of requests containing the context and ending conditions. + override_bs (int, optional): Override the batch size for generation. Defaults to None. + + Returns: + list[GenerativeResponse]: list of generated responses. + """ + for request in requests: + request.tokenized_context = self.tok_encode(request.context) + + dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for _ in tqdm( + dataset.splits_start_end_iterator(), + total=dataset.num_dataset_splits, + desc="Splits", + position=0, + disable=False, # self.disable_tqdm, + ): + contexts = [c.context for c in dataset] + max_new_tokens = dataset[0].generation_size # could be none + return_logits = dataset[0].use_logits + num_samples = dataset[0].num_samples + stop_sequence = requests[0].stop_sequence + + responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence) + + for response in responses: + result: list[str] = [choice.message.content for choice in response.choices] + + cur_response = GenerativeResponse( + # In empty responses, the model should return an empty string instead of None + result=result if result[0] else [""], + logits=None, + generated_tokens=[], + input_tokens=[], + ) + results.append(cur_response) + + return dataset.get_original_order(results) + + @property + def tokenizer(self): + return self._tokenizer + + def tok_encode(self, text: str | list[str]): + if isinstance(text, list): + toks = [encode(model=self.model, text=t["content"]) for t in text] + toks = [tok for tok in toks if tok] + return toks + return encode(model=self.model, text=text) + + @property + def add_special_tokens(self) -> bool: + return False + + @property + def max_length(self) -> int: + """Return the maximum sequence length of the model.""" + return 4096 + + def loglikelihood( + self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError + + def loglikelihood_rolling( + self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodResponse]: + """This function is used to compute the log likelihood of the context for perplexity metrics.""" + raise NotImplementedError + + def loglikelihood_single_token( + self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None + ) -> list[LoglikelihoodSingleTokenResponse]: + """Tokenize the context and continuation and compute the log likelihood of those + tokenized sequences. + """ + raise NotImplementedError diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 66eb99886..dff3b9b4a 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -31,13 +31,16 @@ ) from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig +from lighteval.models.litellm_model import LiteLLMClient, LiteLLMModelConfig from lighteval.models.transformers.adapter_model import AdapterModel, AdapterModelConfig from lighteval.models.transformers.base_model import BaseModel, BaseModelConfig from lighteval.models.transformers.delta_model import DeltaModel, DeltaModelConfig from lighteval.models.vllm.vllm_model import VLLMModel, VLLMModelConfig from lighteval.utils.imports import ( + NO_LITELLM_ERROR_MSG, NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, + is_litellm_available, is_openai_available, is_tgi_available, is_vllm_available, @@ -58,6 +61,7 @@ def load_model( # noqa: C901 DummyModelConfig, VLLMModelConfig, OpenAIModelConfig, + LiteLLMModelConfig, ], env_config: EnvConfig, ) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]: @@ -95,6 +99,9 @@ def load_model( # noqa: C901 if isinstance(config, OpenAIModelConfig): return load_openai_model(config=config, env_config=env_config) + if isinstance(config, LiteLLMModelConfig): + return load_litellm_model(config=config, env_config=env_config) + def load_model_with_tgi(config: TGIModelConfig): if not is_tgi_available(): @@ -107,6 +114,14 @@ def load_model_with_tgi(config: TGIModelConfig): return model +def load_litellm_model(config: LiteLLMModelConfig, env_config: EnvConfig): + if not is_litellm_available(): + raise ImportError(NO_LITELLM_ERROR_MSG) + + model = LiteLLMClient(config, env_config) + return model + + def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig): if not is_openai_available(): raise ImportError() diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 2d413807d..206fd3a55 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -54,6 +54,12 @@ from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel from vllm.transformers_utils.tokenizer import get_tokenizer + + logging.getLogger("vllm").propagate = True + logging.getLogger("vllm").handlers.clear() + + logging.getLogger("ray").propagate = True + logging.getLogger("ray").handlers.clear() else: LLM = None SamplingParams = None diff --git a/src/lighteval/tasks/prompt_manager.py b/src/lighteval/tasks/prompt_manager.py index cb9f94d04..af55c3184 100644 --- a/src/lighteval/tasks/prompt_manager.py +++ b/src/lighteval/tasks/prompt_manager.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union from lighteval.models.abstract_model import LightevalModel +from lighteval.models.litellm_model import LiteLLMClient from lighteval.tasks.requests import Doc from lighteval.utils.utils import as_list @@ -205,7 +206,11 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tok_encode(output) + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = [self.model.tok_encode(msg["content"]) for msg in output] + toks = [t for ts in toks for t in ts] # If we need to truncate few-shots to fit in the context if truncate_few_shots and self.model.max_length is not None and self.model.tokenizer is not None: @@ -223,7 +228,19 @@ def _single_turn_context( system_prompt=system_prompt, use_chat_template=use_chat_template, ) - toks = self.model.tokenizer(output)["input_ids"] + if not use_chat_template: + toks = self.model.tok_encode(output) + else: + toks = [self.model.tok_encode(msg["content"]) for msg in output] + toks = [t for ts in toks for t in ts] + + if isinstance(self.model, LiteLLMClient): + return output, num_effective_fewshots + + elif use_chat_template: + return self.model.tokenizer.apply_chat_template( + output, tokenize=False, add_generation_prompt=True + ), num_effective_fewshots return output, num_effective_fewshots @@ -256,7 +273,7 @@ def get_examples( examples.insert(0, {"role": "system", "content": system_prompt + instruction}) else: # Else we add the instruction to the first example examples[0]["content"] = instruction + examples[0]["content"] - return self.model.tokenizer.apply_chat_template(examples, tokenize=False, add_generation_prompt=True) + return examples else: if system_prompt is not None: output = system_prompt + instruction + "\n\n".join(examples) diff --git a/src/lighteval/utils/imports.py b/src/lighteval/utils/imports.py index d36c1acb4..c8fb2ce73 100644 --- a/src/lighteval/utils/imports.py +++ b/src/lighteval/utils/imports.py @@ -77,6 +77,13 @@ def is_openai_available() -> bool: NO_OPENAI_ERROR_MSG = "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip." +def is_litellm_available() -> bool: + return importlib.util.find_spec("litellm") is not None + + +NO_LITELLM_ERROR_MSG = "You are trying to use a LiteLLM model, for which you need `litellm`, which is not available in your environment. Please install it using pip." + + def is_vllm_available() -> bool: return importlib.util.find_spec("vllm") is not None and importlib.util.find_spec("ray") is not None