-
Notifications
You must be signed in to change notification settings - Fork 373
Add litellm inference #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 24 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
17783b2
Added inference using litellm.
JoelNiklaus 9e92150
Add Udmurt (udm) translation literals (#381)
codemurt 30a624c
This PR adds translation literals for Belarusian language. (#382)
Kryuski 6e6fed6
fix: cache directory variable (#378)
NazimHAli d1d4c69
greedy_until() fix (#344)
vsabolcec f69811f
Fixed some params in completion call to enable more model providers.
JoelNiklaus dabb4a7
Added diskcache.
JoelNiklaus 65f759c
Merge branch 'main' into add_litellm_inference
JoelNiklaus f74afd4
Merge branch 'main' into add_litellm_inference
JoelNiklaus 88a9838
Fix issue for openai evaluation.
JoelNiklaus 02ed461
Added support for stop sequences and generation size.
JoelNiklaus 34596c2
Merge branch 'main' into add_litellm_inference
JoelNiklaus 190738f
Fixed issue with too many concurrent calls to APIs.
JoelNiklaus 2bb1917
Merge branch 'main' into add_litellm_inference
clefourrier 81e4404
Merge branch 'main' into add_litellm_inference
JoelNiklaus ebdd900
Merge branch 'main' into add_litellm_inference
NathanHB 251e181
few fixes
NathanHB 47b1888
Fixed issues with stop_sequence, max_completion_tokens and system_pro…
JoelNiklaus 20a1191
Merge branch 'main' into add_litellm_inference
JoelNiklaus ade8f0c
Revert weird change to __main__.py.
JoelNiklaus a2587d6
Made configuration simpler.
JoelNiklaus 7c0856e
Merge branch 'main' into add_litellm_inference
JoelNiklaus 932fd2c
Fixed import issues.
JoelNiklaus 8fc9b13
Merge branch 'main' into add_litellm_inference
NathanHB 45d6d1d
fix import location
NathanHB 2a23836
Merge branch 'add_litellm_inference' of github.com:JoelNiklaus/lighte…
NathanHB cca1446
Merge branch 'main' into add_litellm_inference
JoelNiklaus 1a10351
Enabled passing through system prompt to the models in the requests.
JoelNiklaus ff6d5de
Fixed some bugs.
JoelNiklaus 8d831b8
Merge branch 'main' into add_litellm_inference
JoelNiklaus 5115403
Made litellm inference robust to content management errors.
JoelNiklaus 78789c1
allow bette rmessage managment for litellm
NathanHB 3ebff6c
Merge branch 'main' into add_litellm_inference
NathanHB be77b15
allow system prompt to be passed to litellm models
NathanHB 21d6112
Merge branch 'main' into add_litellm_inference
JoelNiklaus d045d92
use system prompt from the request and use litellm encode functino as…
NathanHB f1ed682
fixes from review
NathanHB ec306fd
Merge branch 'add_litellm_inference' of github.com:JoelNiklaus/lighte…
NathanHB bae4506
fix tests
NathanHB 6b0cb60
fix tests
NathanHB c826b0e
Merge branch 'main' into add_litellm_inference
JoelNiklaus a6747f4
remove unecessary doc
NathanHB 5554787
Merge branch 'add_litellm_inference' of github.com:JoelNiklaus/lighte…
NathanHB 5b2b72d
Update src/lighteval/models/litellm_model.py
NathanHB 0265a74
Update src/lighteval/models/litellm_model.py
NathanHB 4fa8311
Merge branch 'main' into add_litellm_inference
NathanHB 86dd849
Support retrying of empty cached model responses.
JoelNiklaus db983e3
Merge branch 'main' into add_litellm_inference
JoelNiklaus 221d5d5
Fixed error when stop sequence is None.
JoelNiklaus 81f02ca
Added support for litellm as judge backend.
JoelNiklaus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,257 @@ | ||
| # MIT License | ||
|
|
||
| # Copyright (c) 2024 The HuggingFace Team | ||
|
|
||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| # of this software and associated documentation files (the "Software"), to deal | ||
| # in the Software without restriction, including without limitation the rights | ||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| # copies of the Software, and to permit persons to whom the Software is | ||
| # furnished to do so, subject to the following conditions: | ||
|
|
||
| # The above copyright notice and this permission notice shall be included in all | ||
| # copies or substantial portions of the Software. | ||
|
|
||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| # SOFTWARE. | ||
|
|
||
| import logging | ||
| import os | ||
| import time | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| from tqdm import tqdm | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from lighteval.data import GenerativeTaskDataset | ||
| from lighteval.models.abstract_model import LightevalModel | ||
| from lighteval.models.endpoints.endpoint_model import ModelInfo | ||
| from lighteval.models.model_output import ( | ||
| GenerativeResponse, | ||
| LoglikelihoodResponse, | ||
| LoglikelihoodSingleTokenResponse, | ||
| ) | ||
| from lighteval.tasks.requests import ( | ||
| GreedyUntilRequest, | ||
| LoglikelihoodRequest, | ||
| LoglikelihoodRollingRequest, | ||
| LoglikelihoodSingleTokenRequest, | ||
| ) | ||
| from lighteval.utils.imports import is_litellm_available | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| if is_litellm_available(): | ||
| import litellm | ||
| from litellm.caching.caching import Cache | ||
|
|
||
| logging.getLogger("LiteLLM").setLevel(logging.WARNING) | ||
| logging.getLogger("LiteLLM").handlers.clear() | ||
|
|
||
| litellm.cache = Cache(type="disk") | ||
|
|
||
|
|
||
| @dataclass | ||
| class LiteLLMModelConfig: | ||
| model: str | ||
|
|
||
|
|
||
| class LiteLLMClient(LightevalModel): | ||
| _DEFAULT_MAX_LENGTH: int = 4096 | ||
|
|
||
| def __init__(self, config, env_config) -> None: | ||
| """ | ||
| IMPORTANT: Your API keys should be set in the environment variables. | ||
| If a base_url is not set, it will default to the public API. | ||
| """ | ||
| self.model_info = ModelInfo( | ||
| model_name=config.model, | ||
| model_sha="", | ||
| model_dtype=None, | ||
| model_size="", | ||
| ) | ||
| self.provider = config.model.split("/")[0] | ||
| self.base_url = os.getenv(f"{self.provider.upper()}_BASE_URL", None) | ||
| self.API_MAX_RETRY = 5 | ||
| self.API_RETRY_SLEEP = 3 | ||
| self.API_RETRY_MULTIPLIER = 2 | ||
| self.CONCURENT_CALLS = 20 # 100 leads to hitting Anthropic rate limits | ||
| self.TEMPERATURE = 0.7 | ||
| self.TOP_P = 0.95 | ||
| self.model = config.model | ||
| self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility | ||
NathanHB marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.pairwise_tokenization = False | ||
| # TODO: Pass the system prompt from the pipeline through. | ||
| self.system_prompt = "You are a helpful assistant." | ||
| litellm.drop_params = True | ||
| litellm.verbose = True | ||
|
|
||
| def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): | ||
| for attempt in range(self.API_MAX_RETRY): | ||
| try: | ||
| if self.provider == "anthropic": | ||
| # Filter out whitespace-only stop sequences | ||
| if stop_sequence: | ||
| stop_sequence = [s for s in stop_sequence if s.strip()] | ||
| if not stop_sequence: # If empty after filtering | ||
| stop_sequence = ["\n"] | ||
|
|
||
| if "o1" in self.model: | ||
| # We need to allow more tokens to include reasoning tokens | ||
| max_new_tokens *= 10 | ||
|
|
||
| response = litellm.completion( | ||
| model=self.model, | ||
| messages=[{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}], | ||
| max_completion_tokens=max_new_tokens if max_new_tokens > 0 else None, | ||
| logprobs=return_logits if self.provider == "openai" else None, | ||
NathanHB marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| stop=stop_sequence, | ||
| base_url=self.base_url, | ||
| n=num_samples, | ||
| temperature=self.TEMPERATURE, | ||
| top_p=self.TOP_P, | ||
| caching=True, | ||
| ) | ||
| return response | ||
| except Exception as e: | ||
| wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s | ||
| logger.warning( | ||
| f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}" | ||
| ) | ||
| time.sleep(wait_time) | ||
|
|
||
| logger.error(f"API call failed after {self.API_MAX_RETRY} attempts, skipping entry.") | ||
|
|
||
| def __call_api_parallel( | ||
| self, | ||
| prompts, | ||
| return_logits: bool | list[bool], | ||
| max_new_tokens: int | list[int], | ||
| num_samples: int | list[int], | ||
| stop_sequence: list[str] | None = None, | ||
| ): | ||
| results = [] | ||
|
|
||
| return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits | ||
| max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens | ||
| num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples | ||
| stop_sequencess = [stop_sequence for _ in prompts] | ||
|
|
||
| assert ( | ||
| len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(stop_sequencess) | ||
| ), f"Length of prompts, return_logitss, max_new_tokenss, num_sampless, stop_sequences should be the same but are {len(prompts)}, {len(return_logitss)}, {len(max_new_tokenss)}, {len(num_sampless)}, {len(stop_sequencess)}" | ||
|
|
||
| with ThreadPoolExecutor(self.CONCURENT_CALLS) as executor: | ||
| for entry in tqdm( | ||
| executor.map( | ||
| self.__call_api, | ||
| prompts, | ||
| return_logitss, | ||
| max_new_tokenss, | ||
| num_sampless, | ||
| stop_sequencess, | ||
| ), | ||
| total=len(prompts), | ||
| ): | ||
| results.append(entry) | ||
|
|
||
| if None in results: | ||
| raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") | ||
|
|
||
| return results | ||
|
|
||
| def greedy_until( | ||
| self, | ||
| requests: list[GreedyUntilRequest], | ||
| override_bs: Optional[int] = None, | ||
| ) -> list[GenerativeResponse]: | ||
| """ | ||
| Generates responses using a greedy decoding strategy until certain ending conditions are met. | ||
|
|
||
| Args: | ||
| requests (list[Request]): list of requests containing the context and ending conditions. | ||
| disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False. | ||
| override_bs (int, optional): Override the batch size for generation. Defaults to None. | ||
|
|
||
| Returns: | ||
| list[GenerativeResponse]: list of generated responses. | ||
| """ | ||
| for request in requests: | ||
| request.tokenized_context = self.tok_encode(request.context) | ||
|
|
||
| dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS) | ||
| results = [] | ||
|
|
||
| for _ in tqdm( | ||
| dataset.splits_start_end_iterator(), | ||
| total=dataset.num_dataset_splits, | ||
| desc="Splits", | ||
| position=0, | ||
| disable=False, # self.disable_tqdm, | ||
| ): | ||
| contexts = [c.context for c in dataset] | ||
| max_new_tokens = dataset[0].generation_size # could be none | ||
| return_logits = dataset[0].use_logits | ||
| num_samples = dataset[0].num_samples | ||
| stop_sequence = requests[0].stop_sequence | ||
|
|
||
| responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples, stop_sequence) | ||
|
|
||
| for response in responses: | ||
| result: list[str] = [choice.message.content for choice in response.choices] | ||
|
|
||
| cur_response = GenerativeResponse( | ||
| result=result, | ||
| logits=None, | ||
NathanHB marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| generated_tokens=[], | ||
| input_tokens=[], | ||
| ) | ||
| results.append(cur_response) | ||
|
|
||
| return dataset.get_original_order(results) | ||
|
|
||
| @property | ||
| def tokenizer(self): | ||
| return self._tokenizer | ||
|
|
||
| def tok_encode(self, text: str): | ||
| return self.tokenizer.encode(text) | ||
|
|
||
| @property | ||
| def add_special_tokens(self) -> bool: | ||
| return False | ||
|
|
||
| @property | ||
| def max_length(self) -> int: | ||
| """Return the maximum sequence length of the model.""" | ||
| return 4096 | ||
|
|
||
| def loglikelihood( | ||
| self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None | ||
| ) -> list[LoglikelihoodResponse]: | ||
| """Tokenize the context and continuation and compute the log likelihood of those | ||
| tokenized sequences. | ||
| """ | ||
| raise NotImplementedError | ||
|
|
||
| def loglikelihood_rolling( | ||
| self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None | ||
| ) -> list[LoglikelihoodResponse]: | ||
| """This function is used to compute the log likelihood of the context for perplexity metrics.""" | ||
| raise NotImplementedError | ||
|
|
||
| def loglikelihood_single_token( | ||
| self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None | ||
| ) -> list[LoglikelihoodSingleTokenResponse]: | ||
| """Tokenize the context and continuation and compute the log likelihood of those | ||
| tokenized sequences. | ||
| """ | ||
| raise NotImplementedError | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.