@@ -3,9 +3,9 @@ use std::num::NonZeroUsize;
3
3
use std:: sync:: { Arc , LazyLock , Mutex } ;
4
4
5
5
use lru:: LruCache ;
6
+ use pgt_lexer:: lex;
6
7
use pgt_query_ext:: diagnostics:: * ;
7
8
use pgt_text_size:: TextRange ;
8
- use pgt_tokenizer:: tokenize;
9
9
use regex:: Regex ;
10
10
11
11
use super :: statement_identifier:: StatementId ;
@@ -104,6 +104,27 @@ fn is_composite_type_error(err: &str) -> bool {
104
104
COMPOSITE_TYPE_ERROR_RE . is_match ( err)
105
105
}
106
106
107
+ // Keywords that, when preceding a named parameter, indicate that the parameter should be treated
108
+ // as an identifier rather than a positional parameter.
109
+ const IDENTIFIER_CONTEXT : [ pgt_lexer:: SyntaxKind ; 15 ] = [
110
+ pgt_lexer:: SyntaxKind :: TO_KW ,
111
+ pgt_lexer:: SyntaxKind :: FROM_KW ,
112
+ pgt_lexer:: SyntaxKind :: SCHEMA_KW ,
113
+ pgt_lexer:: SyntaxKind :: TABLE_KW ,
114
+ pgt_lexer:: SyntaxKind :: INDEX_KW ,
115
+ pgt_lexer:: SyntaxKind :: CONSTRAINT_KW ,
116
+ pgt_lexer:: SyntaxKind :: OWNER_KW ,
117
+ pgt_lexer:: SyntaxKind :: ROLE_KW ,
118
+ pgt_lexer:: SyntaxKind :: USER_KW ,
119
+ pgt_lexer:: SyntaxKind :: DATABASE_KW ,
120
+ pgt_lexer:: SyntaxKind :: TYPE_KW ,
121
+ pgt_lexer:: SyntaxKind :: CAST_KW ,
122
+ pgt_lexer:: SyntaxKind :: ALTER_KW ,
123
+ pgt_lexer:: SyntaxKind :: DROP_KW ,
124
+ // for schema.table style identifiers
125
+ pgt_lexer:: SyntaxKind :: DOT ,
126
+ ] ;
127
+
107
128
/// Converts named parameters in a SQL query string to positional parameters.
108
129
///
109
130
/// This function scans the input SQL string for named parameters (e.g., `@param`, `:param`, `:'param'`)
@@ -116,13 +137,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
116
137
let mut result = String :: with_capacity ( text. len ( ) ) ;
117
138
let mut param_mapping: HashMap < & str , usize > = HashMap :: new ( ) ;
118
139
let mut param_index = 1 ;
119
- let mut position = 0 ;
120
140
121
- for token in tokenize ( text) {
122
- let token_len = token. len as usize ;
123
- let token_text = & text[ position..position + token_len] ;
141
+ let lexed = lex ( text) ;
142
+ for ( token_idx, kind) in lexed. tokens ( ) . enumerate ( ) {
143
+ if kind == pgt_lexer:: SyntaxKind :: EOF {
144
+ break ;
145
+ }
146
+
147
+ let token_text = lexed. text ( token_idx) ;
124
148
125
- if matches ! ( token . kind, pgt_tokenizer :: TokenKind :: NamedParam { .. } ) {
149
+ if matches ! ( kind, pgt_lexer :: SyntaxKind :: NAMED_PARAM ) {
126
150
let idx = match param_mapping. get ( token_text) {
127
151
Some ( & index) => index,
128
152
None => {
@@ -133,7 +157,16 @@ pub fn convert_to_positional_params(text: &str) -> String {
133
157
}
134
158
} ;
135
159
136
- let replacement = format ! ( "${}" , idx) ;
160
+ // find previous non-trivia token
161
+ let prev_token = ( 0 ..token_idx)
162
+ . rev ( )
163
+ . map ( |i| lexed. kind ( i) )
164
+ . find ( |kind| !kind. is_trivia ( ) ) ;
165
+
166
+ let replacement = match prev_token {
167
+ Some ( k) if IDENTIFIER_CONTEXT . contains ( & k) => deterministic_identifier ( idx - 1 ) ,
168
+ _ => format ! ( "${}" , idx) ,
169
+ } ;
137
170
let original_len = token_text. len ( ) ;
138
171
let replacement_len = replacement. len ( ) ;
139
172
@@ -146,17 +179,45 @@ pub fn convert_to_positional_params(text: &str) -> String {
146
179
} else {
147
180
result. push_str ( token_text) ;
148
181
}
149
-
150
- position += token_len;
151
182
}
152
183
153
184
result
154
185
}
155
186
187
+ const ALPHABET : [ char ; 26 ] = [
188
+ 'a' , 'b' , 'c' , 'd' , 'e' , 'f' , 'g' , 'h' , 'i' , 'j' , 'k' , 'l' , 'm' , 'n' , 'o' , 'p' , 'q' , 'r' , 's' ,
189
+ 't' , 'u' , 'v' , 'w' , 'x' , 'y' , 'z' ,
190
+ ] ;
191
+
192
+ /// Generates a deterministic identifier based on the given index.
193
+ fn deterministic_identifier ( idx : usize ) -> String {
194
+ let iteration = idx / ALPHABET . len ( ) ;
195
+ let pos = idx % ALPHABET . len ( ) ;
196
+
197
+ format ! (
198
+ "{}{}" ,
199
+ ALPHABET [ pos] ,
200
+ if iteration > 0 {
201
+ deterministic_identifier( iteration - 1 )
202
+ } else {
203
+ "" . to_string( )
204
+ }
205
+ )
206
+ }
207
+
156
208
#[ cfg( test) ]
157
209
mod tests {
158
210
use super :: * ;
159
211
212
+ #[ test]
213
+ fn test_deterministic_identifier ( ) {
214
+ assert_eq ! ( deterministic_identifier( 0 ) , "a" ) ;
215
+ assert_eq ! ( deterministic_identifier( 25 ) , "z" ) ;
216
+ assert_eq ! ( deterministic_identifier( 26 ) , "aa" ) ;
217
+ assert_eq ! ( deterministic_identifier( 27 ) , "ba" ) ;
218
+ assert_eq ! ( deterministic_identifier( 51 ) , "za" ) ;
219
+ }
220
+
160
221
#[ test]
161
222
fn test_convert_to_positional_params ( ) {
162
223
let input = "select * from users where id = @one and name = :two and email = :'three';" ;
@@ -177,6 +238,24 @@ mod tests {
177
238
) ;
178
239
}
179
240
241
+ #[ test]
242
+ fn test_positional_params_in_grant ( ) {
243
+ let input = "grant usage on schema public, app_public, app_hidden to :DB_ROLE;" ;
244
+
245
+ let result = convert_to_positional_params ( input) ;
246
+
247
+ assert_eq ! (
248
+ result,
249
+ "grant usage on schema public, app_public, app_hidden to a ;"
250
+ ) ;
251
+
252
+ let store = PgQueryStore :: new ( ) ;
253
+
254
+ let res = store. get_or_cache_ast ( & StatementId :: new ( input) ) ;
255
+
256
+ assert ! ( res. is_ok( ) ) ;
257
+ }
258
+
180
259
#[ test]
181
260
fn test_plpgsql_syntax_error ( ) {
182
261
let input = "
0 commit comments