Skip to content

Commit 3106dce

Browse files
starpitclaudiosv
authored andcommitted
feat: some regex parser support for rust interpreter
Split mode TODO! Signed-off-by: Nick Mitchell <[email protected]> Signed-off-by: Claudio Spiess <[email protected]>
1 parent 3ba5cab commit 3106dce

File tree

6 files changed

+128
-9
lines changed

6 files changed

+128
-9
lines changed

pdl-live-react/src-tauri/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pdl-live-react/src-tauri/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ fs4 = "0.13.1"
5050
derive_builder = "0.20.2"
5151
iana-time-zone = "0.1.63"
5252
async-openai = "0.28.1"
53+
regex = "1.11.1"
5354

5455
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
5556
tauri-plugin-cli = "2"

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,18 @@ pub enum Role {
1919
Tool,
2020
}
2121

22+
/// Function used to parse to value (https://docs.python.org/3/library/re.html).
2223
#[derive(Serialize, Deserialize, Debug, Clone)]
23-
#[serde(rename_all_fields(serialize = "lowercase"))]
2424
pub enum RegexMode {
25+
#[serde(rename = "search")]
2526
Search,
27+
#[serde(rename = "match")]
2628
Match,
29+
#[serde(rename = "fullmatch")]
2730
Fullmatch,
31+
#[serde(rename = "split")]
2832
Split,
33+
#[serde(rename = "findall")]
2934
Findall,
3035
}
3136

@@ -40,7 +45,7 @@ pub struct RegexParser {
4045

4146
/// Expected type of the parsed value
4247
#[serde(skip_serializing_if = "Option::is_none")]
43-
pub spec: Option<Value>,
48+
pub spec: Option<IndexMap<String, PdlType>>,
4449
}
4550

4651
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -744,3 +749,15 @@ impl From<Number> for PdlResult {
744749
PdlResult::Number(n)
745750
}
746751
}
752+
impl PartialEq for PdlResult {
753+
fn eq(&self, other: &Self) -> bool {
754+
match (self, other) {
755+
(PdlResult::Number(a), PdlResult::Number(b)) => a == b,
756+
(PdlResult::String(a), PdlResult::String(b)) => a == b,
757+
(PdlResult::Bool(a), PdlResult::Bool(b)) => a == b,
758+
(PdlResult::List(a), PdlResult::List(b)) => a == b,
759+
(PdlResult::Dict(a), PdlResult::Dict(b)) => a == b,
760+
_ => false,
761+
}
762+
}
763+
}

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ use crate::pdl::ast::{
2727
IncludeBlock, ListOrString, MessageBlock, Metadata, MetadataBuilder, ModelBlock, ObjectBlock,
2828
PdlBlock,
2929
PdlBlock::Advanced,
30-
PdlParser, PdlResult, PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock, Role, Scope,
31-
SequencingBlock, StringOrBoolean, StringOrNull, Timing,
30+
PdlParser, PdlResult, PdlUsage, PythonCodeBlock, ReadBlock, RegexMode, RegexParser,
31+
RepeatBlock, Role, Scope, SequencingBlock, StringOrBoolean, StringOrNull, Timing,
3232
};
3333

3434
type Messages = Vec<ChatMessage>;
@@ -1172,12 +1172,13 @@ impl<'a> Interpreter<'a> {
11721172
Some(Value::String(s)) => Some(ChatMessage::user(s.clone())),
11731173
_ => None,
11741174
},
1175-
_ => None,
1175+
m => Some(ChatMessage::user(m.to_string())),
11761176
})
11771177
.collect(),
1178-
_ => vec![],
1178+
_ => vec![ChatMessage::user(to_string(m)?)],
11791179
},
1180-
_ => vec![],
1180+
Value::Array(a) => vec![ChatMessage::user(to_string(a)?)],
1181+
m => vec![ChatMessage::user(m.to_string())],
11811182
};
11821183
Ok((result, messages, Data(trace)))
11831184
}
@@ -1288,7 +1289,38 @@ impl<'a> Interpreter<'a> {
12881289
.collect::<Result<_, _>>()?,
12891290
)),
12901291
PdlParser::Yaml => from_yaml_str(result).map_err(|e| Box::from(e)),
1291-
PdlParser::Regex(_) => todo!(),
1292+
PdlParser::Regex(RegexParser { regex, mode, spec }) => {
1293+
use regex::Regex;
1294+
let re = Regex::new(regex)?;
1295+
let expected_captures: Vec<&str> = if let Some(spec) = spec {
1296+
spec.keys().map(|k| k.as_str()).collect()
1297+
} else {
1298+
vec![]
1299+
};
1300+
1301+
match mode {
1302+
Some(RegexMode::Findall) => Ok(PdlResult::List(
1303+
re.captures_iter(result)
1304+
.flat_map(|cap| {
1305+
expected_captures.iter().filter_map(move |k| {
1306+
cap.name(k).and_then(|m| Some(m.as_str().into()))
1307+
})
1308+
})
1309+
.collect(),
1310+
)),
1311+
Some(RegexMode::Split) => todo!(),
1312+
_ => Ok(PdlResult::Dict(
1313+
re.captures_iter(result)
1314+
.flat_map(|cap| {
1315+
expected_captures.iter().filter_map(move |k| {
1316+
cap.name(k)
1317+
.and_then(|m| Some((k.to_string(), m.as_str().into())))
1318+
})
1319+
})
1320+
.collect(),
1321+
)),
1322+
}
1323+
}
12921324
}
12931325
}
12941326

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mod tests {
44
use serde_json::json;
55

66
use crate::pdl::{
7-
ast::{Block, Body::*, ModelBlockBuilder, PdlBlock, PdlBlock::Advanced, Scope},
7+
ast::{Block, Body::*, ModelBlockBuilder, PdlBlock, PdlBlock::Advanced, PdlResult, Scope},
88
interpreter::{RunOptions, load_scope, run_json_sync as run_json, run_sync as run},
99
};
1010

@@ -679,6 +679,67 @@ mod tests {
679679
Ok(())
680680
}
681681

682+
#[test]
683+
fn regex_findall() -> Result<(), Box<dyn Error>> {
684+
let program = json!({
685+
"data": "aaa999bbb888",
686+
"parser": {
687+
"regex": "[^0-9]*(?P<answer1>[0-9]+)[^0-9]*(?P<answer2>[0-9]+)$",
688+
"mode": "findall",
689+
"spec": {
690+
"answer1": "str",
691+
"answer2": "str"
692+
}
693+
}
694+
});
695+
696+
let (_, messages, _) = run_json(program, streaming(), initial_scope())?;
697+
assert_eq!(messages.len(), 1);
698+
assert_eq!(messages[0].role, MessageRole::User);
699+
assert_eq!(messages[0].content, "[\"999\",\"888\"]");
700+
Ok(())
701+
}
702+
703+
#[test]
704+
fn regex_plain_1() -> Result<(), Box<dyn Error>> {
705+
let program = json!({
706+
"data": "aaa999bbb888",
707+
"parser": {
708+
"regex": "[^0-9]*(?P<answer1>[0-9]+)[^0-9]*(?P<answer2>[0-9]+)$",
709+
"spec": {
710+
"answer1": "str",
711+
"answer2": "str"
712+
}
713+
}
714+
});
715+
716+
let (result, _, _) = run_json(program, streaming(), initial_scope())?;
717+
let mut m = ::std::collections::HashMap::new();
718+
m.insert("answer1".into(), "999".into());
719+
m.insert("answer2".into(), "888".into());
720+
assert_eq!(result, PdlResult::Dict(m));
721+
Ok(())
722+
}
723+
724+
#[test]
725+
fn regex_plain_2() -> Result<(), Box<dyn Error>> {
726+
let program = json!({
727+
"data": "aaa999bbb888",
728+
"parser": {
729+
"regex": "[^0-9]*(?P<answer1>[0-9]+)[^0-9]*(?P<answer2>[0-9]+)$",
730+
"spec": {
731+
"answer1": "str",
732+
}
733+
}
734+
});
735+
736+
let (result, _, _) = run_json(program, streaming(), initial_scope())?;
737+
let mut m = ::std::collections::HashMap::new();
738+
m.insert("answer1".into(), "999".into());
739+
assert_eq!(result, PdlResult::Dict(m));
740+
Ok(())
741+
}
742+
682743
#[test]
683744
fn bee_1() -> Result<(), Box<dyn Error>> {
684745
let program = crate::compile::beeai::compile("./tests/data/bee_1.py", false)?;
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
data: aaa999bbb888
2+
parser:
3+
regex: '[^0-9]*(?P<answer1>[0-9]+)[^0-9]*(?P<answer2>[0-9]+)$'
4+
mode: findall
5+
spec:
6+
answer1: str
7+
answer2: str

0 commit comments

Comments
 (0)