From d20457b282682f863c73a9017f99a3a656f8837f Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 15 Sep 2025 22:59:43 +0800 Subject: [PATCH 1/9] feat: implement first \d table query --- datafusion-postgres/src/handlers.rs | 8 ++++-- datafusion-postgres/src/sql.rs | 44 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index b41ccd5..307ce8a 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -4,9 +4,9 @@ use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, - CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, PrependUnqualifiedPgTableName, - RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, - RewriteArrayAnyAllOperation, SqlStatementRewriteRule, + CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, FixCollate, + PrependUnqualifiedPgTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes, + ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; @@ -111,6 +111,7 @@ impl DfSessionService { Arc::new(FixArrayLiteral), Arc::new(RemoveTableFunctionQualifier), Arc::new(CurrentUserVariableToSessionUserFunctionCall), + Arc::new(FixCollate), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), @@ -464,6 +465,7 @@ impl SimpleQueryHandler for DfSessionService { // Attempt to rewrite statement = rewrite(statement, &self.sql_rewrite_rules); + dbg!(&statement); // TODO: improve statement check by using statement directly let query = statement.to_string(); diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 736769c..9760c9f 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -664,6 +664,43 @@ impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall { } } +/// Fix collate and regex calls +#[derive(Debug)] +pub struct FixCollate; + +struct FixCollateVisitor; + +impl VisitorMut for FixCollateVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::Collate { expr: inner, .. } => { + *expr = inner.as_ref().clone(); + } + Expr::BinaryOp { op, .. } => { + if let BinaryOperator::PGCustomBinaryOperator(ops) = op { + if *ops == ["pg_catalog", "~"] { + *op = BinaryOperator::PGRegexMatch; + } + } + } + _ => {} + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for FixCollate { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = FixCollateVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -869,4 +906,11 @@ mod tests { "SELECT is_null(session_user)" ); } + + #[test] + fn test_collate_fix() { + let rules: Vec> = vec![Arc::new(FixCollate)]; + + assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3"); + } } From de3abe2a5ad7d99f3f9b9d61c78792ece939855f Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 19 Sep 2025 11:40:54 -0700 Subject: [PATCH 2/9] feat: update rules WIP --- datafusion-postgres/src/pg_catalog.rs | 53 +++-- .../src/pg_catalog/format_type.rs | 196 ++++++++++++++++++ datafusion-postgres/src/sql.rs | 42 +++- datafusion-postgres/tests/psql.rs | 105 ++++++++++ 4 files changed, 372 insertions(+), 24 deletions(-) create mode 100644 datafusion-postgres/src/pg_catalog/format_type.rs create mode 100644 datafusion-postgres/tests/psql.rs diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index b82f2ea..1588c22 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -23,6 +23,7 @@ use crate::pg_catalog::catalog_info::CatalogInfo; pub mod catalog_info; pub mod empty_table; +pub mod format_type; pub mod has_privilege_udf; pub mod pg_attribute; pub mod pg_class; @@ -835,29 +836,13 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF { ) } -pub fn create_format_type_udf() -> ScalarUDF { - let func = move |args: &[ColumnarValue]| { - let args = ColumnarValue::values_to_arrays(args)?; - let type_oids = &args[0]; // Table (can be name or OID) - let _type_mods = &args[1]; // Privilege type (SELECT, INSERT, etc.) - - // For now, always return true (full access for current user) - let mut builder = StringBuilder::new(); - for _ in 0..type_oids.len() { - builder.append_value("???"); - } - - let array: ArrayRef = Arc::new(builder.finish()); - - Ok(ColumnarValue::Array(array)) - }; - +pub fn create_format_type_udf(name: &str) -> ScalarUDF { create_udf( - "format_type", + name, vec![DataType::Int64, DataType::Int32], DataType::Utf8, Volatility::Stable, - Arc::new(func), + Arc::new(format_type::format_type_impl), ) } @@ -905,6 +890,30 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { ) } +pub fn create_pg_relation_is_publishable_udf(name: &str) -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let oid = &args[0]; + + let mut builder = BooleanBuilder::new(); + for _ in 0..oid.len() { + builder.append_value(true); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + name, + vec![DataType::Int32], + DataType::Boolean, + Volatility::Stable, + Arc::new(func), + ) +} + /// Install pg_catalog and postgres UDFs to current `SessionContext` pub fn setup_pg_catalog( session_context: &SessionContext, @@ -951,11 +960,15 @@ pub fn setup_pg_catalog( )); session_context.register_udf(create_pg_table_is_visible("pg_table_is_visible")); session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible")); - session_context.register_udf(create_format_type_udf()); + session_context.register_udf(create_format_type_udf("format_type")); + session_context.register_udf(create_format_type_udf("pg_catalog.format_type")); session_context.register_udf(create_session_user_udf()); session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf()); session_context.register_udf(create_pg_get_partkeydef_udf()); + session_context.register_udf(create_pg_relation_is_publishable_udf( + "pg_catalog.pg_relation_is_publishable", + )); Ok(()) } diff --git a/datafusion-postgres/src/pg_catalog/format_type.rs b/datafusion-postgres/src/pg_catalog/format_type.rs new file mode 100644 index 0000000..0126af6 --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/format_type.rs @@ -0,0 +1,196 @@ +use std::sync::Arc; + +use datafusion::{ + arrow::array::{Array, StringBuilder}, + common::{cast::as_int32_array, DataFusionError}, + logical_expr::ColumnarValue, +}; + +pub(crate) fn format_type_impl(args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + let type_oids = as_int32_array(&args[0])?; + + let typemods = if args.len() > 1 { + Some(as_int32_array(&args[1])?) + } else { + None + }; + + let mut result = StringBuilder::new(); + + for i in 0..type_oids.len() { + if type_oids.is_null(i) { + result.append_null(); + continue; + } + + let type_oid = type_oids.value(i); + let typemod = typemods + .map(|tm| if tm.is_null(i) { -1 } else { tm.value(i) }) + .unwrap_or(-1); + + let formatted_type = format_postgres_type(type_oid, typemod); + result.append_value(formatted_type); + } + + Ok(ColumnarValue::Array(Arc::new(result.finish()))) +} + +/// Format PostgreSQL type based on OID and type modifier +fn format_postgres_type(type_oid: i32, typemod: i32) -> String { + match type_oid { + // Core types + 16 => "boolean".to_string(), + 17 => "bytea".to_string(), + 18 => "\"char\"".to_string(), // Note: quoted to distinguish from char(n) + 19 => "name".to_string(), + 20 => "bigint".to_string(), + 21 => "smallint".to_string(), + 23 => "integer".to_string(), + 24 => "regproc".to_string(), + 25 => "text".to_string(), + 26 => "oid".to_string(), + 27 => "tid".to_string(), + 28 => "xid".to_string(), + 29 => "cid".to_string(), + + // JSON types + 114 => "json".to_string(), + 3802 => "jsonb".to_string(), + + // Numeric types + 700 => "real".to_string(), + 701 => "double precision".to_string(), + + // Character types with length + 1042 => { + // char(n) + if typemod > 4 { + format!("character({})", typemod - 4) + } else { + "character".to_string() + } + } + 1043 => { + // varchar(n) + if typemod > 4 { + format!("character varying({})", typemod - 4) + } else { + "character varying".to_string() + } + } + + // Numeric with precision/scale + 1700 => { + // numeric/decimal + if typemod >= 4 { + let precision = ((typemod - 4) >> 16) & 0xffff; + let scale = (typemod - 4) & 0xffff; + if scale > 0 { + format!("numeric({},{})", precision, scale) + } else { + format!("numeric({})", precision) + } + } else { + "numeric".to_string() + } + } + + // Date/Time types + 1082 => "date".to_string(), + 1083 => { + // time without time zone + if typemod >= 0 { + format!("time({}) without time zone", typemod) + } else { + "time without time zone".to_string() + } + } + 1114 => { + // timestamp without time zone + if typemod >= 0 { + format!("timestamp({}) without time zone", typemod) + } else { + "timestamp without time zone".to_string() + } + } + 1184 => { + // timestamp with time zone + if typemod >= 0 { + format!("timestamp({}) with time zone", typemod) + } else { + "timestamp with time zone".to_string() + } + } + 1266 => { + // time with time zone + if typemod >= 0 { + format!("time({}) with time zone", typemod) + } else { + "time with time zone".to_string() + } + } + 1186 => "interval".to_string(), + + // Bit types + 1560 => { + // bit + if typemod > 0 { + format!("bit({})", typemod) + } else { + "bit".to_string() + } + } + 1562 => { + // bit varying + if typemod > 0 { + format!("bit varying({})", typemod) + } else { + "bit varying".to_string() + } + } + + // UUID + 2950 => "uuid".to_string(), + + // Arrays (append [] to base type) + oid if oid > 0 => { + // For array types, we need to look up the base type + // This is a simplified approach - in practice you'd query pg_types + if let Some(base_oid) = get_array_base_type(oid) { + format!("{}[]", format_postgres_type(base_oid, -1)) + } else { + format!("oid({})", oid) + } + } + + // Unknown or invalid OID + _ => format!("oid({})", type_oid), + } +} + +/// Get base type OID for array types +/// This is a simplified mapping - in practice you'd query your pg_types table +fn get_array_base_type(array_oid: i32) -> Option { + match array_oid { + 1000 => Some(16), // _bool -> bool + 1001 => Some(17), // _bytea -> bytea + 1005 => Some(21), // _int2 -> int2 + 1007 => Some(23), // _int4 -> int4 + 1016 => Some(20), // _int8 -> int8 + 1009 => Some(25), // _text -> text + 1014 => Some(1042), // _bpchar -> bpchar + 1015 => Some(1043), // _varchar -> varchar + 1021 => Some(700), // _float4 -> float4 + 1022 => Some(701), // _float8 -> float8 + 1231 => Some(1700), // _numeric -> numeric + 1182 => Some(1082), // _date -> date + 1183 => Some(1083), // _time -> time + 1115 => Some(1114), // _timestamp -> timestamp + 1185 => Some(1184), // _timestamptz -> timestamptz + 199 => Some(114), // _json -> json + 3807 => Some(3802), // _jsonb -> jsonb + 2951 => Some(2950), // _uuid -> uuid + _ => None, + } +} diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 9760c9f..b38b95a 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -280,6 +280,7 @@ impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer { } /// Remove datafusion unsupported type annotations +/// it also removes pg_catalog as qualifier #[derive(Debug)] pub struct RemoveUnsupportedTypes { unsupported_types: HashSet, @@ -288,10 +289,17 @@ pub struct RemoveUnsupportedTypes { impl RemoveUnsupportedTypes { pub fn new() -> Self { let mut unsupported_types = HashSet::new(); - unsupported_types.insert("regclass".to_owned()); - unsupported_types.insert("regproc".to_owned()); - unsupported_types.insert("regtype".to_owned()); - unsupported_types.insert("regtype[]".to_owned()); + + for item in [ + "regclass", + "regproc", + "regtype", + "regtype[]", + "regnamespace", + ] { + unsupported_types.insert(item.to_owned()); + unsupported_types.insert(format!("pg_catalog.{item}")); + } Self { unsupported_types } } @@ -321,6 +329,22 @@ impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> { expr: value, .. } => { + dbg!(&data_type); + // rewrite custom pg_catalog. qualified types + let data_type_str = data_type.to_string(); + match data_type_str.as_str() { + "pg_catalog.text" => { + *data_type = DataType::Text; + } + "pg_catalog.int2[]" => { + *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket( + Box::new(DataType::Int16), + None, + )); + } + _ => {} + } + if self .unsupported_types .contains(data_type.to_string().to_lowercase().as_str()) @@ -800,6 +824,16 @@ mod tests { "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname", "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" ); + + assert_rewrite!( + &rules, + "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE c.oid = '16386'", + "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'" + ); } #[test] diff --git a/datafusion-postgres/tests/psql.rs b/datafusion-postgres/tests/psql.rs new file mode 100644 index 0000000..f36d2a7 --- /dev/null +++ b/datafusion-postgres/tests/psql.rs @@ -0,0 +1,105 @@ +mod common; + +use common::*; +use pgwire::api::query::SimpleQueryHandler; + +const PSQL_QUERIES: &[&str] = &[ + "SELECT c.oid, + n.nspname, + c.relname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relname OPERATOR(pg_catalog.~) '^(tt)$' COLLATE pg_catalog.default + AND pg_catalog.pg_table_is_visible(c.oid) + ORDER BY 2, 3;", + "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE c.oid = '16384';", + // the query contains all necessary information of columns + "SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) + FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef LIMIT 1), + a.attnotnull, + (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t + WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, + a.attidentity, + a.attgenerated + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum;", + "SELECT pol.polname, pol.polpermissive, + CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END, + pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), + pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid), + CASE pol.polcmd + WHEN 'r' THEN 'SELECT' + WHEN 'a' THEN 'INSERT' + WHEN 'w' THEN 'UPDATE' + WHEN 'd' THEN 'DELETE' + END AS cmd + FROM pg_catalog.pg_policy pol + WHERE pol.polrelid = '16384' ORDER BY 1;", + " SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname, + pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns, + 'd' = any(stxkind) AS ndist_enabled, + 'f' = any(stxkind) AS deps_enabled, + 'm' = any(stxkind) AS mcv_enabled, + stxstattarget + FROM pg_catalog.pg_statistic_ext + WHERE stxrelid = '16384' + ORDER BY nsp, stxname;", + "SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid + JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid + WHERE pc.oid ='16384' and pg_catalog.pg_relation_is_publishable('16384') + UNION + SELECT pubname + , pg_get_expr(pr.prqual, c.oid) + , (CASE WHEN pr.prattrs IS NOT NULL THEN + (SELECT string_agg(attname, ', ') + FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s, + pg_catalog.pg_attribute + WHERE attrelid = pr.prrelid AND attnum = prattrs[s]) + ELSE NULL END) FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid + JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid + WHERE pr.prrelid = '16384' + UNION + SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384') + ORDER BY 1;", + "SELECT c.oid::pg_catalog.regclass + FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i + WHERE c.oid = i.inhparent AND i.inhrelid = '16384' + AND c.relkind != 'p' AND c.relkind != 'I' + ORDER BY inhseqno;", + "SELECT c.oid::pg_catalog.regclass, c.relkind, inhdetachpending, pg_catalog.pg_get_expr(c.relpartbound, c.oid) + FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i + WHERE c.oid = i.inhrelid AND i.inhparent = '16384' + ORDER BY pg_catalog.pg_get_expr(c.relpartbound, c.oid) = 'DEFAULT', c.oid::pg_catalog.regclass::pg_catalog.text;", +]; + +#[tokio::test] +pub async fn test_psql_startup_sql() { + env_logger::init(); + let service = setup_handlers(); + let mut client = MockClient::new(); + + for query in PSQL_QUERIES { + SimpleQueryHandler::do_query(&service, &mut client, query) + .await + .unwrap_or_else(|e| { + panic!("failed to run sql:\n--------------\n {query}\n--------------\n{e}") + }); + } +} From 3ec415bca0068b540161719c0b09ca68ec4cc1f5 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 19 Sep 2025 15:12:05 -0700 Subject: [PATCH 3/9] feat: implement more psql \d queries --- datafusion-postgres/src/handlers.rs | 10 +- datafusion-postgres/src/pg_catalog.rs | 105 ++++++------- .../src/pg_catalog/pg_get_expr_udf.rs | 4 +- datafusion-postgres/src/sql.rs | 147 +++++++++++++----- datafusion-postgres/src/sql/blacklist.rs | 79 ++++++++++ datafusion-postgres/tests/dbeaver.rs | 2 +- datafusion-postgres/tests/psql.rs | 11 +- 7 files changed, 254 insertions(+), 104 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 307ce8a..a5d387f 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -5,8 +5,9 @@ use crate::auth::{AuthManager, Permission, ResourceType}; use crate::sql::{ parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, FixCollate, - PrependUnqualifiedPgTableName, RemoveTableFunctionQualifier, RemoveUnsupportedTypes, - ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, SqlStatementRewriteRule, + PrependUnqualifiedPgTableName, RemoveQualifier, RemoveSubqueryFromProjection, + RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, + SqlStatementRewriteRule, }; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; @@ -105,13 +106,14 @@ impl DfSessionService { Arc::new(BlacklistSqlRewriter::new()), Arc::new(AliasDuplicatedProjectionRewrite), Arc::new(ResolveUnqualifiedIdentifer), - Arc::new(RemoveUnsupportedTypes::new()), Arc::new(RewriteArrayAnyAllOperation), Arc::new(PrependUnqualifiedPgTableName), + Arc::new(RemoveQualifier), + Arc::new(RemoveUnsupportedTypes::new()), Arc::new(FixArrayLiteral), - Arc::new(RemoveTableFunctionQualifier), Arc::new(CurrentUserVariableToSessionUserFunctionCall), Arc::new(FixCollate), + Arc::new(RemoveSubqueryFromProjection), ]; let parser = Arc::new(Parser { session_context: session_context.clone(), diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 1588c22..7fae49c 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -686,7 +686,7 @@ impl PgCatalogStaticTables { } } -pub fn create_current_schemas_udf(name: &str) -> ScalarUDF { +pub fn create_current_schemas_udf() -> ScalarUDF { // Define the function implementation let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -709,7 +709,7 @@ pub fn create_current_schemas_udf(name: &str) -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - name, + "current_schemas", vec![DataType::Boolean], DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))), Volatility::Immutable, @@ -717,7 +717,7 @@ pub fn create_current_schemas_udf(name: &str) -> ScalarUDF { ) } -pub fn create_current_schema_udf(name: &str) -> ScalarUDF { +pub fn create_current_schema_udf() -> ScalarUDF { // Define the function implementation let func = move |_args: &[ColumnarValue]| { // Create a UTF8 array with a single value @@ -730,7 +730,7 @@ pub fn create_current_schema_udf(name: &str) -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - name, + "current_schema", vec![], DataType::Utf8, Volatility::Immutable, @@ -738,7 +738,7 @@ pub fn create_current_schema_udf(name: &str) -> ScalarUDF { ) } -pub fn create_current_database_udf(name: &str) -> ScalarUDF { +pub fn create_current_database_udf() -> ScalarUDF { // Define the function implementation let func = move |_args: &[ColumnarValue]| { // Create a UTF8 array with a single value @@ -751,30 +751,7 @@ pub fn create_current_database_udf(name: &str) -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - name, - vec![], - DataType::Utf8, - Volatility::Immutable, - Arc::new(func), - ) -} - -pub fn create_version_udf() -> ScalarUDF { - // Define the function implementation - let func = move |_args: &[ColumnarValue]| { - // Create a UTF8 array with version information - let mut builder = StringBuilder::new(); - // TODO: improve version string generation - builder - .append_value("DataFusion PostgreSQL 48.0.0 on x86_64-pc-linux-gnu, compiled by Rust"); - let array: ArrayRef = Arc::new(builder.finish()); - - Ok(ColumnarValue::Array(array)) - }; - - // Wrap the implementation in a scalar function - create_udf( - "version", + "current_database", vec![], DataType::Utf8, Volatility::Immutable, @@ -801,7 +778,7 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - "pg_catalog.pg_get_userbyid", + "pg_get_userbyid", vec![DataType::Int32], DataType::Utf8, Volatility::Stable, @@ -809,7 +786,7 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF { ) } -pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF { +pub fn create_pg_table_is_visible() -> ScalarUDF { // Define the function implementation let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -828,7 +805,7 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF { // Wrap the implementation in a scalar function create_udf( - name, + "pg_table_is_visible", vec![DataType::Int32], DataType::Boolean, Volatility::Stable, @@ -836,9 +813,9 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF { ) } -pub fn create_format_type_udf(name: &str) -> ScalarUDF { +pub fn create_format_type_udf() -> ScalarUDF { create_udf( - name, + "format_type", vec![DataType::Int64, DataType::Int32], DataType::Utf8, Volatility::Stable, @@ -882,7 +859,7 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { }; create_udf( - "pg_catalog.pg_get_partkeydef", + "pg_get_partkeydef", vec![DataType::Utf8], DataType::Utf8, Volatility::Stable, @@ -890,7 +867,7 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF { ) } -pub fn create_pg_relation_is_publishable_udf(name: &str) -> ScalarUDF { +pub fn create_pg_relation_is_publishable_udf() -> ScalarUDF { let func = move |args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; let oid = &args[0]; @@ -906,7 +883,7 @@ pub fn create_pg_relation_is_publishable_udf(name: &str) -> ScalarUDF { }; create_udf( - name, + "pg_relation_is_publishable", vec![DataType::Int32], DataType::Boolean, Volatility::Stable, @@ -914,6 +891,30 @@ pub fn create_pg_relation_is_publishable_udf(name: &str) -> ScalarUDF { ) } +pub fn create_pg_get_statisticsobjdef_columns_udf() -> ScalarUDF { + let func = move |args: &[ColumnarValue]| { + let args = ColumnarValue::values_to_arrays(args)?; + let oid = &args[0]; + + let mut builder = BooleanBuilder::new(); + for _ in 0..oid.len() { + builder.append_null(); + } + + let array: ArrayRef = Arc::new(builder.finish()); + + Ok(ColumnarValue::Array(array)) + }; + + create_udf( + "pg_get_statisticsobjdef_columns", + vec![DataType::UInt32], + DataType::Utf8, + Volatility::Stable, + Arc::new(func), + ) +} + /// Install pg_catalog and postgres UDFs to current `SessionContext` pub fn setup_pg_catalog( session_context: &SessionContext, @@ -933,42 +934,28 @@ pub fn setup_pg_catalog( })? .register_schema("pg_catalog", Arc::new(pg_catalog))?; - session_context.register_udf(create_current_database_udf("current_database")); - session_context.register_udf(create_current_schema_udf("current_schema")); - session_context.register_udf(create_current_schema_udf("pg_catalog.current_schema")); - session_context.register_udf(create_current_schemas_udf("current_schemas")); - session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas")); - session_context.register_udf(create_version_udf()); + session_context.register_udf(create_current_database_udf()); + session_context.register_udf(create_current_schema_udf()); + session_context.register_udf(create_current_schemas_udf()); + // session_context.register_udf(create_version_udf()); session_context.register_udf(create_pg_get_userbyid_udf()); session_context.register_udf(has_privilege_udf::create_has_privilege_udf( "has_table_privilege", )); - session_context.register_udf(has_privilege_udf::create_has_privilege_udf( - "pg_catalog.has_table_privilege", - )); session_context.register_udf(has_privilege_udf::create_has_privilege_udf( "has_schema_privilege", )); - session_context.register_udf(has_privilege_udf::create_has_privilege_udf( - "pg_catalog.has_schema_privilege", - )); session_context.register_udf(has_privilege_udf::create_has_privilege_udf( "has_any_column_privilege", )); - session_context.register_udf(has_privilege_udf::create_has_privilege_udf( - "pg_catalog.has_any_column_privilege", - )); - session_context.register_udf(create_pg_table_is_visible("pg_table_is_visible")); - session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible")); - session_context.register_udf(create_format_type_udf("format_type")); - session_context.register_udf(create_format_type_udf("pg_catalog.format_type")); + session_context.register_udf(create_pg_table_is_visible()); + session_context.register_udf(create_format_type_udf()); session_context.register_udf(create_session_user_udf()); session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf()); session_context.register_udf(create_pg_get_partkeydef_udf()); - session_context.register_udf(create_pg_relation_is_publishable_udf( - "pg_catalog.pg_relation_is_publishable", - )); + session_context.register_udf(create_pg_relation_is_publishable_udf()); + session_context.register_udf(create_pg_get_statisticsobjdef_columns_udf()); Ok(()) } diff --git a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs index d6672cd..5997273 100644 --- a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs +++ b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs @@ -30,12 +30,12 @@ impl PgGetExprUDF { ], Volatility::Stable, ), - name: "pg_catalog.pg_get_expr", + name: "pg_get_expr", } } pub fn into_scalar_udf(self) -> ScalarUDF { - ScalarUDF::new_from_impl(self).with_aliases(vec!["pg_get_expr"]) + ScalarUDF::new_from_impl(self) } } diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index b38b95a..74fdaf7 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -296,6 +296,7 @@ impl RemoveUnsupportedTypes { "regtype", "regtype[]", "regnamespace", + "oid", ] { unsupported_types.insert(item.to_owned()); unsupported_types.insert(format!("pg_catalog.{item}")); @@ -329,22 +330,6 @@ impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> { expr: value, .. } => { - dbg!(&data_type); - // rewrite custom pg_catalog. qualified types - let data_type_str = data_type.to_string(); - match data_type_str.as_str() { - "pg_catalog.text" => { - *data_type = DataType::Text; - } - "pg_catalog.int2[]" => { - *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket( - Box::new(DataType::Int16), - None, - )); - } - _ => {} - } - if self .unsupported_types .contains(data_type.to_string().to_lowercase().as_str()) @@ -493,14 +478,17 @@ impl VisitorMut for PrependUnqualifiedPgTableNameVisitor { &mut self, table_factor: &mut TableFactor, ) -> ControlFlow { - if let TableFactor::Table { name, .. } = table_factor { - if name.0.len() == 1 { - let ObjectNamePart::Identifier(ident) = &name.0[0]; - if ident.value.starts_with("pg_") { - *name = ObjectName(vec![ - ObjectNamePart::Identifier(Ident::new("pg_catalog")), - name.0[0].clone(), - ]); + if let TableFactor::Table { name, args, .. } = table_factor { + // not a table function + if args.is_none() { + if name.0.len() == 1 { + let ObjectNamePart::Identifier(ident) = &name.0[0]; + if ident.value.starts_with("pg_") { + *name = ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + name.0[0].clone(), + ]); + } } } } @@ -599,21 +587,25 @@ impl SqlStatementRewriteRule for FixArrayLiteral { } } -/// Remove qualifier from table function +/// Remove qualifier from unsupported items /// -/// The query engine doesn't support qualified table function name +/// This rewriter removes qualifier from following items: +/// 1. type cast: for example: `pg_catalog.text` +/// 2. function name: for example: `pg_catalog.array_to_string`, +/// 3. table function name #[derive(Debug)] -pub struct RemoveTableFunctionQualifier; +pub struct RemoveQualifier; -struct RemoveTableFunctionQualifierVisitor; +struct RemoveQualifierVisitor; -impl VisitorMut for RemoveTableFunctionQualifierVisitor { +impl VisitorMut for RemoveQualifierVisitor { type Break = (); fn pre_visit_table_factor( &mut self, table_factor: &mut TableFactor, ) -> ControlFlow { + // remove table function qualifier if let TableFactor::Table { name, args, .. } = table_factor { if args.is_some() { // multiple idents in name, which means it's a qualified table name @@ -626,11 +618,44 @@ impl VisitorMut for RemoveTableFunctionQualifierVisitor { } ControlFlow::Continue(()) } + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::Cast { data_type, .. } => { + // rewrite custom pg_catalog. qualified types + let data_type_str = data_type.to_string(); + match data_type_str.as_str() { + "pg_catalog.text" => { + *data_type = DataType::Text; + } + "pg_catalog.int2[]" => { + *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket( + Box::new(DataType::Int16), + None, + )); + } + _ => {} + } + } + Expr::Function(function) => { + // remove qualifier from pg_catalog.function + let name = &mut function.name; + if name.0.len() > 1 { + if let Some(last_ident) = name.0.pop() { + *name = ObjectName(vec![last_ident]); + } + } + } + + _ => {} + } + ControlFlow::Continue(()) + } } -impl SqlStatementRewriteRule for RemoveTableFunctionQualifier { +impl SqlStatementRewriteRule for RemoveQualifier { fn rewrite(&self, mut s: Statement) -> Statement { - let mut visitor = RemoveTableFunctionQualifierVisitor; + let mut visitor = RemoveQualifierVisitor; let _ = s.visit(&mut visitor); s @@ -725,6 +750,47 @@ impl SqlStatementRewriteRule for FixCollate { } } +/// Datafusion doesn't support subquery on projection +#[derive(Debug)] +pub struct RemoveSubqueryFromProjection; + +struct RemoveSubqueryFromProjectionVisitor; + +impl VisitorMut for RemoveSubqueryFromProjectionVisitor { + type Break = (); + + fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow { + if let SetExpr::Select(select) = query.body.as_mut() { + for projection in &mut select.projection { + match projection { + SelectItem::UnnamedExpr(expr) => { + if let Expr::Subquery(_) = expr { + *expr = Expr::Value(Value::Null.with_empty_span()); + } + } + SelectItem::ExprWithAlias { expr, .. } => { + if let Expr::Subquery(_) = expr { + *expr = Expr::Value(Value::Null.with_empty_span()); + } + } + _ => {} + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveSubqueryFromProjection { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RemoveSubqueryFromProjectionVisitor; + let _ = s.visit(&mut visitor); + + s + } +} + #[cfg(test)] mod tests { use super::*; @@ -798,8 +864,10 @@ mod tests { #[test] fn test_remove_unsupported_types() { - let rules: Vec> = - vec![Arc::new(RemoveUnsupportedTypes::new())]; + let rules: Vec> = vec![ + Arc::new(RemoveQualifier), + Arc::new(RemoveUnsupportedTypes::new()), + ]; assert_rewrite!( &rules, @@ -915,8 +983,7 @@ mod tests { #[test] fn test_remove_qualifier_from_table_function() { - let rules: Vec> = - vec![Arc::new(RemoveTableFunctionQualifier)]; + let rules: Vec> = vec![Arc::new(RemoveQualifier)]; assert_rewrite!( &rules, @@ -947,4 +1014,14 @@ mod tests { assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3"); } + + #[test] + fn test_remove_subquery() { + let rules: Vec> = + vec![Arc::new(RemoveSubqueryFromProjection)]; + + assert_rewrite!(&rules, + "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;", + "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum"); + } } diff --git a/datafusion-postgres/src/sql/blacklist.rs b/datafusion-postgres/src/sql/blacklist.rs index 1757e35..6bf69e2 100644 --- a/datafusion-postgres/src/sql/blacklist.rs +++ b/datafusion-postgres/src/sql/blacklist.rs @@ -68,7 +68,86 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ "SELECT NULL::TEXT AS schema_name, NULL::TEXT AS type_name WHERE false" ), +// psql \d queries + ( +"SELECT pol.polname, pol.polpermissive, + CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END, + pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), + pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid), + CASE pol.polcmd + WHEN 'r' THEN 'SELECT' + WHEN 'a' THEN 'INSERT' + WHEN 'w' THEN 'UPDATE' + WHEN 'd' THEN 'DELETE' + END AS cmd + FROM pg_catalog.pg_policy pol + WHERE pol.polrelid = '$1' ORDER BY 1;", +"SELECT + NULL::TEXT AS polname, + NULL::TEXT AS polpermissive, + NULL::TEXT AS array_to_string, + NULL::TEXT AS pg_get_expr_1, + NULL::TEXT AS pg_get_expr_2, + NULL::TEXT AS cmd + WHERE false" + ), + + ( +"SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname, + pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns, + 'd' = any(stxkind) AS ndist_enabled, + 'f' = any(stxkind) AS deps_enabled, + 'm' = any(stxkind) AS mcv_enabled, + stxstattarget + FROM pg_catalog.pg_statistic_ext + WHERE stxrelid = '$1' + ORDER BY nsp, stxname;", +"SELECT + NULL::INT32 AS oid, + NULL::TEXT AS stxrelid, + NULL::TEXT AS nsp, + NULL::TEXT AS stxname, + NULL::TEXT AS columns, + NULL::BOOLEAN AS ndist_enabled, + NULL::BOOLEAN AS deps_enabled, + NULL::BOOLEAN AS mcv_enabled, + NULL::TEXT AS stxstattarget + WHERE false" + ), + ( +"SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid + JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid + WHERE pc.oid ='$1' and pg_catalog.pg_relation_is_publishable('$1') + UNION + SELECT pubname + , pg_get_expr(pr.prqual, c.oid) + , (CASE WHEN pr.prattrs IS NOT NULL THEN + (SELECT string_agg(attname, ', ') + FROM pg_catalog.generate_series(0, pg_catalog.array_upper(pr.prattrs::pg_catalog.int2[], 1)) s, + pg_catalog.pg_attribute + WHERE attrelid = pr.prrelid AND attnum = prattrs[s]) + ELSE NULL END) FROM pg_catalog.pg_publication p + JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid + JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid + WHERE pr.prrelid = '$1' + UNION + SELECT pubname + , NULL + , NULL + FROM pg_catalog.pg_publication p + WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('$1') + ORDER BY 1;", +"SELECT + NULL::TEXT AS pubname, + NULL::TEXT AS _1, + NULL::TEXT AS _2, + WHERE false" + ), ]; /// A blacklist based sql rewrite, when the input matches, return the output diff --git a/datafusion-postgres/tests/dbeaver.rs b/datafusion-postgres/tests/dbeaver.rs index 24e5ab8..f99817e 100644 --- a/datafusion-postgres/tests/dbeaver.rs +++ b/datafusion-postgres/tests/dbeaver.rs @@ -34,6 +34,6 @@ pub async fn test_dbeaver_startup_sql() { for query in DBEAVER_QUERIES { SimpleQueryHandler::do_query(&service, &mut client, query) .await - .unwrap_or_else(|_| panic!("failed to run sql: {query}")); + .unwrap_or_else(|e| panic!("failed to run sql: {query}\n{e}")); } } diff --git a/datafusion-postgres/tests/psql.rs b/datafusion-postgres/tests/psql.rs index f36d2a7..2a87a13 100644 --- a/datafusion-postgres/tests/psql.rs +++ b/datafusion-postgres/tests/psql.rs @@ -22,15 +22,16 @@ const PSQL_QUERIES: &[&str] = &[ pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef LIMIT 1), + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t - WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, + WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;", + // the following queries should return empty results at least for now "SELECT pol.polname, pol.polpermissive, CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END, pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), @@ -43,7 +44,8 @@ const PSQL_QUERIES: &[&str] = &[ END AS cmd FROM pg_catalog.pg_policy pol WHERE pol.polrelid = '16384' ORDER BY 1;", - " SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname, + + "SELECT oid, stxrelid::pg_catalog.regclass, stxnamespace::pg_catalog.regnamespace::pg_catalog.text AS nsp, stxname, pg_catalog.pg_get_statisticsobjdef_columns(oid) AS columns, 'd' = any(stxkind) AS ndist_enabled, 'f' = any(stxkind) AS deps_enabled, @@ -52,6 +54,7 @@ const PSQL_QUERIES: &[&str] = &[ FROM pg_catalog.pg_statistic_ext WHERE stxrelid = '16384' ORDER BY nsp, stxname;", + "SELECT pubname , NULL , NULL @@ -78,11 +81,13 @@ const PSQL_QUERIES: &[&str] = &[ FROM pg_catalog.pg_publication p WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('16384') ORDER BY 1;", + "SELECT c.oid::pg_catalog.regclass FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i WHERE c.oid = i.inhparent AND i.inhrelid = '16384' AND c.relkind != 'p' AND c.relkind != 'I' ORDER BY inhseqno;", + "SELECT c.oid::pg_catalog.regclass, c.relkind, inhdetachpending, pg_catalog.pg_get_expr(c.relpartbound, c.oid) FROM pg_catalog.pg_class c, pg_catalog.pg_inherits i WHERE c.oid = i.inhrelid AND i.inhparent = '16384' From f71981bbaaa31d781d2d7e431d7bcc8a53195ab3 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 00:54:14 +0800 Subject: [PATCH 4/9] feat: add support psql \d queries --- datafusion-postgres/src/handlers.rs | 51 ++---- datafusion-postgres/src/pg_catalog.rs | 2 +- datafusion-postgres/src/sql.rs | 41 ++--- datafusion-postgres/src/sql/blacklist.rs | 203 +++++++++++++++++++---- 4 files changed, 211 insertions(+), 86 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index a5d387f..e23a3a5 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -2,13 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; -use crate::sql::{ - parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter, - CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, FixCollate, - PrependUnqualifiedPgTableName, RemoveQualifier, RemoveSubqueryFromProjection, - RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer, RewriteArrayAnyAllOperation, - SqlStatementRewriteRule, -}; +use crate::sql::PostgresCompatibilityParser; use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::ToDFSchema; @@ -92,7 +86,6 @@ pub struct DfSessionService { parser: Arc, timezone: Arc>, auth_manager: Arc, - sql_rewrite_rules: Vec>, } impl DfSessionService { @@ -100,31 +93,15 @@ impl DfSessionService { session_context: Arc, auth_manager: Arc, ) -> DfSessionService { - let sql_rewrite_rules: Vec> = vec![ - // make sure blacklist based rewriter it on the top to prevent sql - // being rewritten from other rewriters - Arc::new(BlacklistSqlRewriter::new()), - Arc::new(AliasDuplicatedProjectionRewrite), - Arc::new(ResolveUnqualifiedIdentifer), - Arc::new(RewriteArrayAnyAllOperation), - Arc::new(PrependUnqualifiedPgTableName), - Arc::new(RemoveQualifier), - Arc::new(RemoveUnsupportedTypes::new()), - Arc::new(FixArrayLiteral), - Arc::new(CurrentUserVariableToSessionUserFunctionCall), - Arc::new(FixCollate), - Arc::new(RemoveSubqueryFromProjection), - ]; let parser = Arc::new(Parser { session_context: session_context.clone(), - sql_rewrite_rules: sql_rewrite_rules.clone(), + sql_parser: PostgresCompatibilityParser::new(), }); DfSessionService { session_context, parser, timezone: Arc::new(Mutex::new("UTC".to_string())), auth_manager, - sql_rewrite_rules, } } @@ -460,14 +437,14 @@ impl SimpleQueryHandler for DfSessionService { return Ok(vec![resp]); } - let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let mut statements = self + .parser + .sql_parser + .parse(query) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // TODO: deal with multiple statements - let mut statement = statements.remove(0); - - // Attempt to rewrite - statement = rewrite(statement, &self.sql_rewrite_rules); - dbg!(&statement); + let statement = statements.remove(0); // TODO: improve statement check by using statement directly let query = statement.to_string(); @@ -721,7 +698,7 @@ impl ExtendedQueryHandler for DfSessionService { pub struct Parser { session_context: Arc, - sql_rewrite_rules: Vec>, + sql_parser: PostgresCompatibilityParser, } impl Parser { @@ -794,11 +771,11 @@ impl QueryParser for Parser { return Ok((sql.to_string(), plan)); } - let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let mut statement = statements.remove(0); - - // Attempt to rewrite - statement = rewrite(statement, &self.sql_rewrite_rules); + let mut statements = self + .sql_parser + .parse(sql) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let statement = statements.remove(0); let query = statement.to_string(); diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 7fae49c..62b69cd 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -816,7 +816,7 @@ pub fn create_pg_table_is_visible() -> ScalarUDF { pub fn create_format_type_udf() -> ScalarUDF { create_udf( "format_type", - vec![DataType::Int64, DataType::Int32], + vec![DataType::Int32, DataType::Int32], DataType::Utf8, Volatility::Stable, Arc::new(format_type::format_type_impl), diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 74fdaf7..51c8fcb 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; +use std::fmt::Debug; use std::ops::ControlFlow; -use std::sync::Arc; use datafusion::sql::sqlparser::ast::Array; use datafusion::sql::sqlparser::ast::ArrayElemTypeDef; @@ -30,28 +30,11 @@ use datafusion::sql::sqlparser::ast::Value; use datafusion::sql::sqlparser::ast::ValueWithSpan; use datafusion::sql::sqlparser::ast::VisitMut; use datafusion::sql::sqlparser::ast::VisitorMut; -use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion::sql::sqlparser::parser::Parser; -use datafusion::sql::sqlparser::parser::ParserError; mod blacklist; -pub use blacklist::BlacklistSqlRewriter; +pub use blacklist::PostgresCompatibilityParser; -pub fn parse(sql: &str) -> Result, ParserError> { - let dialect = PostgreSqlDialect {}; - - Parser::parse_sql(&dialect, sql) -} - -pub fn rewrite(mut s: Statement, rules: &[Arc]) -> Statement { - for rule in rules { - s = rule.rewrite(s); - } - - s -} - -pub trait SqlStatementRewriteRule: Send + Sync { +pub trait SqlStatementRewriteRule: Send + Sync + Debug { fn rewrite(&self, s: Statement) -> Statement; } @@ -794,6 +777,24 @@ impl SqlStatementRewriteRule for RemoveSubqueryFromProjection { #[cfg(test)] mod tests { use super::*; + use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; + use datafusion::sql::sqlparser::parser::Parser; + use datafusion::sql::sqlparser::parser::ParserError; + use std::sync::Arc; + + fn parse(sql: &str) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + + Parser::parse_sql(&dialect, sql) + } + + fn rewrite(mut s: Statement, rules: &[Arc]) -> Statement { + for rule in rules { + s = rule.rewrite(s); + } + + s + } macro_rules! assert_rewrite { ($rules:expr, $orig:expr, $rewt:expr) => { diff --git a/datafusion-postgres/src/sql/blacklist.rs b/datafusion-postgres/src/sql/blacklist.rs index 6bf69e2..86e5b3f 100644 --- a/datafusion-postgres/src/sql/blacklist.rs +++ b/datafusion-postgres/src/sql/blacklist.rs @@ -1,9 +1,23 @@ -use std::collections::HashMap; +use std::sync::Arc; use datafusion::sql::sqlparser::ast::Statement; +use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion::sql::sqlparser::parser::Parser; +use datafusion::sql::sqlparser::parser::ParserError; +use datafusion::sql::sqlparser::tokenizer::Token; +use datafusion::sql::sqlparser::tokenizer::TokenWithSpan; -use super::parse; -use super::SqlStatementRewriteRule; +use crate::sql::AliasDuplicatedProjectionRewrite; +use crate::sql::CurrentUserVariableToSessionUserFunctionCall; +use crate::sql::FixArrayLiteral; +use crate::sql::FixCollate; +use crate::sql::PrependUnqualifiedPgTableName; +use crate::sql::RemoveQualifier; +use crate::sql::RemoveSubqueryFromProjection; +use crate::sql::RemoveUnsupportedTypes; +use crate::sql::ResolveUnqualifiedIdentifer; +use crate::sql::RewriteArrayAnyAllOperation; +use crate::sql::SqlStatementRewriteRule; const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ // pgcli startup query @@ -81,7 +95,7 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ WHEN 'd' THEN 'DELETE' END AS cmd FROM pg_catalog.pg_policy pol - WHERE pol.polrelid = '$1' ORDER BY 1;", + WHERE pol.polrelid = $1 ORDER BY 1;", "SELECT NULL::TEXT AS polname, NULL::TEXT AS polpermissive, @@ -100,10 +114,10 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ 'm' = any(stxkind) AS mcv_enabled, stxstattarget FROM pg_catalog.pg_statistic_ext - WHERE stxrelid = '$1' + WHERE stxrelid = $1 ORDER BY nsp, stxname;", "SELECT - NULL::INT32 AS oid, + NULL::INT AS oid, NULL::TEXT AS stxrelid, NULL::TEXT AS nsp, NULL::TEXT AS stxname, @@ -122,7 +136,7 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ FROM pg_catalog.pg_publication p JOIN pg_catalog.pg_publication_namespace pn ON p.oid = pn.pnpubid JOIN pg_catalog.pg_class pc ON pc.relnamespace = pn.pnnspid - WHERE pc.oid ='$1' and pg_catalog.pg_relation_is_publishable('$1') + WHERE pc.oid = $1 and pg_catalog.pg_relation_is_publishable($1) UNION SELECT pubname , pg_get_expr(pr.prqual, c.oid) @@ -134,50 +148,183 @@ const BLACKLIST_SQL_MAPPING: &[(&str, &str)] = &[ ELSE NULL END) FROM pg_catalog.pg_publication p JOIN pg_catalog.pg_publication_rel pr ON p.oid = pr.prpubid JOIN pg_catalog.pg_class c ON c.oid = pr.prrelid - WHERE pr.prrelid = '$1' + WHERE pr.prrelid = $1 UNION SELECT pubname , NULL , NULL FROM pg_catalog.pg_publication p - WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable('$1') + WHERE p.puballtables AND pg_catalog.pg_relation_is_publishable($1) ORDER BY 1;", "SELECT NULL::TEXT AS pubname, NULL::TEXT AS _1, - NULL::TEXT AS _2, + NULL::TEXT AS _2 WHERE false" ), ]; -/// A blacklist based sql rewrite, when the input matches, return the output +/// A parser with Postgres Compatibility for Datafusion /// -/// This rewriter is for those complex but meaningless queries we won't spend -/// effort to rewrite to datafusion supported version in near future. +/// This parser will try its best to rewrite postgres SQL into a form that +/// datafuiosn supports. It also maintains a blacklist that will transform the +/// statement to a similar version if rewrite doesn't worth the effort for now. #[derive(Debug)] -pub struct BlacklistSqlRewriter(HashMap); +pub struct PostgresCompatibilityParser { + blacklist: Vec<(Vec, Statement)>, + rewrite_rules: Vec>, +} + +impl PostgresCompatibilityParser { + pub fn new() -> Self { + let mut mapping = Vec::with_capacity(BLACKLIST_SQL_MAPPING.len()); + + for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING { + mapping.push(( + Parser::new(&PostgreSqlDialect {}) + .try_with_sql(sql_from) + .unwrap() + .into_tokens() + .into_iter() + .map(|t| t.token) + .filter(|t| !matches!(t, Token::Whitespace(_) | Token::SemiColon)) + .collect(), + Parser::new(&PostgreSqlDialect {}) + .try_with_sql(sql_to) + .unwrap() + .parse_statement() + .unwrap(), + )); + } + + Self { + blacklist: mapping, + rewrite_rules: vec![ + // make sure blacklist based rewriter it on the top to prevent sql + // being rewritten from other rewriters + Arc::new(AliasDuplicatedProjectionRewrite), + Arc::new(ResolveUnqualifiedIdentifer), + Arc::new(RewriteArrayAnyAllOperation), + Arc::new(PrependUnqualifiedPgTableName), + Arc::new(RemoveQualifier), + Arc::new(RemoveUnsupportedTypes::new()), + Arc::new(FixArrayLiteral), + Arc::new(CurrentUserVariableToSessionUserFunctionCall), + Arc::new(FixCollate), + Arc::new(RemoveSubqueryFromProjection), + ], + } + } + + /// return statement if matched + fn parse_and_replace(&self, input: &str) -> Result { + let parser = Parser::new(&PostgreSqlDialect {}); + let tokens = parser.try_with_sql(input)?.into_tokens(); + + let tokens_without_whitespace = tokens + .iter() + .filter(|t| !matches!(t.token, Token::Whitespace(_) | Token::SemiColon)) + .collect::>(); + + for (blacklisted_sql_tokens, replacement) in &self.blacklist { + if blacklisted_sql_tokens.len() == tokens_without_whitespace.len() { + let matches = blacklisted_sql_tokens + .iter() + .zip(tokens_without_whitespace.iter()) + .all(|(a, b)| { + if matches!(a, Token::Placeholder(_)) { + true + } else { + *a == b.token + } + }); + if matches { + return Ok(MatchResult::Matches(replacement.clone())); + } + } else { + continue; + } + } + + Ok(MatchResult::Unmatches(tokens)) + } + + fn parse_tokens(&self, tokens: Vec) -> Result, ParserError> { + let parser = Parser::new(&PostgreSqlDialect {}); + parser.with_tokens_with_locations(tokens).parse_statements() + } + + pub fn parse(&self, input: &str) -> Result, ParserError> { + let statements = match self.parse_and_replace(input)? { + MatchResult::Matches(statement) => vec![statement], + MatchResult::Unmatches(tokens) => self.parse_tokens(tokens)?, + }; -impl SqlStatementRewriteRule for BlacklistSqlRewriter { - fn rewrite(&self, mut s: Statement) -> Statement { - if let Some(stmt) = self.0.get(&s) { - s = stmt.clone(); + let statements = statements.into_iter().map(|s| self.rewrite(s)).collect(); + + Ok(statements) + } + + pub fn rewrite(&self, mut s: Statement) -> Statement { + for rule in &self.rewrite_rules { + s = rule.rewrite(s); } s } } -impl BlacklistSqlRewriter { - pub(crate) fn new() -> BlacklistSqlRewriter { - let mut mapping = HashMap::new(); +pub(crate) enum MatchResult { + Matches(Statement), + Unmatches(Vec), +} - for (sql_from, sql_to) in BLACKLIST_SQL_MAPPING { - mapping.insert( - parse(sql_from).unwrap().remove(0), - parse(sql_to).unwrap().remove(0), - ); - } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sql_mapping() { + let sql = "SELECT pol.polname, pol.polpermissive, + CASE WHEN pol.polroles = '{0}' THEN NULL ELSE pg_catalog.array_to_string(array(select rolname from pg_catalog.pg_roles where oid = any (pol.polroles) order by 1),',') END, + pg_catalog.pg_get_expr(pol.polqual, pol.polrelid), + pg_catalog.pg_get_expr(pol.polwithcheck, pol.polrelid), + CASE pol.polcmd + WHEN 'r' THEN 'SELECT' + WHEN 'a' THEN 'INSERT' + WHEN 'w' THEN 'UPDATE' + WHEN 'd' THEN 'DELETE' + END AS cmd + FROM pg_catalog.pg_policy pol + WHERE pol.polrelid = '16384' ORDER BY 1;"; + + let parser = PostgresCompatibilityParser::new(); + let match_result = parser.parse_and_replace(sql).expect("failed to parse sql"); + assert!(matches!(match_result, MatchResult::Matches(_))); + + let sql = "SELECT n.nspname schema_name, + t.typname type_name + FROM pg_catalog.pg_type t + INNER JOIN pg_catalog.pg_namespace n + ON n.oid = t.typnamespace + WHERE ( t.typrelid = 0 -- non-composite types + OR ( -- composite type, but not a table + SELECT c.relkind = 'c' + FROM pg_catalog.pg_class c + WHERE c.oid = t.typrelid + ) + ) + AND NOT EXISTS( -- ignore array types + SELECT 1 + FROM pg_catalog.pg_type el + WHERE el.oid = t.typelem AND el.typarray = t.oid + ) + AND n.nspname <> 'pg_catalog' + AND n.nspname <> 'information_schema' + ORDER BY 1, 2"; - Self(mapping) + let parser = PostgresCompatibilityParser::new(); + let match_result = parser.parse_and_replace(sql).expect("failed to parse sql"); + assert!(matches!(match_result, MatchResult::Matches(_))); } } From b49beca6d9fc8da22c345e7cb41161d8ca444d04 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 01:04:13 +0800 Subject: [PATCH 5/9] feat: make format_type work for both int64/int32 --- datafusion-postgres/src/pg_catalog.rs | 12 +--- .../src/pg_catalog/format_type.rs | 59 ++++++++++++++++++- .../src/pg_catalog/pg_get_expr_udf.rs | 4 +- 3 files changed, 59 insertions(+), 16 deletions(-) diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 62b69cd..196a836 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -813,16 +813,6 @@ pub fn create_pg_table_is_visible() -> ScalarUDF { ) } -pub fn create_format_type_udf() -> ScalarUDF { - create_udf( - "format_type", - vec![DataType::Int32, DataType::Int32], - DataType::Utf8, - Volatility::Stable, - Arc::new(format_type::format_type_impl), - ) -} - pub fn create_session_user_udf() -> ScalarUDF { let func = move |_args: &[ColumnarValue]| { let mut builder = StringBuilder::new(); @@ -949,7 +939,7 @@ pub fn setup_pg_catalog( "has_any_column_privilege", )); session_context.register_udf(create_pg_table_is_visible()); - session_context.register_udf(create_format_type_udf()); + session_context.register_udf(format_type::create_format_type_udf()); session_context.register_udf(create_session_user_udf()); session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone()); session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf()); diff --git a/datafusion-postgres/src/pg_catalog/format_type.rs b/datafusion-postgres/src/pg_catalog/format_type.rs index 0126af6..2c953e8 100644 --- a/datafusion-postgres/src/pg_catalog/format_type.rs +++ b/datafusion-postgres/src/pg_catalog/format_type.rs @@ -1,9 +1,15 @@ use std::sync::Arc; use datafusion::{ - arrow::array::{Array, StringBuilder}, + arrow::{ + array::{Array, StringBuilder}, + datatypes::DataType, + }, common::{cast::as_int32_array, DataFusionError}, - logical_expr::ColumnarValue, + logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, + }, }; pub(crate) fn format_type_impl(args: &[ColumnarValue]) -> Result { @@ -194,3 +200,52 @@ fn get_array_base_type(array_oid: i32) -> Option { _ => None, } } + +#[derive(Debug)] +pub struct FormatTypeUDF { + signature: Signature, +} + +impl FormatTypeUDF { + pub(crate) fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int32, DataType::Int32]), + ], + Volatility::Stable, + ), + } + } + + pub fn into_scalar_udf(self) -> ScalarUDF { + ScalarUDF::new_from_impl(self) + } +} + +impl ScalarUDFImpl for FormatTypeUDF { + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn name(&self) -> &str { + "format_type" + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + format_type_impl(&args.args) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +pub fn create_format_type_udf() -> ScalarUDF { + FormatTypeUDF::new().into_scalar_udf() +} diff --git a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs index 5997273..25f993e 100644 --- a/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs +++ b/datafusion-postgres/src/pg_catalog/pg_get_expr_udf.rs @@ -11,7 +11,6 @@ use datafusion::{ #[derive(Debug)] pub struct PgGetExprUDF { signature: Signature, - name: &'static str, } impl PgGetExprUDF { @@ -30,7 +29,6 @@ impl PgGetExprUDF { ], Volatility::Stable, ), - name: "pg_get_expr", } } @@ -49,7 +47,7 @@ impl ScalarUDFImpl for PgGetExprUDF { } fn name(&self) -> &str { - self.name + "pg_get_expr" } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { From 19b7a08f133401218965e55e0353bf2738dd2d23 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 01:08:41 +0800 Subject: [PATCH 6/9] refactor: split and rename sql module --- datafusion-postgres/src/sql.rs | 1031 +---------------- .../src/sql/{blacklist.rs => parser.rs} | 22 +- datafusion-postgres/src/sql/rules.rs | 1025 ++++++++++++++++ 3 files changed, 1039 insertions(+), 1039 deletions(-) rename datafusion-postgres/src/sql/{blacklist.rs => parser.rs} (96%) create mode 100644 datafusion-postgres/src/sql/rules.rs diff --git a/datafusion-postgres/src/sql.rs b/datafusion-postgres/src/sql.rs index 51c8fcb..42c6a17 100644 --- a/datafusion-postgres/src/sql.rs +++ b/datafusion-postgres/src/sql.rs @@ -1,1028 +1,3 @@ -use std::collections::HashSet; -use std::fmt::Debug; -use std::ops::ControlFlow; - -use datafusion::sql::sqlparser::ast::Array; -use datafusion::sql::sqlparser::ast::ArrayElemTypeDef; -use datafusion::sql::sqlparser::ast::BinaryOperator; -use datafusion::sql::sqlparser::ast::CastKind; -use datafusion::sql::sqlparser::ast::DataType; -use datafusion::sql::sqlparser::ast::Expr; -use datafusion::sql::sqlparser::ast::Function; -use datafusion::sql::sqlparser::ast::FunctionArg; -use datafusion::sql::sqlparser::ast::FunctionArgExpr; -use datafusion::sql::sqlparser::ast::FunctionArgumentList; -use datafusion::sql::sqlparser::ast::FunctionArguments; -use datafusion::sql::sqlparser::ast::Ident; -use datafusion::sql::sqlparser::ast::ObjectName; -use datafusion::sql::sqlparser::ast::ObjectNamePart; -use datafusion::sql::sqlparser::ast::OrderByKind; -use datafusion::sql::sqlparser::ast::Query; -use datafusion::sql::sqlparser::ast::Select; -use datafusion::sql::sqlparser::ast::SelectItem; -use datafusion::sql::sqlparser::ast::SelectItemQualifiedWildcardKind; -use datafusion::sql::sqlparser::ast::SetExpr; -use datafusion::sql::sqlparser::ast::Statement; -use datafusion::sql::sqlparser::ast::TableFactor; -use datafusion::sql::sqlparser::ast::TableWithJoins; -use datafusion::sql::sqlparser::ast::UnaryOperator; -use datafusion::sql::sqlparser::ast::Value; -use datafusion::sql::sqlparser::ast::ValueWithSpan; -use datafusion::sql::sqlparser::ast::VisitMut; -use datafusion::sql::sqlparser::ast::VisitorMut; - -mod blacklist; -pub use blacklist::PostgresCompatibilityParser; - -pub trait SqlStatementRewriteRule: Send + Sync + Debug { - fn rewrite(&self, s: Statement) -> Statement; -} - -/// Rewrite rule for adding alias to duplicated projection -/// -/// This rule is to deal with sql like `SELECT n.oid, n.* FROM n`, which is a -/// valid statement in postgres. But datafusion treat it as illegal because of -/// duplicated column oid in projection. -/// -/// This rule will add alias to column, when there is a wildcard found in -/// projection. -#[derive(Debug)] -pub struct AliasDuplicatedProjectionRewrite; - -impl AliasDuplicatedProjectionRewrite { - // Rewrites a SELECT statement to alias explicit columns from the same table as a qualified wildcard. - fn rewrite_select_with_alias(select: &mut Box) { + // 1. Collect all table aliases from qualified wildcards. + let mut wildcard_tables = Vec::new(); + let mut has_simple_wildcard = false; + for p in &select.projection { + match p { + SelectItem::QualifiedWildcard(name, _) => match name { + SelectItemQualifiedWildcardKind::ObjectName(objname) => { + // for n.oid, + let idents = objname + .0 + .iter() + .map(|v| v.as_ident().unwrap().value.clone()) + .collect::>() + .join("."); + + wildcard_tables.push(idents); + } + SelectItemQualifiedWildcardKind::Expr(_expr) => { + // FIXME: + } + }, + SelectItem::Wildcard(_) => { + has_simple_wildcard = true; + } + _ => {} + } + } + + // If there are no qualified wildcards, there's nothing to do. + if wildcard_tables.is_empty() && !has_simple_wildcard { + return; + } + + // 2. Rewrite the projection, adding aliases to matching columns. + let mut new_projection = vec![]; + for p in select.projection.drain(..) { + match p { + SelectItem::UnnamedExpr(expr) => { + let alias_partial = match &expr { + // Case for `oid` (unqualified identifier) + Expr::Identifier(ident) => Some(ident.clone()), + // Case for `n.oid` (compound identifier) + Expr::CompoundIdentifier(idents) => { + // compare every ident but the last + if idents.len() > 1 { + let table_name = &idents[..idents.len() - 1] + .iter() + .map(|i| i.value.clone()) + .collect::>() + .join("."); + if wildcard_tables.iter().any(|name| name == table_name) { + Some(idents[idents.len() - 1].clone()) + } else { + None + } + } else { + None + } + } + _ => None, + }; + + if let Some(name) = alias_partial { + let alias = format!("__alias_{name}"); + new_projection.push(SelectItem::ExprWithAlias { + expr, + alias: Ident::new(alias), + }); + } else { + new_projection.push(SelectItem::UnnamedExpr(expr)); + } + } + // Preserve existing aliases and wildcards. + _ => new_projection.push(p), + } + } + select.projection = new_projection; + } +} + +impl SqlStatementRewriteRule for AliasDuplicatedProjectionRewrite { + fn rewrite(&self, mut statement: Statement) -> Statement { + if let Statement::Query(query) = &mut statement { + if let SetExpr::Select(select) = query.body.as_mut() { + Self::rewrite_select_with_alias(select); + } + } + + statement + } +} + +/// Prepend qualifier for order by or filter when there is qualified wildcard +/// +/// Postgres allows unqualified identifier in ORDER BY and FILTER but it's not +/// accepted by datafusion. +#[derive(Debug)] +pub struct ResolveUnqualifiedIdentifer; + +impl ResolveUnqualifiedIdentifer { + fn rewrite_unqualified_identifiers(query: &mut Box) { + if let SetExpr::Select(select) = query.body.as_mut() { + // Step 1: Find all table aliases from FROM and JOIN clauses. + let table_aliases = Self::get_table_aliases(&select.from); + + // Step 2: Check for a single qualified wildcard in the projection. + let qualified_wildcard_alias = Self::get_qualified_wildcard_alias(&select.projection); + if qualified_wildcard_alias.is_none() || table_aliases.is_empty() { + return; // Conditions not met. + } + + let wildcard_alias = qualified_wildcard_alias.unwrap(); + + // Step 3: Rewrite expressions in the WHERE and ORDER BY clauses. + if let Some(selection) = &mut select.selection { + Self::rewrite_expr(selection, &wildcard_alias, &table_aliases); + } + + if let Some(OrderByKind::Expressions(order_by_exprs)) = + query.order_by.as_mut().map(|o| &mut o.kind) + { + for order_by_expr in order_by_exprs { + Self::rewrite_expr(&mut order_by_expr.expr, &wildcard_alias, &table_aliases); + } + } + } + } + + fn get_table_aliases(tables: &[TableWithJoins]) -> HashSet { + let mut aliases = HashSet::new(); + for table_with_joins in tables { + if let TableFactor::Table { + alias: Some(alias), .. + } = &table_with_joins.relation + { + aliases.insert(alias.name.value.clone()); + } + for join in &table_with_joins.joins { + if let TableFactor::Table { + alias: Some(alias), .. + } = &join.relation + { + aliases.insert(alias.name.value.clone()); + } + } + } + aliases + } + + fn get_qualified_wildcard_alias(projection: &[SelectItem]) -> Option { + let mut qualified_wildcards = projection + .iter() + .filter_map(|item| { + if let SelectItem::QualifiedWildcard( + SelectItemQualifiedWildcardKind::ObjectName(objname), + _, + ) = item + { + Some( + objname + .0 + .iter() + .map(|v| v.as_ident().unwrap().value.clone()) + .collect::>() + .join("."), + ) + } else { + None + } + }) + .collect::>(); + + if qualified_wildcards.len() == 1 { + Some(qualified_wildcards.remove(0)) + } else { + None + } + } + + fn rewrite_expr(expr: &mut Expr, wildcard_alias: &str, table_aliases: &HashSet) { + match expr { + Expr::Identifier(ident) => { + // If the identifier is not a table alias itself, rewrite it. + if !table_aliases.contains(&ident.value) { + *expr = Expr::CompoundIdentifier(vec![ + Ident::new(wildcard_alias.to_string()), + ident.clone(), + ]); + } + } + Expr::BinaryOp { left, right, .. } => { + Self::rewrite_expr(left, wildcard_alias, table_aliases); + Self::rewrite_expr(right, wildcard_alias, table_aliases); + } + // Add more cases for other expression types as needed (e.g., `InList`, `Between`, etc.) + _ => {} + } + } +} + +impl SqlStatementRewriteRule for ResolveUnqualifiedIdentifer { + fn rewrite(&self, mut statement: Statement) -> Statement { + if let Statement::Query(query) = &mut statement { + Self::rewrite_unqualified_identifiers(query); + } + + statement + } +} + +/// Remove datafusion unsupported type annotations +/// it also removes pg_catalog as qualifier +#[derive(Debug)] +pub struct RemoveUnsupportedTypes { + unsupported_types: HashSet, +} + +impl RemoveUnsupportedTypes { + pub fn new() -> Self { + let mut unsupported_types = HashSet::new(); + + for item in [ + "regclass", + "regproc", + "regtype", + "regtype[]", + "regnamespace", + "oid", + ] { + unsupported_types.insert(item.to_owned()); + unsupported_types.insert(format!("pg_catalog.{item}")); + } + + Self { unsupported_types } + } +} + +struct RemoveUnsupportedTypesVisitor<'a> { + unsupported_types: &'a HashSet, +} + +impl VisitorMut for RemoveUnsupportedTypesVisitor<'_> { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + // This is the key part: identify constants with type annotations. + Expr::TypedString { value, data_type } => { + if self + .unsupported_types + .contains(data_type.to_string().to_lowercase().as_str()) + { + *expr = + Expr::Value(Value::SingleQuotedString(value.to_string()).with_empty_span()); + } + } + Expr::Cast { + data_type, + expr: value, + .. + } => { + if self + .unsupported_types + .contains(data_type.to_string().to_lowercase().as_str()) + { + *expr = *value.clone(); + } + } + // Add more match arms for other expression types (e.g., `Function`, `InList`) as needed. + _ => {} + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveUnsupportedTypes { + fn rewrite(&self, mut statement: Statement) -> Statement { + let mut visitor = RemoveUnsupportedTypesVisitor { + unsupported_types: &self.unsupported_types, + }; + let _ = statement.visit(&mut visitor); + statement + } +} + +/// Rewrite Postgres's ANY operator to array_contains +#[derive(Debug)] +pub struct RewriteArrayAnyAllOperation; + +struct RewriteArrayAnyAllOperationVisitor; + +impl RewriteArrayAnyAllOperationVisitor { + fn any_to_array_cofntains(&self, left: &Expr, right: &Expr) -> Expr { + let array = if let Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(array_literal), + .. + }) = right + { + let array_literal = array_literal.trim(); + if array_literal.starts_with('{') && array_literal.ends_with('}') { + let items = array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' '); + let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()); + + // For now, we assume the data type is string + let elems = items + .map(|s| { + Expr::Value(Value::SingleQuotedString(s.to_string()).with_empty_span()) + }) + .collect(); + Expr::Array(Array { + elem: elems, + named: true, + }) + } else { + right.clone() + } + } else { + right.clone() + }; + + Expr::Function(Function { + name: ObjectName::from(vec![Ident::new("array_contains")]), + args: FunctionArguments::List(FunctionArgumentList { + args: vec![ + FunctionArg::Unnamed(FunctionArgExpr::Expr(array)), + FunctionArg::Unnamed(FunctionArgExpr::Expr(left.clone())), + ], + duplicate_treatment: None, + clauses: vec![], + }), + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }) + } +} + +impl VisitorMut for RewriteArrayAnyAllOperationVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::AnyOp { + left, + compare_op, + right, + .. + } => match compare_op { + BinaryOperator::Eq => { + *expr = self.any_to_array_cofntains(left.as_ref(), right.as_ref()); + } + BinaryOperator::NotEq => { + // TODO:left not equals to any element in array + } + _ => {} + }, + Expr::AllOp { + left, + compare_op, + right, + } => match compare_op { + BinaryOperator::Eq => { + // TODO: left equals to every element in array + } + BinaryOperator::NotEq => { + *expr = Expr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(self.any_to_array_cofntains(left.as_ref(), right.as_ref())), + } + } + _ => {} + }, + _ => {} + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RewriteArrayAnyAllOperation { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RewriteArrayAnyAllOperationVisitor; + + let _ = s.visit(&mut visitor); + + s + } +} + +/// Prepend qualifier to table_name +/// +/// Postgres has pg_catalog in search_path by default so it allow access to +/// `pg_namespace` without `pg_catalog.` qualifier +#[derive(Debug)] +pub struct PrependUnqualifiedPgTableName; + +struct PrependUnqualifiedPgTableNameVisitor; + +impl VisitorMut for PrependUnqualifiedPgTableNameVisitor { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + if let TableFactor::Table { name, args, .. } = table_factor { + // not a table function + if args.is_none() { + if name.0.len() == 1 { + let ObjectNamePart::Identifier(ident) = &name.0[0]; + if ident.value.starts_with("pg_") { + *name = ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + name.0[0].clone(), + ]); + } + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for PrependUnqualifiedPgTableName { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = PrependUnqualifiedPgTableNameVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +#[derive(Debug)] +pub struct FixArrayLiteral; + +struct FixArrayLiteralVisitor; + +impl FixArrayLiteralVisitor { + fn is_string_type(dt: &DataType) -> bool { + matches!( + dt, + DataType::Text | DataType::Varchar(_) | DataType::Char(_) | DataType::String(_) + ) + } +} + +impl VisitorMut for FixArrayLiteralVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Cast { + kind, + expr, + data_type, + .. + } = expr + { + if kind == &CastKind::DoubleColon { + if let DataType::Array(arr) = data_type { + // cast some to + if let Expr::Value(ValueWithSpan { + value: Value::SingleQuotedString(array_literal), + .. + }) = expr.as_ref() + { + let items = + array_literal.trim_matches(|c| c == '{' || c == '}' || c == ' '); + let items = items.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()); + + let is_text = match arr { + ArrayElemTypeDef::AngleBracket(dt) => Self::is_string_type(dt.as_ref()), + ArrayElemTypeDef::SquareBracket(dt, _) => { + Self::is_string_type(dt.as_ref()) + } + ArrayElemTypeDef::Parenthesis(dt) => Self::is_string_type(dt.as_ref()), + _ => false, + }; + + let elems = items + .map(|s| { + if is_text { + Expr::Value( + Value::SingleQuotedString(s.to_string()).with_empty_span(), + ) + } else { + Expr::Value( + Value::Number(s.to_string(), false).with_empty_span(), + ) + } + }) + .collect(); + *expr = Box::new(Expr::Array(Array { + elem: elems, + named: true, + })); + } + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for FixArrayLiteral { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = FixArrayLiteralVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +/// Remove qualifier from unsupported items +/// +/// This rewriter removes qualifier from following items: +/// 1. type cast: for example: `pg_catalog.text` +/// 2. function name: for example: `pg_catalog.array_to_string`, +/// 3. table function name +#[derive(Debug)] +pub struct RemoveQualifier; + +struct RemoveQualifierVisitor; + +impl VisitorMut for RemoveQualifierVisitor { + type Break = (); + + fn pre_visit_table_factor( + &mut self, + table_factor: &mut TableFactor, + ) -> ControlFlow { + // remove table function qualifier + if let TableFactor::Table { name, args, .. } = table_factor { + if args.is_some() { + // multiple idents in name, which means it's a qualified table name + if name.0.len() > 1 { + if let Some(last_ident) = name.0.pop() { + *name = ObjectName(vec![last_ident]); + } + } + } + } + ControlFlow::Continue(()) + } + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::Cast { data_type, .. } => { + // rewrite custom pg_catalog. qualified types + let data_type_str = data_type.to_string(); + match data_type_str.as_str() { + "pg_catalog.text" => { + *data_type = DataType::Text; + } + "pg_catalog.int2[]" => { + *data_type = DataType::Array(ArrayElemTypeDef::SquareBracket( + Box::new(DataType::Int16), + None, + )); + } + _ => {} + } + } + Expr::Function(function) => { + // remove qualifier from pg_catalog.function + let name = &mut function.name; + if name.0.len() > 1 { + if let Some(last_ident) = name.0.pop() { + *name = ObjectName(vec![last_ident]); + } + } + } + + _ => {} + } + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveQualifier { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RemoveQualifierVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +/// Replace `current_user` with `session_user()` +#[derive(Debug)] +pub struct CurrentUserVariableToSessionUserFunctionCall; + +struct CurrentUserVariableToSessionUserFunctionCallVisitor; + +impl VisitorMut for CurrentUserVariableToSessionUserFunctionCallVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Identifier(ident) = expr { + if ident.quote_style.is_none() && ident.value.to_lowercase() == "current_user" { + *expr = Expr::Function(Function { + name: ObjectName::from(vec![Ident::new("session_user")]), + args: FunctionArguments::None, + uses_odbc_syntax: false, + parameters: FunctionArguments::None, + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }); + } + } + + if let Expr::Function(func) = expr { + let fname = func + .name + .0 + .iter() + .map(|ident| ident.to_string()) + .collect::>() + .join("."); + if fname.to_lowercase() == "current_user" { + func.name = ObjectName::from(vec![Ident::new("session_user")]) + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for CurrentUserVariableToSessionUserFunctionCall { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = CurrentUserVariableToSessionUserFunctionCallVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +/// Fix collate and regex calls +#[derive(Debug)] +pub struct FixCollate; + +struct FixCollateVisitor; + +impl VisitorMut for FixCollateVisitor { + type Break = (); + + fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + match expr { + Expr::Collate { expr: inner, .. } => { + *expr = inner.as_ref().clone(); + } + Expr::BinaryOp { op, .. } => { + if let BinaryOperator::PGCustomBinaryOperator(ops) = op { + if *ops == ["pg_catalog", "~"] { + *op = BinaryOperator::PGRegexMatch; + } + } + } + _ => {} + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for FixCollate { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = FixCollateVisitor; + + let _ = s.visit(&mut visitor); + s + } +} + +/// Datafusion doesn't support subquery on projection +#[derive(Debug)] +pub struct RemoveSubqueryFromProjection; + +struct RemoveSubqueryFromProjectionVisitor; + +impl VisitorMut for RemoveSubqueryFromProjectionVisitor { + type Break = (); + + fn pre_visit_query(&mut self, query: &mut Query) -> ControlFlow { + if let SetExpr::Select(select) = query.body.as_mut() { + for projection in &mut select.projection { + match projection { + SelectItem::UnnamedExpr(expr) => { + if let Expr::Subquery(_) = expr { + *expr = Expr::Value(Value::Null.with_empty_span()); + } + } + SelectItem::ExprWithAlias { expr, .. } => { + if let Expr::Subquery(_) = expr { + *expr = Expr::Value(Value::Null.with_empty_span()); + } + } + _ => {} + } + } + } + + ControlFlow::Continue(()) + } +} + +impl SqlStatementRewriteRule for RemoveSubqueryFromProjection { + fn rewrite(&self, mut s: Statement) -> Statement { + let mut visitor = RemoveSubqueryFromProjectionVisitor; + let _ = s.visit(&mut visitor); + + s + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::sql::sqlparser::dialect::PostgreSqlDialect; + use datafusion::sql::sqlparser::parser::Parser; + use datafusion::sql::sqlparser::parser::ParserError; + use std::sync::Arc; + + fn parse(sql: &str) -> Result, ParserError> { + let dialect = PostgreSqlDialect {}; + + Parser::parse_sql(&dialect, sql) + } + + fn rewrite(mut s: Statement, rules: &[Arc]) -> Statement { + for rule in rules { + s = rule.rewrite(s); + } + + s + } + + macro_rules! assert_rewrite { + ($rules:expr, $orig:expr, $rewt:expr) => { + let sql = $orig; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, $rules); + assert_eq!(statement.to_string(), $rewt); + }; + } + + #[test] + fn test_alias_rewrite() { + let rules: Vec> = + vec![Arc::new(AliasDuplicatedProjectionRewrite)]; + + assert_rewrite!( + &rules, + "SELECT n.oid, n.* FROM pg_catalog.pg_namespace n", + "SELECT n.oid AS __alias_oid, n.* FROM pg_catalog.pg_namespace AS n" + ); + + assert_rewrite!( + &rules, + "SELECT oid, * FROM pg_catalog.pg_namespace", + "SELECT oid AS __alias_oid, * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id", + "SELECT t1.oid, t2.* FROM tbl1 AS t1 JOIN tbl2 AS t2 ON t1.id = t2.id" + ); + + let sql = "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname"; + let statement = parse(sql).expect("Failed to parse").remove(0); + + let statement = rewrite(statement, &rules); + assert_eq!( + statement.to_string(), + "SELECT n.oid AS __alias_oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspsname" + ); + } + + #[test] + fn test_qualifier_prepend() { + let rules: Vec> = + vec![Arc::new(ResolveUnqualifiedIdentifer)]; + + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE nspname = 'pg_catalog' ORDER BY nspname", + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname", + "SELECT * FROM pg_catalog.pg_namespace ORDER BY nspname" + ); + + assert_rewrite!( + &rules, + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace' ORDER BY nspsname", + "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY n.nspsname" + ); + } + + #[test] + fn test_remove_unsupported_types() { + let rules: Vec> = vec![ + Arc::new(RemoveQualifier), + Arc::new(RemoveUnsupportedTypes::new()), + ]; + + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname", + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.oid = 1 AND n.nspname = 'pg_catalog'::regclass ORDER BY n.nspname", + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.oid = 1 AND n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + assert_rewrite!( + &rules, + "SELECT n.oid,n.*,d.description FROM pg_catalog.pg_namespace n LEFT OUTER JOIN pg_catalog.pg_description d ON d.objoid=n.oid AND d.objsubid=0 AND d.classoid='pg_namespace'::regclass ORDER BY nspname", + "SELECT n.oid, n.*, d.description FROM pg_catalog.pg_namespace AS n LEFT OUTER JOIN pg_catalog.pg_description AS d ON d.objoid = n.oid AND d.objsubid = 0 AND d.classoid = 'pg_namespace' ORDER BY nspname" + ); + + assert_rewrite!( + &rules, + "SELECT n.* FROM pg_catalog.pg_namespace n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname", + "SELECT n.* FROM pg_catalog.pg_namespace AS n WHERE n.nspname = 'pg_catalog' ORDER BY n.nspname" + ); + + assert_rewrite!( + &rules, + "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::pg_catalog.regtype::pg_catalog.text END, c.relpersistence, c.relreplident, am.amname + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE c.oid = '16386'", + "SELECT c.relchecks, c.relkind, c.relhasindex, c.relhasrules, c.relhastriggers, c.relrowsecurity, c.relforcerowsecurity, false AS relhasoids, c.relispartition, '', c.reltablespace, CASE WHEN c.reloftype = 0 THEN '' ELSE c.reloftype::TEXT END, c.relpersistence, c.relreplident, am.amname FROM pg_catalog.pg_class AS c LEFT JOIN pg_catalog.pg_class AS tc ON (c.reltoastrelid = tc.oid) LEFT JOIN pg_catalog.pg_am AS am ON (c.relam = am.oid) WHERE c.oid = '16386'" + ); + } + + #[test] + fn test_any_to_array_contains() { + let rules: Vec> = + vec![Arc::new(RewriteArrayAnyAllOperation)]; + + assert_rewrite!( + &rules, + "SELECT a = ANY(current_schemas(true))", + "SELECT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a <> ALL(current_schemas(true))", + "SELECT NOT array_contains(current_schemas(true), a)" + ); + + assert_rewrite!( + &rules, + "SELECT a = ANY('{r, l, e}')", + "SELECT array_contains(ARRAY['r', 'l', 'e'], a)" + ); + + assert_rewrite!( + &rules, + "SELECT a FROM tbl WHERE a = ANY(current_schemas(true))", + "SELECT a FROM tbl WHERE array_contains(current_schemas(true), a)" + ); + } + + #[test] + fn test_prepend_unqualified_table_name() { + let rules: Vec> = + vec![Arc::new(PrependUnqualifiedPgTableName)]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT * FROM pg_namespace", + "SELECT * FROM pg_catalog.pg_namespace" + ); + + assert_rewrite!( + &rules, + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_namespace as ns ON ns.oid = oid", + "SELECT typtype, typname, pg_type.oid FROM pg_catalog.pg_type LEFT JOIN pg_catalog.pg_namespace AS ns ON ns.oid = oid" + ); + } + + #[test] + fn test_array_literal_fix() { + let rules: Vec> = vec![Arc::new(FixArrayLiteral)]; + + assert_rewrite!( + &rules, + "SELECT '{a, abc}'::text[]", + "SELECT ARRAY['a', 'abc']::TEXT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{1, 2}'::int[]", + "SELECT ARRAY[1, 2]::INT[]" + ); + + assert_rewrite!( + &rules, + "SELECT '{t, f}'::bool[]", + "SELECT ARRAY[t, f]::BOOL[]" + ); + } + + #[test] + fn test_remove_qualifier_from_table_function() { + let rules: Vec> = vec![Arc::new(RemoveQualifier)]; + + assert_rewrite!( + &rules, + "SELECT * FROM pg_catalog.pg_get_keywords()", + "SELECT * FROM pg_get_keywords()" + ); + } + + #[test] + fn test_current_user() { + let rules: Vec> = + vec![Arc::new(CurrentUserVariableToSessionUserFunctionCall)]; + + assert_rewrite!(&rules, "SELECT current_user", "SELECT session_user"); + + assert_rewrite!(&rules, "SELECT CURRENT_USER", "SELECT session_user"); + + assert_rewrite!( + &rules, + "SELECT is_null(current_user)", + "SELECT is_null(session_user)" + ); + } + + #[test] + fn test_collate_fix() { + let rules: Vec> = vec![Arc::new(FixCollate)]; + + assert_rewrite!(&rules, "SELECT c.oid, c.relname FROM pg_catalog.pg_class c WHERE c.relname OPERATOR(pg_catalog.~) '^(tablename)$' COLLATE pg_catalog.default AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3;", "SELECT c.oid, c.relname FROM pg_catalog.pg_class AS c WHERE c.relname ~ '^(tablename)$' AND pg_catalog.pg_table_is_visible(c.oid) ORDER BY 2, 3"); + } + + #[test] + fn test_remove_subquery() { + let rules: Vec> = + vec![Arc::new(RemoveSubqueryFromProjection)]; + + assert_rewrite!(&rules, + "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid, true) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef), a.attnotnull, (SELECT c.collname FROM pg_catalog.pg_collation c, pg_catalog.pg_type t WHERE c.oid = a.attcollation AND t.oid = a.atttypid AND a.attcollation <> t.typcollation LIMIT 1) AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum;", + "SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), NULL, a.attnotnull, NULL AS attcollation, a.attidentity, a.attgenerated FROM pg_catalog.pg_attribute AS a WHERE a.attrelid = '16384' AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum"); + } +} From 766effb4e47de9585076ca2cdbccc11d348fce37 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 01:13:46 +0800 Subject: [PATCH 7/9] fix: lint --- datafusion-postgres/src/sql/parser.rs | 6 +++--- datafusion-postgres/src/sql/rules.rs | 16 +++++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/datafusion-postgres/src/sql/parser.rs b/datafusion-postgres/src/sql/parser.rs index f1aed22..9b8d44c 100644 --- a/datafusion-postgres/src/sql/parser.rs +++ b/datafusion-postgres/src/sql/parser.rs @@ -239,7 +239,7 @@ impl PostgresCompatibilityParser { } }); if matches { - return Ok(MatchResult::Matches(replacement.clone())); + return Ok(MatchResult::Matches(Box::new(replacement.clone()))); } } else { continue; @@ -256,7 +256,7 @@ impl PostgresCompatibilityParser { pub fn parse(&self, input: &str) -> Result, ParserError> { let statements = match self.parse_and_replace(input)? { - MatchResult::Matches(statement) => vec![statement], + MatchResult::Matches(statement) => vec![*statement], MatchResult::Unmatches(tokens) => self.parse_tokens(tokens)?, }; @@ -275,7 +275,7 @@ impl PostgresCompatibilityParser { } pub(crate) enum MatchResult { - Matches(Statement), + Matches(Box), Unmatches(Vec), } diff --git a/datafusion-postgres/src/sql/rules.rs b/datafusion-postgres/src/sql/rules.rs index 596150d..c13284d 100644 --- a/datafusion-postgres/src/sql/rules.rs +++ b/datafusion-postgres/src/sql/rules.rs @@ -460,15 +460,13 @@ impl VisitorMut for PrependUnqualifiedPgTableNameVisitor { ) -> ControlFlow { if let TableFactor::Table { name, args, .. } = table_factor { // not a table function - if args.is_none() { - if name.0.len() == 1 { - let ObjectNamePart::Identifier(ident) = &name.0[0]; - if ident.value.starts_with("pg_") { - *name = ObjectName(vec![ - ObjectNamePart::Identifier(Ident::new("pg_catalog")), - name.0[0].clone(), - ]); - } + if args.is_none() && name.0.len() == 1 { + let ObjectNamePart::Identifier(ident) = &name.0[0]; + if ident.value.starts_with("pg_") { + *name = ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + name.0[0].clone(), + ]); } } } From e4440d6c0a4b9ae1f693b7764d54d2d34e5b1404 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 01:31:13 +0800 Subject: [PATCH 8/9] fix: integration test on version() --- .gitignore | 2 +- tests-integration/test_csv.py | 48 +++++++++++++++++------------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 54271a8..c64ceab 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,4 @@ .envrc .vscode .aider* -/test_env \ No newline at end of file +/tests-integration/test_env diff --git a/tests-integration/test_csv.py b/tests-integration/test_csv.py index 3e036e4..fde5d2b 100644 --- a/tests-integration/test_csv.py +++ b/tests-integration/test_csv.py @@ -5,23 +5,23 @@ def main(): print("šŸ” Testing CSV data loading and PostgreSQL compatibility...") - + conn = psycopg.connect("host=127.0.0.1 port=5433 user=postgres dbname=public") conn.autocommit = True with conn.cursor() as cur: print("\nšŸ“Š Basic Data Access Tests:") test_basic_data_access(cur) - + print("\nšŸ—‚ļø Enhanced pg_catalog Tests:") test_enhanced_pg_catalog(cur) - + print("\nšŸ”§ PostgreSQL Functions Tests:") test_postgresql_functions(cur) - + print("\nšŸ“‹ Table Type Detection Tests:") test_table_type_detection(cur) - + print("\nšŸ” Transaction Integration Tests:") test_transaction_integration(cur) @@ -56,21 +56,21 @@ def test_enhanced_pg_catalog(cur): pg_type_count = cur.fetchone()[0] assert pg_type_count >= 16 print(f" āœ“ pg_catalog.pg_type: {pg_type_count} data types") - + # Test specific data types exist cur.execute("SELECT typname FROM pg_catalog.pg_type WHERE typname IN ('bool', 'int4', 'text', 'float8', 'date') ORDER BY typname") types = [row[0] for row in cur.fetchall()] expected_types = ['bool', 'date', 'float8', 'int4', 'text'] assert all(t in types for t in expected_types) print(f" āœ“ Core PostgreSQL types present: {', '.join(expected_types)}") - + # Test pg_class with proper table types cur.execute("SELECT relname, relkind FROM pg_catalog.pg_class WHERE relname = 'delhi'") result = cur.fetchone() assert result is not None assert result[1] == 'r' # Should be regular table print(f" āœ“ Table type detection: {result[0]} = '{result[1]}' (regular table)") - + # Test pg_attribute has column information cur.execute("SELECT count(*) FROM pg_catalog.pg_attribute WHERE attnum > 0") attr_count = cur.fetchone()[0] @@ -82,21 +82,21 @@ def test_postgresql_functions(cur): # Test version function cur.execute("SELECT version()") version = cur.fetchone()[0] - assert "DataFusion" in version and "PostgreSQL" in version + assert "DataFusion" in version print(f" āœ“ version(): {version[:50]}...") - + # Test current_schema function cur.execute("SELECT current_schema()") schema = cur.fetchone()[0] assert schema == "public" print(f" āœ“ current_schema(): {schema}") - + # Test current_schemas function cur.execute("SELECT current_schemas(false)") schemas = cur.fetchone()[0] assert "public" in schemas print(f" āœ“ current_schemas(): {schemas}") - + # Test has_table_privilege function (2-parameter version) cur.execute("SELECT has_table_privilege('delhi', 'SELECT')") result = cur.fetchone()[0] @@ -106,28 +106,28 @@ def test_postgresql_functions(cur): def test_table_type_detection(cur): """Test table type detection in pg_class.""" cur.execute(""" - SELECT relname, relkind, - CASE relkind + SELECT relname, relkind, + CASE relkind WHEN 'r' THEN 'regular table' - WHEN 'v' THEN 'view' + WHEN 'v' THEN 'view' WHEN 'i' THEN 'index' ELSE 'other' END as description - FROM pg_catalog.pg_class + FROM pg_catalog.pg_class ORDER BY relname """) results = cur.fetchall() - + # Should have multiple tables with proper types table_types = {} for name, kind, desc in results: table_types[name] = (kind, desc) - + # Delhi should be a regular table assert 'delhi' in table_types assert table_types['delhi'][0] == 'r' print(f" āœ“ Delhi table type: {table_types['delhi'][1]}") - + # System tables should also be regular tables system_tables = [name for name, (kind, _) in table_types.items() if name.startswith('pg_')] regular_system_tables = [name for name, (kind, _) in table_types.items() if name.startswith('pg_') and kind == 'r'] @@ -138,19 +138,19 @@ def test_transaction_integration(cur): # Test transaction with data queries cur.execute("BEGIN") print(" āœ“ Transaction started") - + # Execute multiple queries in transaction cur.execute("SELECT count(*) FROM delhi") count1 = cur.fetchone()[0] - + cur.execute("SELECT max(meantemp) FROM delhi") max_temp = cur.fetchone()[0] - + cur.execute("SELECT min(meantemp) FROM delhi") min_temp = cur.fetchone()[0] - + print(f" āœ“ Queries in transaction: {count1} rows, temp range {min_temp}-{max_temp}") - + # Commit transaction cur.execute("COMMIT") print(" āœ“ Transaction committed successfully") From a0ed23e0d7580091089db34b01dadbbec94337df Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 21 Sep 2025 01:39:08 +0800 Subject: [PATCH 9/9] chore: update flake for integration test dependencies --- flake.nix | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flake.nix b/flake.nix index f827612..0136d8b 100644 --- a/flake.nix +++ b/flake.nix @@ -16,10 +16,12 @@ pkgs = nixpkgs.legacyPackages.${system}; pythonEnv = pkgs.python3.withPackages (ps: with ps; [ psycopg2-binary + psycopg pyarrow ]); buildInputs = with pkgs; [ llvmPackages.libclang + libpq ]; in {