Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions datafusion-postgres/src/sql.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use std::collections::HashSet;
use std::sync::Arc;

use datafusion::sql::sqlparser::ast::Expr;
use datafusion::sql::sqlparser::ast::Ident;
use datafusion::sql::sqlparser::ast::OrderByKind;
use datafusion::sql::sqlparser::ast::Query;
use datafusion::sql::sqlparser::ast::Select;
use datafusion::sql::sqlparser::ast::SelectItem;
use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind;
use datafusion::sql::sqlparser::ast::SetExpr;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion::sql::sqlparser::ast::TableFactor;
use datafusion::sql::sqlparser::ast::TableWithJoins;
use datafusion::sql::sqlparser::dialect::PostgreSqlDialect;
use datafusion::sql::sqlparser::parser::Parser;
use datafusion::sql::sqlparser::parser::ParserError;
Expand Down Expand Up @@ -135,6 +140,124 @@ impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite {
}
}

/// Prepend qualifier for order by or filter when there is qualified wildcard
///
/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not
/// accepted by datafusion.
#[derive(Debug)]
pub struct ResolveUnqualifiedIdentifer;

impl ResolveUnqualifiedIdentifer {
fn rewrite_unqualified_identifiers(query: &mut Box<Query>) {
if let SetExpr::Select(select) = query.body.as_mut() {
// Step 1: Find all table aliases from FROM and JOIN clauses.
let table_aliases = Self::get_table_aliases(&select.from);

// Step 2: Check for a single qualified wildcard in the projection.
let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection);
if qualified_wildcard_alias.is_none() || table_aliases.is_empty() {
return; // Conditions not met.
}

let wildcard_alias = qualified_wildcard_alias.unwrap();

// Step 3: Rewrite expressions in the WHERE and ORDER BY clauses.
if let Some(selection) = &mut select.selection {
Self::rewrite_expr(selection, &wildcard_alias, &table_aliases);
}

if let Some(OrderByKind::Expressions(order_by_exprs)) =
query.order_by.as_mut().map(|o| &mut o.kind)
{
for order_by_expr in order_by_exprs {
Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases);
}
}
}
}

fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet<String> {
let mut aliases = HashSet::new();
for table_with_joins in tables {
if let TableFactor::Table {
alias: Some(alias), ..
} = &table_with_joins.relation
{
aliases.insert(alias.name.value.clone());
}
for join in &table_with_joins.joins {
if let TableFactor::Table {
alias: Some(alias), ..
} = &join.relation
{
aliases.insert(alias.name.value.clone());
}
}
}
aliases
}

fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option<String> {
let mut qualified_wildcards = projection
.iter()
.filter_map(|item| {
if let SelectItem::QualifiedWildcard(
SelectItemQualifiedWildcardKind::ObjectName(objname),
_,
) = item
{
Some(
objname
.0
.iter()
.map(|v| v.as_ident().unwrap().value.clone())
.collect::<Vec<_>>()
.join("."),
)
} else {
None
}
})
.collect::<Vec<_>>();

if qualified_wildcards.len() == 1 {
Some(qualified_wildcards.remove(0))
} else {
None
}
}

fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet<String>) {
match expr {
Expr::Identifier(ident) => {
// If the identifier is not a table alias itself, rewrite it.
if !table_aliases.contains(&ident.value) {
*expr = Expr::CompoundIdentifier(vec![
Ident::new(wildcard_alias.to_string()),
ident.clone(),
]);
}
}
Expr::BinaryOp { left, right, .. } => {
Self::rewrite_expr(left, wildcard_alias, table_aliases);
Self::rewrite_expr(right, wildcard_alias, table_aliases);
}
// Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.)
_ => {}
}
}
}

impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer {
fn rewrite(&self, mut statement: Statement) -> Statement {
if let Statement::Query(query) = &mut statement {
Self::rewrite_unqualified_identifiers(query);
}

statement
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -171,4 +294,37 @@ mod tests {
"SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id"
);
}

#[test]
fn test_qualifier_prepend() {
let rules: Vec<Arc<dyn SqlStatementRewriteRule>> =
vec![Arc::new(ResolveUnqualifiedIdentifer)];

let sql = "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname"
);

let sql = "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname"
);

let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname";
let statement = parse(sql).expect("Failed to parse").remove(0);

let statement = rewrite(statement, &rules);
assert_eq!(
statement.to_string(),
"SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname"
);
}
}
Loading