Skip to content
Merged
39 changes: 25 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ utoipa = { version = "3.5.0", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.1.5", features = ["axum"] }
ngrok = { version = "0.13.1", features = ["axum"], optional = true }
init-tracing-opentelemetry = { version = "0.14.1", features = ["opentelemetry-otlp"] }
minijinja = "1.0.10"
futures-util = "0.3.30"

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
Expand Down
51 changes: 38 additions & 13 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/// Batching and inference logic
use crate::validation::{Validation, ValidationError};
use crate::HubTokenizerConfig;
use crate::{ChatRequest, GenerateRequest, GenerateStreamResponse, PrefillToken};
use crate::{Entry, Queue, Token};
use crate::{GenerateRequest, PrefillToken};
use futures::future::try_join_all;
use minijinja::{Environment, ErrorKind, Template};
use nohash_hasher::IntMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -13,7 +15,7 @@ use text_generation_client::{
};
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tokio::sync::{mpsc, Notify, Semaphore, TryAcquireError};
use tokio::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
Expand All @@ -30,6 +32,8 @@ pub struct Infer {
shared: Arc<Shared>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
/// Chat template
template: Option<Template<'static, 'static>>,
}

/// Infer shared state
Expand All @@ -52,6 +56,7 @@ impl Infer {
window_size: Option<u32>,
speculate: u32,
generation_health: Arc<AtomicBool>,
tokenizer_config: HubTokenizerConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
Expand All @@ -74,11 +79,21 @@ impl Infer {
// Inference limit with a semaphore
let semaphore = Arc::new(Semaphore::new(max_concurrent_requests));

let template = tokenizer_config.chat_template.map(|t| {
let env = Box::new(Environment::new());
let template_str = t.into_boxed_str();
// leaking env and template_str as read-only, static resources for performance.
Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap()
});

Self {
validation,
queue,
shared,
limit_concurrent_requests: semaphore,
template,
}
}

Expand All @@ -87,14 +102,7 @@ impl Infer {
pub(crate) async fn generate_stream(
&self,
request: GenerateRequest,
) -> Result<
(
OwnedSemaphorePermit,
u32,
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
),
InferError,
> {
) -> Result<GenerateStreamResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
.clone()
Expand Down Expand Up @@ -139,6 +147,20 @@ impl Infer {
))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(&self, chat: ChatRequest) -> Result<String, InferError> {
self.template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.render(chat)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
tracing::error!("{e}");
InferError::TemplateError(e)
})
}

/// Add a new request to the queue and return a InferResponse
#[instrument(skip_all)]
pub(crate) async fn generate(
Expand Down Expand Up @@ -550,9 +572,9 @@ fn send_responses(
let mut iterator = tokens_
.ids
.into_iter()
.zip(tokens_.logprobs.into_iter())
.zip(tokens_.texts.into_iter())
.zip(tokens_.is_special.into_iter())
.zip(tokens_.logprobs)
.zip(tokens_.texts)
.zip(tokens_.is_special)
.enumerate()
.peekable();
while let Some((i, (((id, logprob), text), special))) = iterator.next() {
Expand Down Expand Up @@ -665,6 +687,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Template error: {0}")]
TemplateError(#[from] minijinja::Error),
}

impl InferError {
Expand All @@ -674,6 +698,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
}
}
}
Loading