Skip to content

Commit 8653808

Browse files
committed
feat: compile bee to pdl
Signed-off-by: Nick Mitchell <[email protected]> This also improves error handling in the rust code. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 7607841 commit 8653808

File tree

7 files changed

+298
-64
lines changed

7 files changed

+298
-64
lines changed

.github/workflows/tauri-cli.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,14 @@ jobs:
4242
run: |
4343
PATH=./src-tauri/target/release/:$PATH
4444
45-
# 1. `run` subcommand works without any arguments
46-
pdl run | grep Usage
45+
# 1a. `run` subcommand errors due to missing required positional parameter
46+
pdl run && (echo "This should have failed" && exit 1) || (echo "Great, expected failure received" && exit 0)
47+
48+
# 1b.`run` subcommand works without any arguments to print Usage
49+
pdl run 2>&1 | grep Usage
50+
51+
# 1c.`run` subcommand works with -h to print Usage
52+
pdl run -h 2>&1 | grep Usage
4753
4854
# 2. `run` subcommand works with UI demos (yaml source)
4955
pdl run ./demos/demo1.pdl | grep 'write a hello'

pdl-live-react/src-tauri/src/cli/run.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub fn run_pdl_program(
1313
trace_file: Option<&tauri_plugin_cli::ArgData>,
1414
data: Option<&tauri_plugin_cli::ArgData>,
1515
stream: Option<&tauri_plugin_cli::ArgData>,
16-
) -> Result<(), tauri::Error> {
16+
) -> Result<(), Box<dyn ::std::error::Error>> {
1717
println!(
1818
"Running {:#?}",
1919
Path::new(&source_file_path).file_name().unwrap()
Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,86 @@
1-
use std::path::Path;
2-
use std::process::exit;
1+
use ::std::path::Path;
32

43
use serde_json::Value;
54
use urlencoding::encode;
65

76
use tauri_plugin_cli::CliExt;
87

98
use crate::cli::run;
9+
use crate::compile;
1010
use crate::gui::setup as gui_setup;
1111

1212
#[cfg(desktop)]
13-
pub fn cli(app: &mut tauri::App) -> Result<(), tauri::Error> {
13+
pub fn cli(app: &mut tauri::App) -> Result<(), Box<dyn ::std::error::Error>> {
1414
app.handle().plugin(tauri_plugin_cli::init())?;
1515

16-
match app.cli().matches() {
17-
// `matches` here is a Struct with { args, subcommand }.
18-
// `args` is `HashMap<String, ArgData>` where `ArgData` is a struct with { value, occurrences }.
19-
// `subcommand` is `Option<Box<SubcommandMatches>>` where `SubcommandMatches` is a struct with { name, matches }.
20-
Ok(matches) => match matches.subcommand {
21-
Some(subcommand_matches) => match subcommand_matches.name.as_str() {
22-
"run" => {
23-
if let Some(source) = subcommand_matches.matches.args.get("source") {
24-
if let Value::String(source_file_path) = &source.value {
25-
match run::run_pdl_program(
26-
source_file_path.clone(),
27-
app.handle().clone(),
28-
subcommand_matches.matches.args.get("trace"),
29-
subcommand_matches.matches.args.get("data"),
30-
subcommand_matches.matches.args.get("stream"),
31-
) {
32-
Ok(()) => exit(0),
33-
_ => exit(1),
34-
}
35-
}
36-
}
37-
println!("Usage: run <source.pdl>");
38-
exit(1)
39-
}
40-
"view" => match subcommand_matches.matches.args.get("trace") {
41-
Some(trace) => match &trace.value {
42-
Value::String(trace_file) => {
43-
gui_setup(
44-
app.handle().clone(),
45-
Path::new("/local")
46-
.join(encode(trace_file).as_ref())
47-
.display()
48-
.to_string(),
49-
)?;
50-
Ok(())
51-
}
52-
_ => {
53-
println!("Usage: view <tracefile.json>");
54-
exit(1)
55-
}
56-
},
57-
_ => {
58-
println!("Usage: view <tracefile.json>");
59-
exit(1)
60-
}
61-
},
62-
_ => {
63-
println!("Invalid subcommand");
64-
exit(1)
16+
// `matches` here is a Struct with { args, subcommand }.
17+
// `args` is `HashMap<String, ArgData>` where `ArgData` is a struct with { value, occurrences }.
18+
// `subcommand` is `Option<Box<SubcommandMatches>>` where `SubcommandMatches` is a struct with { name, matches }.
19+
let Some(subcommand_matches) = app.cli().matches()?.subcommand else {
20+
if let Some(help) = app.cli().matches()?.args.get("help") {
21+
return Err(Box::from(help.value.as_str().or(Some("Internal Error")).unwrap()));
22+
} else {
23+
return Err(Box::from("Internal Error"));
24+
}
25+
};
26+
27+
match subcommand_matches.name.as_str() {
28+
"compile" => {
29+
let Some(compile_subcommand_matches) = subcommand_matches.matches.subcommand else {
30+
return Err(Box::from("Missing compile subcommand"));
31+
};
32+
33+
match compile_subcommand_matches.name.as_str() {
34+
"beeai" => {
35+
let Some(source) = compile_subcommand_matches.matches.args.get("source") else {
36+
return Err(Box::from("Missing source file"));
37+
};
38+
let Value::String(source_file_path) = &source.value else {
39+
return Err(Box::from("Invalid source file argument"));
40+
};
41+
let Some(output) = compile_subcommand_matches.matches.args.get("output") else {
42+
return Err(Box::from("Missing output argument"));
43+
};
44+
let Value::String(output_file_path) = &output.value else {
45+
return Err(Box::from("Invalid output file argument"));
46+
};
47+
return compile::beeai::compile(source_file_path, output_file_path);
6548
}
66-
},
67-
None => {
68-
println!("Invalid command");
69-
exit(1)
49+
_ => {}
7050
}
71-
},
72-
Err(s) => {
73-
println!("{:?}", s);
74-
exit(1)
7551
}
52+
"run" => {
53+
let Some(source) = subcommand_matches.matches.args.get("source") else {
54+
return Err(Box::from("Missing source file"));
55+
};
56+
let Value::String(source_file_path) = &source.value else {
57+
return Err(Box::from("Invalid source file argument"));
58+
};
59+
return run::run_pdl_program(
60+
source_file_path.clone(),
61+
app.handle().clone(),
62+
subcommand_matches.matches.args.get("trace"),
63+
subcommand_matches.matches.args.get("data"),
64+
subcommand_matches.matches.args.get("stream"),
65+
);
66+
}
67+
"view" => {
68+
let Some(trace) = subcommand_matches.matches.args.get("trace") else {
69+
return Err(Box::from("Missing trace file"));
70+
};
71+
let Value::String(trace_file) = &trace.value else {
72+
return Err(Box::from("Invalid trace file argument"));
73+
};
74+
gui_setup(
75+
app.handle().clone(),
76+
Path::new("/local")
77+
.join(encode(trace_file).as_ref())
78+
.display()
79+
.to_string(),
80+
)?
81+
}
82+
_ => {}
7683
}
84+
85+
Ok(())
7786
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
use ::std::collections::HashMap;
2+
use ::std::error::Error;
3+
use ::std::fs::File;
4+
use ::std::io::BufReader;
5+
6+
use serde::{Deserialize, Serialize};
7+
use serde_json::{from_reader, to_string, Value};
8+
9+
macro_rules! zip {
10+
($x: expr) => ($x);
11+
($x: expr, $($y: expr), +) => (
12+
$x.into_iter().zip(
13+
zip!($($y), +))
14+
)
15+
}
16+
17+
#[derive(Deserialize, Debug)]
18+
struct BeeAiInputStateDict {
19+
prompt: Option<String>,
20+
// expected_output: Option<String>,
21+
}
22+
#[derive(Deserialize, Debug)]
23+
struct BeeAiInputState {
24+
#[serde(rename = "__dict__")]
25+
dict: BeeAiInputStateDict,
26+
}
27+
#[derive(Deserialize, Debug)]
28+
struct BeeAiInput {
29+
#[serde(rename = "py/state")]
30+
state: BeeAiInputState,
31+
}
32+
#[derive(Deserialize, Debug)]
33+
struct BeeAiTool {
34+
//#[serde(rename = "py/object")]
35+
//tool: String,
36+
//options: Option<String>, // TODO maybe more general than String?
37+
}
38+
#[derive(Deserialize, Debug)]
39+
struct BeeAiLlmParametersState {
40+
#[serde(rename = "__dict__")]
41+
dict: HashMap<String, Value>,
42+
}
43+
#[derive(Deserialize, Debug)]
44+
struct BeeAiLlmParameters {
45+
#[serde(rename = "py/state")]
46+
state: BeeAiLlmParametersState,
47+
}
48+
#[derive(Deserialize, Debug)]
49+
struct BeeAiLlmSettings {
50+
api_key: String,
51+
// base_url: String,
52+
}
53+
#[derive(Deserialize, Debug)]
54+
struct BeeAiLlm {
55+
// might be helpful to know it's Ollama?
56+
//#[serde(rename = "py/object")]
57+
//object: String,
58+
parameters: BeeAiLlmParameters,
59+
60+
#[serde(rename = "_model_id")]
61+
model_id: String,
62+
//#[serde(rename = "_litellm_provider_id")]
63+
//provider_id: String,
64+
#[serde(rename = "_settings")]
65+
settings: BeeAiLlmSettings,
66+
}
67+
#[derive(Deserialize, Debug)]
68+
struct BeeAiWorkflowStepStateMeta {
69+
//name: String,
70+
role: String,
71+
llm: BeeAiLlm,
72+
instructions: Option<String>,
73+
//tools: Option<Vec<BeeAiTool>>,
74+
}
75+
#[derive(Deserialize, Debug)]
76+
struct BeeAiWorkflowStepStateDict {
77+
meta: BeeAiWorkflowStepStateMeta,
78+
}
79+
#[derive(Deserialize, Debug)]
80+
struct BeeAiWorkflowStepState {
81+
#[serde(rename = "__dict__")]
82+
dict: BeeAiWorkflowStepStateDict,
83+
}
84+
#[derive(Deserialize, Debug)]
85+
struct BeeAiWorkflowStep {
86+
#[serde(rename = "py/state")]
87+
state: BeeAiWorkflowStepState,
88+
}
89+
#[derive(Deserialize, Debug)]
90+
struct BeeAiWorkflowInner {
91+
#[serde(rename = "_name")]
92+
name: String,
93+
#[serde(rename = "_steps")]
94+
steps: HashMap<String, BeeAiWorkflowStep>,
95+
}
96+
#[derive(Deserialize, Debug)]
97+
struct BeeAiWorkflow {
98+
workflow: BeeAiWorkflowInner,
99+
}
100+
#[derive(Deserialize, Debug)]
101+
struct BeeAiProgram {
102+
inputs: Vec<BeeAiInput>,
103+
workflow: BeeAiWorkflow,
104+
}
105+
106+
#[derive(Serialize, Debug)]
107+
#[serde(untagged)]
108+
enum PdlBlock {
109+
String(String),
110+
Text {
111+
#[serde(skip_serializing_if = "Option::is_none")]
112+
description: Option<String>,
113+
#[serde(skip_serializing_if = "Option::is_none")]
114+
role: Option<String>,
115+
text: Vec<PdlBlock>,
116+
},
117+
Model {
118+
#[serde(skip_serializing_if = "Option::is_none")]
119+
description: Option<String>,
120+
model: String,
121+
parameters: HashMap<String, Value>,
122+
},
123+
}
124+
125+
pub fn compile(source_file_path: &String, output_path: &String) -> Result<(), Box<dyn Error>> {
126+
println!("Compiling beeai {} to {}", source_file_path, output_path);
127+
128+
// Open the file in read-only mode with buffer.
129+
let file = File::open(source_file_path)?;
130+
let reader = BufReader::new(file);
131+
132+
// Read the JSON contents of the file as an instance of `User`.
133+
let bee: BeeAiProgram = from_reader(reader)?;
134+
135+
let inputs: Vec<PdlBlock> = bee
136+
.inputs
137+
.into_iter()
138+
.map(|input| input.state.dict.prompt)
139+
.flatten()
140+
.map(|prompt| PdlBlock::String(format!("{}\n", prompt)))
141+
.collect::<Vec<_>>();
142+
143+
let system_prompts = bee
144+
.workflow
145+
.workflow
146+
.steps
147+
.values()
148+
.filter_map(|step| step.state.dict.meta.instructions.clone())
149+
.map(|instructions| PdlBlock::Text {
150+
role: Some(String::from("system")),
151+
text: vec![PdlBlock::String(instructions)],
152+
description: None,
153+
})
154+
.collect::<Vec<_>>();
155+
156+
let model_calls = bee
157+
.workflow
158+
.workflow
159+
.steps
160+
.into_values()
161+
.map(|step| (step.state.dict.meta.role, step.state.dict.meta.llm))
162+
.map(|(role, llm)| PdlBlock::Model {
163+
description: Some(role),
164+
model: format!("{}/{}", llm.settings.api_key, llm.model_id),
165+
parameters: llm.parameters.state.dict,
166+
})
167+
.collect::<Vec<_>>();
168+
169+
let pdl: PdlBlock = PdlBlock::Text {
170+
description: Some(bee.workflow.workflow.name),
171+
role: None,
172+
text: zip!(inputs, system_prompts, model_calls)
173+
.map(|(a, (b, c))| [a, b, c])
174+
.flatten()
175+
.collect(),
176+
};
177+
178+
match output_path.as_str() {
179+
"-" => println!("{}", to_string(&pdl)?),
180+
_ => {
181+
::std::fs::write(output_path, to_string(&pdl)?)?;
182+
}
183+
}
184+
185+
Ok(())
186+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod beeai;

pdl-live-react/src-tauri/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use tauri_plugin_pty;
33

44
mod cli;
55
mod commands;
6+
mod compile;
67
mod gui;
78
mod interpreter;
89

@@ -14,7 +15,13 @@ pub fn run() {
1415
if args_os().count() <= 1 {
1516
gui::setup(app.handle().clone(), "".to_owned())?;
1617
} else {
17-
cli::setup::cli(app)?;
18+
match cli::setup::cli(app) {
19+
Ok(()) => ::std::process::exit(0),
20+
Err(s) => {
21+
eprintln!("{}", s);
22+
::std::process::exit(1)
23+
}
24+
}
1825
}
1926

2027
Ok(())

0 commit comments

Comments
 (0)