Skip to content

Commit d054602

Browse files
committed
update to use newest llama model
1 parent 37fe65a commit d054602

File tree

6 files changed

+67
-26
lines changed

6 files changed

+67
-26
lines changed

crates/llm-chain-llama-sys/src/bindings.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -759,14 +759,21 @@ fn bindgen_test_layout_llama_token_data_array() {
759759
)
760760
);
761761
}
762+
763+
const LLAMA_MAX_DEVICES: usize = 1;
762764
pub type llama_progress_callback =
763765
::std::option::Option<unsafe extern "C" fn(progress: f32, ctx: *mut ::std::os::raw::c_void)>;
764766
#[repr(C)]
765767
#[derive(Debug, Copy, Clone)]
766768
pub struct llama_context_params {
767769
pub n_ctx: ::std::os::raw::c_int,
768-
pub n_parts: ::std::os::raw::c_int,
770+
pub n_batch: ::std::os::raw::c_int,
771+
pub n_gpu_layers: ::std::os::raw::c_int,
772+
pub main_gpu: ::std::os::raw::c_int,
773+
pub tensor_split: [::std::os::raw::c_float; LLAM_MAX_DEVICES],
774+
769775
pub seed: ::std::os::raw::c_int,
776+
770777
pub f16_kv: bool,
771778
pub logits_all: bool,
772779
pub vocab_only: bool,
@@ -800,16 +807,7 @@ fn bindgen_test_layout_llama_context_params() {
800807
stringify!(n_ctx)
801808
)
802809
);
803-
assert_eq!(
804-
unsafe { ::std::ptr::addr_of!((*ptr).n_parts) as usize - ptr as usize },
805-
4usize,
806-
concat!(
807-
"Offset of field: ",
808-
stringify!(llama_context_params),
809-
"::",
810-
stringify!(n_parts)
811-
)
812-
);
810+
813811
assert_eq!(
814812
unsafe { ::std::ptr::addr_of!((*ptr).seed) as usize - ptr as usize },
815813
8usize,
Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use llm_chain::{executor, parameters, prompt};
2-
2+
use llm_chain::options;
3+
use llm_chain::options::{ModelRef, Options};
4+
use std::{env::args, error::Error};
35
/// This example demonstrates how to use the llm-chain-llama crate to generate text using a
46
/// LLaMA model.
57
///
@@ -9,11 +11,42 @@ use llm_chain::{executor, parameters, prompt};
911
/// cargo run --example simple /models/llama
1012
#[tokio::main(flavor = "current_thread")]
1113
async fn main() -> Result<(), Box<dyn std::error::Error>> {
12-
let exec = executor!(llama)?;
14+
let raw_args: Vec<String> = args().collect();
15+
let args = match &raw_args.len() {
16+
2 => (raw_args[1].as_str(), "Rust is a cool programming language because"),
17+
3 => (raw_args[1].as_str(), raw_args[2].as_str()),
18+
_ => panic!("Usage: cargo run --release --example simple <path to model> <optional prompt>")
19+
};
20+
21+
let model_path = args.0;
22+
let prompt = args.1;
23+
let opts = options!(
24+
Model: ModelRef::from_path(model_path),
25+
ModelType: "llama",
26+
MaxContextSize: 512_usize,
27+
NThreads: 4_usize,
28+
MaxTokens: 0_usize,
29+
TopK: 40_i32,
30+
TopP: 0.95,
31+
TfsZ: 1.0,
32+
TypicalP: 1.0,
33+
Temperature: 0.8,
34+
RepeatPenalty: 1.1,
35+
RepeatPenaltyLastN: 64_usize,
36+
FrequencyPenalty: 0.0,
37+
PresencePenalty: 0.0,
38+
Mirostat: 0_i32,
39+
MirostatTau: 5.0,
40+
MirostatEta: 0.1,
41+
PenalizeNl: true,
42+
StopSequence: vec!["\n".to_string()]
43+
);
44+
let exec = executor!(llama, opts.clone())?;
1345

14-
let res = prompt!("The Colors of the Rainbow are (in order): ")
46+
let res = prompt!(prompt)
1547
.run(&parameters!(), &exec)
1648
.await?;
49+
1750
println!("{}", res.to_immediate().await?);
1851
Ok(())
1952
}

crates/llm-chain-llama/src/context.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ use serde::{Deserialize, Serialize};
2020
#[error("LLAMA.cpp returned error-code {0}")]
2121
pub struct LLAMACPPErrorCode(i32);
2222

23+
const LLAMA_MAX_DEVICES: usize = 1; // corresponding to constant in llama.h
2324
// Represents the configuration parameters for a LLamaContext.
2425
#[derive(Debug, Clone, Serialize, Deserialize)]
2526
pub struct ContextParams {
2627
pub n_ctx: i32,
27-
pub n_parts: i32,
28+
pub n_batch: i32,
29+
pub n_gpu_layers: i32,
30+
pub main_gpu: i32,
31+
pub tensor_split: [f32; LLAMA_MAX_DEVICES],
2832
pub seed: i32,
2933
pub f16_kv: bool,
3034
pub vocab_only: bool,
@@ -56,7 +60,10 @@ impl From<ContextParams> for llama_context_params {
5660
fn from(params: ContextParams) -> Self {
5761
llama_context_params {
5862
n_ctx: params.n_ctx,
59-
n_parts: params.n_parts,
63+
n_batch: params.n_batch,
64+
n_gpu_layers: params.n_gpu_layers,
65+
main_gpu: params.main_gpu,
66+
tensor_split: params.tensor_split,
6067
seed: params.seed,
6168
f16_kv: params.f16_kv,
6269
logits_all: false,
@@ -74,7 +81,10 @@ impl From<llama_context_params> for ContextParams {
7481
fn from(params: llama_context_params) -> Self {
7582
ContextParams {
7683
n_ctx: params.n_ctx,
77-
n_parts: params.n_parts,
84+
n_batch: params.n_batch,
85+
n_gpu_layers: params.n_gpu_layers,
86+
main_gpu: params.main_gpu,
87+
tensor_split: params.tensor_split,
7888
seed: params.seed,
7989
f16_kv: params.f16_kv,
8090
vocab_only: params.vocab_only,

crates/llm-chain-llama/src/executor.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl Executor {
5151

5252
// Run the LLAMA model with the provided input and generate output.
5353
// Executes the model with the provided input and context parameters.
54-
fn run_model(&self, input: LlamaInvocation) -> Output {
54+
async fn run_model(&self, input: LlamaInvocation) -> Output {
5555
let (sender, output) = Output::new_stream();
5656
// Tokenize the stop sequence and input prompt.
5757
let context = self.context.clone();
@@ -62,7 +62,6 @@ impl Executor {
6262
async move {
6363
let context_size = context_size;
6464
let context = context.lock().await;
65-
6665
let tokenized_stop_prompt = tokenize(
6766
&context,
6867
input
@@ -87,7 +86,7 @@ impl Executor {
8786

8887
// Embd contains the prompt and the completion. The longer the prompt, the shorter the completion.
8988
let mut embd = tokenized_input.clone();
90-
89+
9190
// Evaluate the prompt in full.
9291
bail!(
9392
context
@@ -180,7 +179,7 @@ impl Executor {
180179
}
181180
}
182181
}
183-
});
182+
}).await.unwrap().await;
184183

185184
output
186185
}
@@ -210,7 +209,7 @@ impl ExecutorTrait for Executor {
210209
async fn execute(&self, options: &Options, prompt: &Prompt) -> Result<Output, ExecutorError> {
211210
let invocation = LlamaInvocation::new(self.get_cascade(options), prompt)
212211
.ok_or(ExecutorError::InvalidOptions)?;
213-
Ok(self.run_model(invocation))
212+
Ok(self.run_model(invocation).await)
214213
}
215214

216215
fn tokens_used(

crates/llm-chain-llama/src/options.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ impl LlamaInvocation {
4343
pub(crate) fn new(opt: OptionsCascade, prompt: &Prompt) -> Option<LlamaInvocation> {
4444
opt_extract!(opt, n_threads, NThreads);
4545
opt_extract!(opt, n_tok_predict, MaxTokens);
46-
opt_extract!(opt, token_bias, TokenBias);
46+
// Skip TokenBias for now
47+
//opt_extract!(opt, token_bias, TokenBias);
4748
opt_extract!(opt, top_k, TopK);
4849
opt_extract!(opt, top_p, TopP);
4950
opt_extract!(opt, tfs_z, TfsZ);
@@ -59,8 +60,8 @@ impl LlamaInvocation {
5960
opt_extract!(opt, penalize_nl, PenalizeNl);
6061
opt_extract!(opt, stop_sequence, StopSequence);
6162

62-
let logit_bias = token_bias.as_i32_f32_hashmap()?;
63-
63+
let logit_bias = HashMap::<i32,f32>::new();// token_bias.as_i32_f32_hashmap()?;
64+
6465
Some(LlamaInvocation {
6566
n_threads: *n_threads as i32,
6667
n_tok_predict: *n_tok_predict,

0 commit comments

Comments
 (0)