Skip to content

Commit 0eabc83

Browse files
authored
feat: supports openai chat completions API (#1427)
This PR adds support to make TGI a drop in replacement for OpenAI clients by exposing the same HTTP interface. Notes - TGI inits a single model at startup so the `model` field is unused in HTTP requests. - `max_tokens` and `stream` should work as expected but other params may be (unimplemented or not supported) General approach - fetch the `tokenizer_config` at startup from the hub - pass `tokenizer_config` into `Infer` so we have it at request time - use the `chat_template` on the config to format chat request - parse jinja template and render chat string - pass inputs into existing generate function - wrap generation output in expected structure before returning # How to test ### Streaming curl ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "tgi", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is deep learning?" } ], "stream": true, "max_tokens": 20 }' \ -H 'Content-Type: application/json' ``` It is also possible to use the `openai` python library and change the base url ### 🌊 STREAMING REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=True ) # iterate and print stream for message in chat_completion: print(message) # ChatCompletionChunk(id='', choices=[Choice(delta=ChoiceDelta(content=' that', function_call=None, role='assistant', tool_calls=None), finish_reason=None, index=2, logprobs=None)], created=1704486761, model='', object='text_completion', system_fingerprint='') ``` ### 🚗 SYNCHRONOUS REQUEST ```python from openai import OpenAI # init the client but point it to TGI client = OpenAI( base_url="http://localhost:3000/v1", api_key="not needed for a local LLM" ) chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "system", "content": "You are a helpful assistant." }, {"role": "user", "content": "What is deep learning?"} ], stream=False ) print(chat_completion) # ChatCompletion(id='', choices=[Choice(finish_reason=None, index=0, logprobs=None, message=ChatCompletionMessage(content='\nDeep learning is a new field of research that has been gaining traction in the last ...', role='assistant', function_call=None, tool_calls=None))], created=1704486762, model='', object='text_completion', system_fingerprint='', usage=CompletionUsage(completion_tokens=100, prompt_tokens=76, total_tokens=176)) ``` ## How to run dev ```bash cd text-generation-inference/server MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 text-generation-server serve --trust-remote-code gpt2 ``` ***note many of the existing `chat_templates` use non standard `jinja` (ie. adding a `raise` to the template) which will throw an error when parsing; hence using `upstage/SOLAR-10.7B-Instruct-v1.0` since it has a valid template ```bash cd text-generation-inference/router cargo run -- --tokenizer-name upstage/SOLAR-10.7B-Instruct-v1.0 ``` trigger ```bash curl localhost:3000/v1/chat/completions \ -X POST \ -d '{ "model": "gpt-3.5-turbo", "messages": [ { "role": "system", "content": "You are a helpful assistant." }, { "role": "user", "content": "What is the IP address of the Google DNS servers?" } ], "stream": true, "max_tokens": 20, "logprobs": true }' \ -H 'Content-Type: application/json' ``` ^ supports `stream: true` and `stream: false` requests
1 parent ac08b4e commit 0eabc83

File tree

7 files changed

+557
-64
lines changed

7 files changed

+557
-64
lines changed

Cargo.lock

Lines changed: 25 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

router/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
4343
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
4444
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
4545
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
46+
minijinja = "1.0.10"
47+
futures-util = "0.3.30"
4648

4749
[build-dependencies]
4850
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }

router/src/infer.rs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
/// Batching and inference logic
22
use crate::validation::{Validation, ValidationError};
3+
use crate::HubTokenizerConfig;
4+
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
35
use crate::{Entry, Queue, Token};
4-
use crate::{GenerateRequest, PrefillToken};
56
use futures::future::try_join_all;
7+
use minijinja::{Environment, ErrorKind, Template};
68
use nohash_hasher::IntMap;
79
use std::sync::{
810
atomic::{AtomicBool, Ordering},
@@ -13,7 +15,7 @@ use text_generation_client::{
1315
};
1416
use thiserror::Error;
1517
use tokio::sync::mpsc::error::SendError;
16-
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
18+
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
1719
use tokio::time::Instant;
1820
use tokio_stream::wrappers::UnboundedReceiverStream;
1921
use tokio_stream::StreamExt;
@@ -30,6 +32,8 @@ pub struct Infer {
3032
shared: Arc<Shared>,
3133
/// Inference limit
3234
limit_concurrent_requests: Arc<Semaphore>,
35+
/// Chat template
36+
template: Option<Template<'static, 'static>>,
3337
}
3438

3539
/// Infer shared state
@@ -52,6 +56,7 @@ impl Infer {
5256
window_size: Option<u32>,
5357
speculate: u32,
5458
generation_health: Arc<AtomicBool>,
59+
tokenizer_config: HubTokenizerConfig,
5560
) -> Self {
5661
// Infer shared state
5762
let queue = Queue::new(requires_padding, 16, window_size, speculate);
@@ -74,11 +79,21 @@ impl Infer {
7479
// Inference limit with a semaphore
7580
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
7681

82+
let template = tokenizer_config.chat_template.map(|t| {
83+
let env = Box::new(Environment::new());
84+
let template_str = t.into_boxed_str();
85+
// leaking env and template_str as read-only, static resources for performance.
86+
Box::leak(env)
87+
.template_from_str(Box::leak(template_str))
88+
.unwrap()
89+
});
90+
7791
Self {
7892
validation,
7993
queue,
8094
shared,
8195
limit_concurrent_requests: semaphore,
96+
template,
8297
}
8398
}
8499

@@ -87,14 +102,7 @@ impl Infer {
87102
pub(crate) async fn generate_stream(
88103
&self,
89104
request: GenerateRequest,
90-
) -> Result<
91-
(
92-
OwnedSemaphorePermit,
93-
u32,
94-
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
95-
),
96-
InferError,
97-
> {
105+
) -> Result<GenerateStreamResponse, InferError> {
98106
// Limit concurrent requests by acquiring a permit from the semaphore
99107
let permit = self
100108
.clone()
@@ -139,6 +147,20 @@ impl Infer {
139147
))
140148
}
141149

150+
/// Apply the chat template to the chat request
151+
#[instrument(skip_all)]
152+
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
153+
self.template
154+
.as_ref()
155+
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
156+
.render(chat)
157+
.map_err(|e| {
158+
metrics::increment_counter!("tgi_request_failure", "err" => "template");
159+
tracing::error!("{e}");
160+
InferError::TemplateError(e)
161+
})
162+
}
163+
142164
/// Add a new request to the queue and return a InferResponse
143165
#[instrument(skip_all)]
144166
pub(crate) async fn generate(
@@ -550,9 +572,9 @@ fn send_responses(
550572
let mut iterator = tokens_
551573
.ids
552574
.into_iter()
553-
.zip(tokens_.logprobs.into_iter())
554-
.zip(tokens_.texts.into_iter())
555-
.zip(tokens_.is_special.into_iter())
575+
.zip(tokens_.logprobs)
576+
.zip(tokens_.texts)
577+
.zip(tokens_.is_special)
556578
.enumerate()
557579
.peekable();
558580
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
@@ -665,6 +687,8 @@ pub enum InferError {
665687
ValidationError(#[from] ValidationError),
666688
#[error("Incomplete generation")]
667689
IncompleteGeneration,
690+
#[error("Template error: {0}")]
691+
TemplateError(#[from] minijinja::Error),
668692
}
669693

670694
impl InferError {
@@ -674,6 +698,7 @@ impl InferError {
674698
InferError::Overloaded(_) => "overloaded",
675699
InferError::ValidationError(_) => "validation",
676700
InferError::IncompleteGeneration => "incomplete_generation",
701+
InferError::TemplateError(_) => "template_error",
677702
}
678703
}
679704
}

0 commit comments

Comments
 (0)