Skip to content

Commit e7a6fb2

Browse files
committed
fix: initialize chat template single time, fix defaults and add seed param
1 parent fa6b227 commit e7a6fb2

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

router/src/infer.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::HubTokenizerConfig;
44
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
55
use crate::{Entry, Queue, Token};
66
use futures::future::try_join_all;
7+
use minijinja::{Environment, ErrorKind, Template};
78
use nohash_hasher::IntMap;
89
use std::sync::{
910
atomic::{AtomicBool, Ordering},
@@ -27,12 +28,12 @@ pub struct Infer {
2728
validation: Validation,
2829
/// Request queue
2930
queue: Queue,
30-
/// Chat formatter
31-
tokenizer_config: HubTokenizerConfig,
3231
/// Shared state
3332
shared: Arc<Shared>,
3433
/// Inference limit
3534
limit_concurrent_requests: Arc<Semaphore>,
35+
/// Chat template
36+
template: Option<Template<'static, 'static>>,
3637
}
3738

3839
/// Infer shared state
@@ -78,12 +79,21 @@ impl Infer {
7879
// Inference limit with a semaphore
7980
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));
8081

82+
let template = tokenizer_config.chat_template.map(|t| {
83+
let env = Box::new(Environment::new());
84+
let template_str = t.into_boxed_str();
85+
// leaking env and template_str as read-only, static resources for performance.
86+
Box::leak(env)
87+
.template_from_str(Box::leak(template_str))
88+
.unwrap()
89+
});
90+
8191
Self {
8292
validation,
8393
queue,
8494
shared,
8595
limit_concurrent_requests: semaphore,
86-
tokenizer_config,
96+
template,
8797
}
8898
}
8999

@@ -139,9 +149,15 @@ impl Infer {
139149
/// Apply the chat template to the chat request
140150
#[instrument(skip_all)]
141151
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
142-
self.tokenizer_config
143-
.apply_chat_template(chat)
144-
.map_err(InferError::TemplateError)
152+
self.template
153+
.as_ref()
154+
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
155+
.render(chat)
156+
.map_err(|e| {
157+
metrics::increment_counter!("tgi_request_failure", "err" => "template");
158+
tracing::error!("{e}");
159+
InferError::TemplateError(e)
160+
})
145161
}
146162

147163
/// Add a new request to the queue and return a InferResponse

router/src/lib.rs

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,6 @@ pub struct HubTokenizerConfig {
3636
pub chat_template: Option<String>,
3737
}
3838

39-
impl HubTokenizerConfig {
40-
/// Apply the chat template to the chat request
41-
pub(crate) fn apply_chat_template(
42-
&self,
43-
chat: ChatRequest,
44-
) -> Result<String, minijinja::Error> {
45-
let mut env = minijinja::Environment::new();
46-
let chat_template = self
47-
.chat_template
48-
.as_ref()
49-
.ok_or(minijinja::ErrorKind::TemplateNotFound)?;
50-
env.add_template("_", chat_template)?;
51-
env.get_template("_")?.render(chat)
52-
}
53-
}
54-
5539
#[derive(Clone, Debug, Serialize, ToSchema)]
5640
pub struct Info {
5741
/// Model info
@@ -292,7 +276,7 @@ impl ChatCompletionChunk {
292276
finish_reason: Option<String>,
293277
) -> Self {
294278
Self {
295-
id: "".to_string(),
279+
id: String::new(),
296280
object: "text_completion".to_string(),
297281
created,
298282
model,
@@ -312,7 +296,7 @@ impl ChatCompletionChunk {
312296

313297
fn default_request_messages() -> Vec<Message> {
314298
vec![Message {
315-
role: "system".to_string(),
299+
role: "user".to_string(),
316300
content: "My name is David and I".to_string(),
317301
}]
318302
}
@@ -371,11 +355,14 @@ pub(crate) struct ChatRequest {
371355

372356
#[serde(default = "bool::default")]
373357
pub stream: bool,
358+
359+
#[schema(nullable = true, example = 42)]
360+
pub seed: Option<u64>,
374361
}
375362

376363
#[derive(Clone, Deserialize, ToSchema, Serialize)]
377364
pub(crate) struct Message {
378-
#[schema(example = "system")]
365+
#[schema(example = "user")]
379366
pub role: String,
380367
#[schema(example = "My name is David and I")]
381368
pub content: String,

router/src/server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ async fn chat_completions(
564564
// rescale frequency_penalty from (-2.0, 2.0) to (0.0, 4.0)
565565
.map(|x| x + 2.0);
566566
let logprobs = req.logprobs.unwrap_or(false);
567+
let seed = req.seed;
567568

568569
// apply chat template to flatten the request into a single input
569570
let inputs = match infer.apply_chat_template(req) {
@@ -599,7 +600,7 @@ async fn chat_completions(
599600
watermark: false,
600601
details: true,
601602
decoder_input_details: false,
602-
seed: None,
603+
seed,
603604
top_n_tokens: None,
604605
},
605606
};

0 commit comments

Comments
 (0)