Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use std::collections::HashMap;
use std::sync::Arc;

use crate::auth::{AuthManager, Permission, ResourceType};
use crate::sql::{parse, rewrite, AliasDuplicatedProjectionRewrite, SqlStatementRewriteRule};
use crate::sql::{
parse, rewrite, AliasDuplicatedProjectionRewrite, RemoveUnsupportedTypes,
ResolveUnqualifiedIdentifer, SqlStatementRewriteRule,
};
use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
Expand Down Expand Up @@ -73,8 +76,11 @@ impl DfSessionService {
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
) -> DfSessionService {
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(AliasDuplicatedProjectionRewrite)];
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
Arc::new(AliasDuplicatedProjectionRewrite),
Arc::new(ResolveUnqualifiedIdentifer),
Arc::new(RemoveUnsupportedTypes::new()),
];
let parser = Arc::new(Parser {
session_context: session_context.clone(),
sql_rewrite_rules: sql_rewrite_rules.clone(),
Expand Down
89 changes: 89 additions & 0 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use datafusion::sql::sqlparser::ast::SetExpr;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::ast::TableFactor;
use datafusion::sql::sqlparser::ast::TableWithJoins;
use datafusion::sql::sqlparser::ast::Value;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::parser::ParserError;
Expand Down Expand Up @@ -258,6 +259,70 @@ impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
}
}

/// Remove datafusion unsupported type annotations
#[derive(Debug)]
pub struct RemoveUnsupportedTypes {
unsupported_types: HashSet<String>,
}

impl RemoveUnsupportedTypes {
pub fn new() -> Self {
let mut unsupported_types = HashSet::new();
unsupported_types.insert("regclass".to_owned());

Self { unsupported_types }
}

fn rewrite_expr_unsupported_types(&self, expr: &mut Expr) {
match expr {
// This is the key part: identify constants with type annotations.
Expr::TypedString { value, data_type } => {
if self
.unsupported_types
.contains(data_type.to_string().to_lowercase().as_str())
{
*expr =
Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span());
}
}
Expr::Cast {
data_type,
expr: value,
..
} => {
if self
.unsupported_types
.contains(data_type.to_string().to_lowercase().as_str())
{
*expr = *value.clone();
}
}
// Handle binary operations by recursively rewriting both sides.
Expr::BinaryOp { left, right, .. } => {
self.rewrite_expr_unsupported_types(left);
self.rewrite_expr_unsupported_types(right);
}
// Add more match arms for other expression types (e.g., `Function`, `InList`) as needed.
_ => {}
}
}
}

impl SqlStatementRewriteRule for RemoveUnsupportedTypes {
fn rewrite(&self, mut s: Statement) -> Statement {
// Traverse the AST to find the WHERE clause and rewrite it.
if let Statement::Query(query) = &mut s {
if let SetExpr::Select(select) = query.body.as_mut() {
if let Some(expr) = &mut select.selection {
self.rewrite_expr_unsupported_types(expr);
}
}
}

s
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -327,4 +392,28 @@ mod tests {
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
);
}

#[test]
fn test_remove_unsupported_types() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(RemoveUnsupportedTypes::new())];

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);
}
}
Loading