Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
.envrc
.vscode
.aider*
/test_env
/tests-integration/test_env
47 changes: 14 additions & 33 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use crate::auth::{AuthManager, Permission, ResourceType};
use crate::sql::{
parse, rewrite, AliasDuplicatedProjectionRewrite, BlacklistSqlRewriter,
CurrentUserVariableToSessionUserFunctionCall, FixArrayLiteral, PrependUnqualifiedPgTableName,
RemoveTableFunctionQualifier, RemoveUnsupportedTypes, ResolveUnqualifiedIdentifer,
RewriteArrayAnyAllOperation, SqlStatementRewriteRule,
};
use crate::sql::PostgresCompatibilityParser;
use async_trait::async_trait;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::ToDFSchema;
Expand Down Expand Up @@ -91,37 +86,22 @@ pub struct DfSessionService {
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
auth_manager: Arc<AuthManager>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
}

impl DfSessionService {
pub fn new(
session_context: Arc<SessionContext>,
auth_manager: Arc<AuthManager>,
) -> DfSessionService {
let sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>> = vec![
// make sure blacklist based rewriter it on the top to prevent sql
// being rewritten from other rewriters
Arc::new(BlacklistSqlRewriter::new()),
Arc::new(AliasDuplicatedProjectionRewrite),
Arc::new(ResolveUnqualifiedIdentifer),
Arc::new(RemoveUnsupportedTypes::new()),
Arc::new(RewriteArrayAnyAllOperation),
Arc::new(PrependUnqualifiedPgTableName),
Arc::new(FixArrayLiteral),
Arc::new(RemoveTableFunctionQualifier),
Arc::new(CurrentUserVariableToSessionUserFunctionCall),
];
let parser = Arc::new(Parser {
session_context: session_context.clone(),
sql_rewrite_rules: sql_rewrite_rules.clone(),
sql_parser: PostgresCompatibilityParser::new(),
});
DfSessionService {
session_context,
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
auth_manager,
sql_rewrite_rules,
}
}

Expand Down Expand Up @@ -457,13 +437,14 @@ impl SimpleQueryHandler for DfSessionService {
return Ok(vec![resp]);
}

let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let mut statements = self
.parser
.sql_parser
.parse(query)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

// TODO: deal with multiple statements
let mut statement = statements.remove(0);

// Attempt to rewrite
statement = rewrite(statement, &self.sql_rewrite_rules);
let statement = statements.remove(0);

// TODO: improve statement check by using statement directly
let query = statement.to_string();
Expand Down Expand Up @@ -717,7 +698,7 @@ impl ExtendedQueryHandler for DfSessionService {

pub struct Parser {
session_context: Arc<SessionContext>,
sql_rewrite_rules: Vec<Arc<dyn SqlStatementRewriteRule>>,
sql_parser: PostgresCompatibilityParser,
}

impl Parser {
Expand Down Expand Up @@ -790,11 +771,11 @@ impl QueryParser for Parser {
return Ok((sql.to_string(), plan));
}

let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let mut statement = statements.remove(0);

// Attempt to rewrite
statement = rewrite(statement, &self.sql_rewrite_rules);
let mut statements = self
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let statement = statements.remove(0);

let query = statement.to_string();

Expand Down
132 changes: 61 additions & 71 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::pg_catalog::catalog_info::CatalogInfo;

pub mod catalog_info;
pub mod empty_table;
pub mod format_type;
pub mod has_privilege_udf;
pub mod pg_attribute;
pub mod pg_class;
Expand Down Expand Up @@ -685,7 +686,7 @@ impl PgCatalogStaticTables {
}
}

pub fn create_current_schemas_udf(name: &str) -> ScalarUDF {
pub fn create_current_schemas_udf() -> ScalarUDF {
// Define the function implementation
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
Expand All @@ -708,15 +709,15 @@ pub fn create_current_schemas_udf(name: &str) -> ScalarUDF {

// Wrap the implementation in a scalar function
create_udf(
name,
"current_schemas",
vec![DataType::Boolean],
DataType::List(Arc::new(Field::new("schema", DataType::Utf8, false))),
Volatility::Immutable,
Arc::new(func),
)
}

pub fn create_current_schema_udf(name: &str) -> ScalarUDF {
pub fn create_current_schema_udf() -> ScalarUDF {
// Define the function implementation
let func = move |_args: &[ColumnarValue]| {
// Create a UTF8 array with a single value
Expand All @@ -729,15 +730,15 @@ pub fn create_current_schema_udf(name: &str) -> ScalarUDF {

// Wrap the implementation in a scalar function
create_udf(
name,
"current_schema",
vec![],
DataType::Utf8,
Volatility::Immutable,
Arc::new(func),
)
}

pub fn create_current_database_udf(name: &str) -> ScalarUDF {
pub fn create_current_database_udf() -> ScalarUDF {
// Define the function implementation
let func = move |_args: &[ColumnarValue]| {
// Create a UTF8 array with a single value
Expand All @@ -750,30 +751,7 @@ pub fn create_current_database_udf(name: &str) -> ScalarUDF {

// Wrap the implementation in a scalar function
create_udf(
name,
vec![],
DataType::Utf8,
Volatility::Immutable,
Arc::new(func),
)
}

pub fn create_version_udf() -> ScalarUDF {
// Define the function implementation
let func = move |_args: &[ColumnarValue]| {
// Create a UTF8 array with version information
let mut builder = StringBuilder::new();
// TODO: improve version string generation
builder
.append_value("DataFusion PostgreSQL 48.0.0 on x86_64-pc-linux-gnu, compiled by Rust");
let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

// Wrap the implementation in a scalar function
create_udf(
"version",
"current_database",
vec![],
DataType::Utf8,
Volatility::Immutable,
Expand All @@ -800,15 +778,15 @@ pub fn create_pg_get_userbyid_udf() -> ScalarUDF {

// Wrap the implementation in a scalar function
create_udf(
"pg_catalog.pg_get_userbyid",
"pg_get_userbyid",
vec![DataType::Int32],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {
pub fn create_pg_table_is_visible() -> ScalarUDF {
// Define the function implementation
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
Expand All @@ -827,24 +805,42 @@ pub fn create_pg_table_is_visible(name: &str) -> ScalarUDF {

// Wrap the implementation in a scalar function
create_udf(
name,
"pg_table_is_visible",
vec![DataType::Int32],
DataType::Boolean,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_format_type_udf() -> ScalarUDF {
pub fn create_session_user_udf() -> ScalarUDF {
let func = move |_args: &[ColumnarValue]| {
let mut builder = StringBuilder::new();
// TODO: return real user
builder.append_value("postgres");

let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

create_udf(
"session_user",
vec![],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let type_oids = &args[0]; // Table (can be name or OID)
let _type_mods = &args[1]; // Privilege type (SELECT, INSERT, etc.)
let oid = &args[0];

// For now, always return true (full access for current user)
let mut builder = StringBuilder::new();
for _ in 0..type_oids.len() {
builder.append_value("???");
for _ in 0..oid.len() {
builder.append_value("");
}

let array: ArrayRef = Arc::new(builder.finish());
Expand All @@ -853,42 +849,46 @@ pub fn create_format_type_udf() -> ScalarUDF {
};

create_udf(
"format_type",
vec![DataType::Int64, DataType::Int32],
"pg_get_partkeydef",
vec![DataType::Utf8],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_session_user_udf() -> ScalarUDF {
let func = move |_args: &[ColumnarValue]| {
let mut builder = StringBuilder::new();
// TODO: return real user
builder.append_value("postgres");
pub fn create_pg_relation_is_publishable_udf() -> ScalarUDF {
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let oid = &args[0];

let mut builder = BooleanBuilder::new();
for _ in 0..oid.len() {
builder.append_value(true);
}

let array: ArrayRef = Arc::new(builder.finish());

Ok(ColumnarValue::Array(array))
};

create_udf(
"session_user",
vec![],
DataType::Utf8,
"pg_relation_is_publishable",
vec![DataType::Int32],
DataType::Boolean,
Volatility::Stable,
Arc::new(func),
)
}

pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
pub fn create_pg_get_statisticsobjdef_columns_udf() -> ScalarUDF {
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let oid = &args[0];

let mut builder = StringBuilder::new();
let mut builder = BooleanBuilder::new();
for _ in 0..oid.len() {
builder.append_value("");
builder.append_null();
}

let array: ArrayRef = Arc::new(builder.finish());
Expand All @@ -897,8 +897,8 @@ pub fn create_pg_get_partkeydef_udf() -> ScalarUDF {
};

create_udf(
"pg_catalog.pg_get_partkeydef",
vec![DataType::Utf8],
"pg_get_statisticsobjdef_columns",
vec![DataType::UInt32],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
Expand All @@ -924,38 +924,28 @@ pub fn setup_pg_catalog(
})?
.register_schema("pg_catalog", Arc::new(pg_catalog))?;

session_context.register_udf(create_current_database_udf("current_database"));
session_context.register_udf(create_current_schema_udf("current_schema"));
session_context.register_udf(create_current_schema_udf("pg_catalog.current_schema"));
session_context.register_udf(create_current_schemas_udf("current_schemas"));
session_context.register_udf(create_current_schemas_udf("pg_catalog.current_schemas"));
session_context.register_udf(create_version_udf());
session_context.register_udf(create_current_database_udf());
session_context.register_udf(create_current_schema_udf());
session_context.register_udf(create_current_schemas_udf());
// session_context.register_udf(create_version_udf());
session_context.register_udf(create_pg_get_userbyid_udf());
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"has_table_privilege",
));
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"pg_catalog.has_table_privilege",
));
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"has_schema_privilege",
));
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"pg_catalog.has_schema_privilege",
));
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"has_any_column_privilege",
));
session_context.register_udf(has_privilege_udf::create_has_privilege_udf(
"pg_catalog.has_any_column_privilege",
));
session_context.register_udf(create_pg_table_is_visible("pg_table_is_visible"));
session_context.register_udf(create_pg_table_is_visible("pg_catalog.pg_table_is_visible"));
session_context.register_udf(create_format_type_udf());
session_context.register_udf(create_pg_table_is_visible());
session_context.register_udf(format_type::create_format_type_udf());
session_context.register_udf(create_session_user_udf());
session_context.register_udtf("pg_get_keywords", static_tables.pg_get_keywords.clone());
session_context.register_udf(pg_get_expr_udf::create_pg_get_expr_udf());
session_context.register_udf(create_pg_get_partkeydef_udf());
session_context.register_udf(create_pg_relation_is_publishable_udf());
session_context.register_udf(create_pg_get_statisticsobjdef_columns_udf());

Ok(())
}
Expand Down
Loading
Loading