Skip to content

Commit e3d1576

Browse files
authored
fix: context-aware named param conversion (#513)
we were blindly converting named parameters to positional parameters. but the latter is only valid as a literal, and not as an identifier. statements like ```sql grant usage on schema public, app_public, app_hidden to :DB_ROLE; ``` are not valid when converted to ```sql grant usage on schema public, app_public, app_hidden to $1; ``` i went a bit back and forth on this and decided the easiest way to fix this is to convert to identifiers like `a` if the previous non-trivia token is one of a set list of tokens. We will probably have a bunch of edge cases here but fixing them should be as easy as adding a keyword to the list. now, we convert to ```sql grant usage on schema public, app_public, app_hidden to a; ``` closes #510
1 parent e2826fb commit e3d1576

File tree

5 files changed

+152
-10
lines changed

5 files changed

+152
-10
lines changed

crates/pgt_lexer/src/lexer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ impl<'a> Lexer<'a> {
143143
}
144144
_ => {}
145145
};
146-
SyntaxKind::POSITIONAL_PARAM
146+
SyntaxKind::NAMED_PARAM
147147
}
148148
pgt_tokenizer::TokenKind::QuotedIdent { terminated } => {
149149
if !terminated {

crates/pgt_lexer_codegen/src/syntax_kind.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {
6565

6666
let mut enum_variants: Vec<TokenStream> = Vec::new();
6767
let mut from_kw_match_arms: Vec<TokenStream> = Vec::new();
68+
let mut is_kw_match_arms: Vec<TokenStream> = Vec::new();
69+
70+
let mut is_trivia_match_arms: Vec<TokenStream> = Vec::new();
6871

6972
// collect keywords
7073
for kw in &all_keywords {
@@ -78,18 +81,30 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {
7881
from_kw_match_arms.push(quote! {
7982
#kw => Some(SyntaxKind::#kind_ident)
8083
});
84+
is_kw_match_arms.push(quote! {
85+
SyntaxKind::#kind_ident => true
86+
});
8187
}
8288

8389
// collect extra keywords
8490
EXTRA.iter().for_each(|&name| {
8591
let variant_name = format_ident!("{}", name);
8692
enum_variants.push(quote! { #variant_name });
93+
94+
if name == "COMMENT" {
95+
is_trivia_match_arms.push(quote! {
96+
SyntaxKind::#variant_name => true
97+
});
98+
}
8799
});
88100

89101
// collect whitespace variants
90102
WHITESPACE.iter().for_each(|&name| {
91103
let variant_name = format_ident!("{}", name);
92104
enum_variants.push(quote! { #variant_name });
105+
is_trivia_match_arms.push(quote! {
106+
SyntaxKind::#variant_name => true
107+
});
93108
});
94109

95110
// collect punctuations
@@ -119,6 +134,20 @@ pub fn syntax_kind_mod() -> proc_macro2::TokenStream {
119134
_ => None
120135
}
121136
}
137+
138+
pub fn is_keyword(&self) -> bool {
139+
match self {
140+
#(#is_kw_match_arms),*,
141+
_ => false
142+
}
143+
}
144+
145+
pub fn is_trivia(&self) -> bool {
146+
match self {
147+
#(#is_trivia_match_arms),*,
148+
_ => false
149+
}
150+
}
122151
}
123152
}
124153
}

crates/pgt_tokenizer/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,13 @@ mod tests {
668668
assert_debug_snapshot!(result);
669669
}
670670

671+
#[test]
672+
fn graphile_named_param() {
673+
let result =
674+
lex("grant usage on schema public, app_public, app_hidden to :DATABASE_VISITOR;");
675+
assert_debug_snapshot!(result);
676+
}
677+
671678
#[test]
672679
fn named_param_dollar_raw() {
673680
let result = lex("select 1 from c where id = $id;");
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
---
2+
source: crates/pgt_tokenizer/src/lib.rs
3+
expression: result
4+
snapshot_kind: text
5+
---
6+
[
7+
"grant" @ Ident,
8+
" " @ Space,
9+
"usage" @ Ident,
10+
" " @ Space,
11+
"on" @ Ident,
12+
" " @ Space,
13+
"schema" @ Ident,
14+
" " @ Space,
15+
"public" @ Ident,
16+
"," @ Comma,
17+
" " @ Space,
18+
"app_public" @ Ident,
19+
"," @ Comma,
20+
" " @ Space,
21+
"app_hidden" @ Ident,
22+
" " @ Space,
23+
"to" @ Ident,
24+
" " @ Space,
25+
":DATABASE_VISITOR" @ NamedParam { kind: ColonRaw },
26+
";" @ Semi,
27+
]

crates/pgt_workspace/src/workspace/server/pg_query.rs

Lines changed: 88 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ use std::num::NonZeroUsize;
33
use std::sync::{Arc, LazyLock, Mutex};
44

55
use lru::LruCache;
6+
use pgt_lexer::lex;
67
use pgt_query_ext::diagnostics::*;
78
use pgt_text_size::TextRange;
8-
use pgt_tokenizer::tokenize;
99
use regex::Regex;
1010

1111
use super::statement_identifier::StatementId;
@@ -104,6 +104,27 @@ fn is_composite_type_error(err: &str) -> bool {
104104
COMPOSITE_TYPE_ERROR_RE.is_match(err)
105105
}
106106

107+
// Keywords that, when preceding a named parameter, indicate that the parameter should be treated
108+
// as an identifier rather than a positional parameter.
109+
const IDENTIFIER_CONTEXT: [pgt_lexer::SyntaxKind; 15] = [
110+
pgt_lexer::SyntaxKind::TO_KW,
111+
pgt_lexer::SyntaxKind::FROM_KW,
112+
pgt_lexer::SyntaxKind::SCHEMA_KW,
113+
pgt_lexer::SyntaxKind::TABLE_KW,
114+
pgt_lexer::SyntaxKind::INDEX_KW,
115+
pgt_lexer::SyntaxKind::CONSTRAINT_KW,
116+
pgt_lexer::SyntaxKind::OWNER_KW,
117+
pgt_lexer::SyntaxKind::ROLE_KW,
118+
pgt_lexer::SyntaxKind::USER_KW,
119+
pgt_lexer::SyntaxKind::DATABASE_KW,
120+
pgt_lexer::SyntaxKind::TYPE_KW,
121+
pgt_lexer::SyntaxKind::CAST_KW,
122+
pgt_lexer::SyntaxKind::ALTER_KW,
123+
pgt_lexer::SyntaxKind::DROP_KW,
124+
// for schema.table style identifiers
125+
pgt_lexer::SyntaxKind::DOT,
126+
];
127+
107128
/// Converts named parameters in a SQL query string to positional parameters.
108129
///
109130
/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`)
@@ -116,13 +137,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
116137
let mut result = String::with_capacity(text.len());
117138
let mut param_mapping: HashMap<&str, usize> = HashMap::new();
118139
let mut param_index = 1;
119-
let mut position = 0;
120140

121-
for token in tokenize(text) {
122-
let token_len = token.len as usize;
123-
let token_text = &text[position..position + token_len];
141+
let lexed = lex(text);
142+
for (token_idx, kind) in lexed.tokens().enumerate() {
143+
if kind == pgt_lexer::SyntaxKind::EOF {
144+
break;
145+
}
146+
147+
let token_text = lexed.text(token_idx);
124148

125-
if matches!(token.kind, pgt_tokenizer::TokenKind::NamedParam { .. }) {
149+
if matches!(kind, pgt_lexer::SyntaxKind::NAMED_PARAM) {
126150
let idx = match param_mapping.get(token_text) {
127151
Some(&index) => index,
128152
None => {
@@ -133,7 +157,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
133157
}
134158
};
135159

136-
let replacement = format!("${}", idx);
160+
// find previous non-trivia token
161+
let prev_token = (0..token_idx)
162+
.rev()
163+
.map(|i| lexed.kind(i))
164+
.find(|kind| !kind.is_trivia());
165+
166+
let replacement = match prev_token {
167+
Some(k) if IDENTIFIER_CONTEXT.contains(&k) => deterministic_identifier(idx - 1),
168+
_ => format!("${}", idx),
169+
};
137170
let original_len = token_text.len();
138171
let replacement_len = replacement.len();
139172

@@ -146,17 +179,45 @@ pub fn convert_to_positional_params(text: &str) -> String {
146179
} else {
147180
result.push_str(token_text);
148181
}
149-
150-
position += token_len;
151182
}
152183

153184
result
154185
}
155186

187+
const ALPHABET: [char; 26] = [
188+
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
189+
't', 'u', 'v', 'w', 'x', 'y', 'z',
190+
];
191+
192+
/// Generates a deterministic identifier based on the given index.
193+
fn deterministic_identifier(idx: usize) -> String {
194+
let iteration = idx / ALPHABET.len();
195+
let pos = idx % ALPHABET.len();
196+
197+
format!(
198+
"{}{}",
199+
ALPHABET[pos],
200+
if iteration > 0 {
201+
deterministic_identifier(iteration - 1)
202+
} else {
203+
"".to_string()
204+
}
205+
)
206+
}
207+
156208
#[cfg(test)]
157209
mod tests {
158210
use super::*;
159211

212+
#[test]
213+
fn test_deterministic_identifier() {
214+
assert_eq!(deterministic_identifier(0), "a");
215+
assert_eq!(deterministic_identifier(25), "z");
216+
assert_eq!(deterministic_identifier(26), "aa");
217+
assert_eq!(deterministic_identifier(27), "ba");
218+
assert_eq!(deterministic_identifier(51), "za");
219+
}
220+
160221
#[test]
161222
fn test_convert_to_positional_params() {
162223
let input = "select * from users where id = @one and name = :two and email = :'three';";
@@ -177,6 +238,24 @@ mod tests {
177238
);
178239
}
179240

241+
#[test]
242+
fn test_positional_params_in_grant() {
243+
let input = "grant usage on schema public, app_public, app_hidden to :DB_ROLE;";
244+
245+
let result = convert_to_positional_params(input);
246+
247+
assert_eq!(
248+
result,
249+
"grant usage on schema public, app_public, app_hidden to a ;"
250+
);
251+
252+
let store = PgQueryStore::new();
253+
254+
let res = store.get_or_cache_ast(&StatementId::new(input));
255+
256+
assert!(res.is_ok());
257+
}
258+
180259
#[test]
181260
fn test_plpgsql_syntax_error() {
182261
let input = "

0 commit comments

Comments
 (0)