Skip to content

Support for ChatCompletions #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions src/llms/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ use async_trait::async_trait;
use tiktoken_rs::tiktoken::{p50k_base, CoreBPE};

use crate::settings::OpenAISettings;
use async_openai::{types::CreateCompletionRequestArgs, Client};
use async_openai::{
types::{
ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs,
CreateCompletionRequestArgs, Role,
},
Client,
};

use super::llm_client::LlmClient;

Expand Down Expand Up @@ -42,33 +48,17 @@ impl OpenAIClient {
_ => 4096,
}
}
}

#[async_trait]
impl LlmClient for OpenAIClient {
/// Sends a request to OpenAI's API to get a text completion.
/// It takes a prompt as input, and returns the completion.
async fn completions(&self, prompt: &str) -> Result<String> {
let prompt_token_limit = self.get_prompt_token_limit_for_model();
lazy_static! {
static ref BPE_TOKENIZER: CoreBPE = p50k_base().unwrap();
}
let output_length = 100;

let tokens = BPE_TOKENIZER.encode_with_special_tokens(prompt);
let prompt_token_count = tokens.len();
if prompt_token_count + output_length > prompt_token_limit {
let error_msg =
format!("skipping... token count: {prompt_token_count} < {prompt_token_limit}");
warn!("{}", error_msg);
bail!(error_msg)
}
pub(crate) fn should_use_chat_completion(model: &str) -> bool {
model.to_lowercase().starts_with("gpt-3.5-turbo")
}

pub(crate) async fn get_completions(&self, prompt: &str, output_length: u16) -> Result<String> {
// Create request using builder pattern
let request = CreateCompletionRequestArgs::default()
.model(&self.model)
.prompt(prompt)
.max_tokens(output_length as u16)
.max_tokens(output_length)
.temperature(0.5)
.top_p(1.)
.frequency_penalty(0.)
Expand All @@ -89,6 +79,67 @@ impl LlmClient for OpenAIClient {
.ok_or(anyhow!("No completion results"))
.map(|c| c.text.clone());

return completion;
completion
}

pub(crate) async fn get_chat_completions(
&self,
prompt: &str,
_output_length: u16,
) -> Result<String> {
let request = CreateChatCompletionRequestArgs::default()
.model(&self.model)
.messages([
ChatCompletionRequestMessageArgs::default()
.role(Role::System)
.content("You are an expect, helpful programming assistant that has a deep understanding of all programming languages including Python, Rust and Javascript.")
.build()?,
ChatCompletionRequestMessageArgs::default()
.role(Role::User)
.content(prompt)
.build()?,

])
.build()?;

let response = self.client.chat().create(request).await?;

if let Some(choice) = response.choices.into_iter().next() {
debug!(
"{}: Role: {} Content: {}",
choice.index, choice.message.role, choice.message.content
);

return Ok(choice.message.content);
}

bail!("No completion results")
}
}

#[async_trait]
impl LlmClient for OpenAIClient {
/// Sends a request to OpenAI's API to get a text completion.
/// It takes a prompt as input, and returns the completion.
async fn completions(&self, prompt: &str) -> Result<String> {
let prompt_token_limit = self.get_prompt_token_limit_for_model();
lazy_static! {
static ref BPE_TOKENIZER: CoreBPE = p50k_base().unwrap();
}
let n_tokens = 100;

let tokens = BPE_TOKENIZER.encode_with_special_tokens(prompt);
let prompt_token_count = tokens.len();
if prompt_token_count + n_tokens > prompt_token_limit {
let error_msg =
format!("skipping... token count: {prompt_token_count} < {prompt_token_limit}");
warn!("{}", error_msg);
bail!(error_msg)
}
if OpenAIClient::should_use_chat_completion(&self.model) {
self.get_chat_completions(prompt, n_tokens as u16).await
} else {
self.get_completions(prompt, n_tokens as u16).await
}
}
}
4 changes: 3 additions & 1 deletion src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use crate::{
},
};

static DEFAULT_OPENAI_MODEL: &str = "gpt-3.5-turbo";

static DEFAULT_FILES_TO_IGNORE: &[&str; 4] = &[
"package-lock.json",
"yarn.lock",
Expand Down Expand Up @@ -195,7 +197,7 @@ impl Settings {
"openai",
Some(OpenAISettings {
api_key: None,
model: Some("text-davinci-003".to_string()),
model: Some(DEFAULT_OPENAI_MODEL.to_string()),
}),
)?
.set_default(
Expand Down