Skip to content

Commit c01cf50

Browse files
authored
feat: remove additional unsupported cast (#146)
1 parent e5644ec commit c01cf50

File tree

1 file changed

+65
-64
lines changed

1 file changed

+65
-64
lines changed

datafusion-postgres/src/sql.rs

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashSet;
2+
use std::ops::ControlFlow;
23
use std::sync::Arc;
34

45
use datafusion::sql::sqlparser::ast::Expr;
@@ -13,6 +14,8 @@ use datafusion::sql::sqlparser::ast::Statement;
1314
use datafusion::sql::sqlparser::ast::TableFactor;
1415
use datafusion::sql::sqlparser::ast::TableWithJoins;
1516
use datafusion::sql::sqlparser::ast::Value;
17+
use datafusion::sql::sqlparser::ast::VisitMut;
18+
use datafusion::sql::sqlparser::ast::VisitorMut;
1619
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
1720
use datafusion::sql::sqlparser::parser::Parser;
1821
use datafusion::sql::sqlparser::parser::ParserError;
@@ -272,8 +275,16 @@ impl RemoveUnsupportedTypes {
272275

273276
Self { unsupported_types }
274277
}
278+
}
279+
280+
struct RemoveUnsupportedTypesVisitor<'a> {
281+
unsupported_types: &'a HashSet<String>,
282+
}
283+
284+
impl<'a> VisitorMut for RemoveUnsupportedTypesVisitor<'a> {
285+
type Break = ();
275286

276-
fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
287+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
277288
match expr {
278289
// This is the key part: identify constants with type annotations.
279290
Expr::TypedString { value, data_type } => {
@@ -297,65 +308,58 @@ impl RemoveUnsupportedTypes {
297308
*expr = *value.clone();
298309
}
299310
}
300-
// Handle binary operations by recursively rewriting both sides.
301-
Expr::BinaryOp { left, right, .. } => {
302-
self.rewrite_expr_unsupported_types(left);
303-
self.rewrite_expr_unsupported_types(right);
304-
}
305311
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
306312
_ => {}
307313
}
314+
315+
ControlFlow::Continue(())
308316
}
309317
}
310318

311319
impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
312-
fn rewrite(&self, mut s: Statement) -> Statement {
313-
// Traverse the AST to find the WHERE clause and rewrite it.
314-
if let Statement::Query(query) = &mut s {
315-
if let SetExpr::Select(select) = query.body.as_mut() {
316-
if let Some(expr) = &mut select.selection {
317-
self.rewrite_expr_unsupported_types(expr);
318-
}
319-
}
320-
}
321-
322-
s
320+
fn rewrite(&self, mut statement: Statement) -> Statement {
321+
let mut visitor = RemoveUnsupportedTypesVisitor {
322+
unsupported_types: &self.unsupported_types,
323+
};
324+
let _ = statement.visit(&mut visitor);
325+
statement
323326
}
324327
}
325328

326329
#[cfg(test)]
327330
mod tests {
328331
use super::*;
329332

333+
macro_rules! assert_rewrite {
334+
($rules:expr, $orig:expr, $rewt:expr) => {
335+
let sql = $orig;
336+
let statement = parse(sql).expect("Failed to parse").remove(0);
337+
338+
let statement = rewrite(statement, $rules);
339+
assert_eq!(statement.to_string(), $rewt);
340+
};
341+
}
342+
330343
#[test]
331344
fn test_alias_rewrite() {
332345
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
333346
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
334347

335-
let sql = "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n";
336-
let statement = parse(sql).expect("Failed to parse").remove(0);
337-
338-
let statement = rewrite(statement, &rules);
339-
assert_eq!(
340-
statement.to_string(),
348+
assert_rewrite!(
349+
&rules,
350+
"SELECT n.oid, n.* FROM pg_catalog.pg_namespace n",
341351
"SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n"
342352
);
343353

344-
let sql = "SELECT oid, * FROM pg_catalog.pg_namespace";
345-
let statement = parse(sql).expect("Failed to parse").remove(0);
346-
347-
let statement = rewrite(statement, &rules);
348-
assert_eq!(
349-
statement.to_string(),
354+
assert_rewrite!(
355+
&rules,
356+
"SELECT oid, * FROM pg_catalog.pg_namespace",
350357
"SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace"
351358
);
352359

353-
let sql = "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id";
354-
let statement = parse(sql).expect("Failed to parse").remove(0);
355-
356-
let statement = rewrite(statement, &rules);
357-
assert_eq!(
358-
statement.to_string(),
360+
assert_rewrite!(
361+
&rules,
362+
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id",
359363
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
360364
);
361365
}
@@ -365,30 +369,21 @@ mod tests {
365369
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
366370
vec![Arc::new(ResolveUnqualifiedIdentifer)];
367371

368-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
369-
let statement = parse(sql).expect("Failed to parse").remove(0);
370-
371-
let statement = rewrite(statement, &rules);
372-
assert_eq!(
373-
statement.to_string(),
372+
assert_rewrite!(
373+
&rules,
374+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname",
374375
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
375376
);
376377

377-
let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
378-
let statement = parse(sql).expect("Failed to parse").remove(0);
379-
380-
let statement = rewrite(statement, &rules);
381-
assert_eq!(
382-
statement.to_string(),
378+
assert_rewrite!(
379+
&rules,
380+
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname",
383381
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
384382
);
385383

386-
let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
387-
let statement = parse(sql).expect("Failed to parse").remove(0);
388-
389-
let statement = rewrite(statement, &rules);
390-
assert_eq!(
391-
statement.to_string(),
384+
assert_rewrite!(
385+
&rules,
386+
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname",
392387
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
393388
);
394389
}
@@ -398,21 +393,27 @@ mod tests {
398393
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
399394
vec![Arc::new(RemoveUnsupportedTypes::new())];
400395

401-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
402-
let statement = parse(sql).expect("Failed to parse").remove(0);
403-
404-
let statement = rewrite(statement, &rules);
405-
assert_eq!(
406-
statement.to_string(),
396+
assert_rewrite!(
397+
&rules,
398+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
407399
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
408400
);
409401

410-
let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
411-
let statement = parse(sql).expect("Failed to parse").remove(0);
402+
assert_rewrite!(
403+
&rules,
404+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname",
405+
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname"
406+
);
407+
408+
assert_rewrite!(
409+
&rules,
410+
"SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname",
411+
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname"
412+
);
412413

413-
let statement = rewrite(statement, &rules);
414-
assert_eq!(
415-
statement.to_string(),
414+
assert_rewrite!(
415+
&rules,
416+
"SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname",
416417
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
417418
);
418419
}

0 commit comments

Comments
 (0)