From a5dee947916199f85368696c23fd5ee2442fcc40 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 14 Sep 2025 11:25:23 +0800 Subject: [PATCH] refactor: add a layer of abstraction for CatalogProviderList --- datafusion-postgres/src/pg_catalog.rs | 50 +++------- .../src/pg_catalog/catalog_info.rs | 91 +++++++++++++++++++ .../src/pg_catalog/pg_attribute.rs | 31 ++++--- .../src/pg_catalog/pg_class.rs | 68 ++++++++------ .../src/pg_catalog/pg_database.rs | 15 +-- .../src/pg_catalog/pg_namespace.rs | 19 ++-- .../src/pg_catalog/pg_tables.rs | 25 ++--- 7 files changed, 191 insertions(+), 108 deletions(-) create mode 100644 datafusion-postgres/src/pg_catalog/catalog_info.rs diff --git a/datafusion-postgres/src/pg_catalog.rs b/datafusion-postgres/src/pg_catalog.rs index 5d68fb0..b82f2ea 100644 --- a/datafusion-postgres/src/pg_catalog.rs +++ b/datafusion-postgres/src/pg_catalog.rs @@ -9,9 +9,9 @@ use datafusion::arrow::array::{ use datafusion::arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion::arrow::ipc::reader::FileReader; use datafusion::catalog::streaming::StreamingTable; -use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl}; +use datafusion::catalog::{MemTable, SchemaProvider, TableFunctionImpl}; use datafusion::common::utils::SingleRowListArrayBuilder; -use datafusion::datasource::{TableProvider, ViewTable}; +use datafusion::datasource::TableProvider; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility}; use datafusion::physical_plan::streaming::PartitionStream; @@ -19,6 +19,9 @@ use datafusion::prelude::{create_udf, Expr, SessionContext}; use postgres_types::Oid; use tokio::sync::RwLock; +use crate::pg_catalog::catalog_info::CatalogInfo; + +pub mod catalog_info; pub mod empty_table; pub mod has_privilege_udf; pub mod pg_attribute; @@ -97,37 +100,6 @@ const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews"; const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables"; const PG_CATALOG_VIEW_PG_STAT_USER_TABELS: &str = "pg_stat_user_tables"; -/// Determine PostgreSQL table type (relkind) from DataFusion TableProvider -fn get_table_type(table: &Arc) -> &'static str { - // Use Any trait to determine the actual table provider type - if table.as_any().is::() { - "v" // view - } else { - "r" // All other table types (StreamingTable, MemTable, etc.) are treated as regular tables - } -} - -/// Determine PostgreSQL table type (relkind) with table name context -fn get_table_type_with_name( - table: &Arc, - table_name: &str, - schema_name: &str, -) -> &'static str { - // Check if this is a system catalog table - if schema_name == "pg_catalog" || schema_name == "information_schema" { - if table_name.starts_with("pg_") - || table_name.contains("_table") - || table_name.contains("_column") - { - "r" // System tables are still regular tables in PostgreSQL - } else { - "v" // Some system objects might be views - } - } else { - get_table_type(table) - } -} - pub const PG_CATALOG_TABLES: &[&str] = &[ PG_CATALOG_TABLE_PG_AGGREGATE, PG_CATALOG_TABLE_PG_AM, @@ -206,15 +178,15 @@ pub(crate) enum OidCacheKey { // Create custom schema provider for pg_catalog #[derive(Debug)] -pub struct PgCatalogSchemaProvider { - catalog_list: Arc, +pub struct PgCatalogSchemaProvider { + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, static_tables: Arc, } #[async_trait] -impl SchemaProvider for PgCatalogSchemaProvider { +impl SchemaProvider for PgCatalogSchemaProvider { fn as_any(&self) -> &dyn std::any::Any { self } @@ -389,11 +361,11 @@ impl SchemaProvider for PgCatalogSchemaProvider { } } -impl PgCatalogSchemaProvider { +impl PgCatalogSchemaProvider { pub fn try_new( - catalog_list: Arc, + catalog_list: C, static_tables: Arc, - ) -> Result { + ) -> Result> { Ok(Self { catalog_list, oid_counter: Arc::new(AtomicU32::new(16384)), diff --git a/datafusion-postgres/src/pg_catalog/catalog_info.rs b/datafusion-postgres/src/pg_catalog/catalog_info.rs new file mode 100644 index 0000000..cf58bd9 --- /dev/null +++ b/datafusion-postgres/src/pg_catalog/catalog_info.rs @@ -0,0 +1,91 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::{ + arrow::datatypes::SchemaRef, catalog::CatalogProviderList, datasource::TableType, + error::DataFusionError, +}; + +/// Define the interface for retrieve catalog data for pg_catalog tables +#[async_trait] +pub trait CatalogInfo: Clone + Send + Sync + Debug + 'static { + fn catalog_names(&self) -> Vec; + + fn schema_names(&self, catalog_name: &str) -> Option>; + + fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option>; + + async fn table_schema( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> Result, DataFusionError>; + + async fn table_type( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> Result, DataFusionError>; +} + +#[async_trait] +impl CatalogInfo for Arc { + fn catalog_names(&self) -> Vec { + CatalogProviderList::catalog_names(self.as_ref()) + } + + fn schema_names(&self, catalog_name: &str) -> Option> { + self.catalog(catalog_name).map(|c| c.schema_names()) + } + + fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option> { + self.catalog(catalog_name) + .and_then(|c| c.schema(schema_name)) + .map(|s| s.table_names()) + } + + async fn table_schema( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> Result, DataFusionError> { + let schema = self + .catalog(catalog_name) + .and_then(|c| c.schema(schema_name)); + if let Some(schema) = schema { + let table_schema = schema.table(table_name).await?.map(|t| t.schema()); + Ok(table_schema) + } else { + Ok(None) + } + } + + async fn table_type( + &self, + catalog_name: &str, + schema_name: &str, + table_name: &str, + ) -> Result, DataFusionError> { + let schema = self + .catalog(catalog_name) + .and_then(|c| c.schema(schema_name)); + if let Some(schema) = schema { + let table_type = schema.table_type(table_name).await?; + Ok(table_type) + } else { + Ok(None) + } + } +} + +pub fn table_type_to_string(tt: &TableType) -> String { + match tt { + TableType::Base => "r".to_string(), + TableType::View => "v".to_string(), + TableType::Temporary => "r".to_string(), + } +} diff --git a/datafusion-postgres/src/pg_catalog/pg_attribute.rs b/datafusion-postgres/src/pg_catalog/pg_attribute.rs index 1f6e596..dbc951c 100644 --- a/datafusion-postgres/src/pg_catalog/pg_attribute.rs +++ b/datafusion-postgres/src/pg_catalog/pg_attribute.rs @@ -6,7 +6,6 @@ use datafusion::arrow::array::{ ArrayRef, BooleanArray, Int16Array, Int32Array, RecordBatch, StringArray, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::catalog::CatalogProviderList; use datafusion::error::Result; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -14,19 +13,21 @@ use datafusion::physical_plan::streaming::PartitionStream; use postgres_types::Oid; use tokio::sync::RwLock; +use crate::pg_catalog::catalog_info::CatalogInfo; + use super::OidCacheKey; #[derive(Debug, Clone)] -pub(crate) struct PgAttributeTable { +pub(crate) struct PgAttributeTable { schema: SchemaRef, - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, } -impl PgAttributeTable { +impl PgAttributeTable { pub(crate) fn new( - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, ) -> Self { @@ -105,11 +106,13 @@ impl PgAttributeTable { let mut swap_cache = HashMap::new(); for catalog_name in this.catalog_list.catalog_names() { - if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { - for schema_name in catalog.schema_names() { - if let Some(schema_provider) = catalog.schema(&schema_name) { + if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) { + for schema_name in schema_names { + if let Some(table_names) = + this.catalog_list.table_names(&catalog_name, &schema_name) + { // Process all tables in this schema - for table_name in schema_provider.table_names() { + for table_name in table_names { let cache_key = OidCacheKey::Table( catalog_name.clone(), schema_name.clone(), @@ -122,9 +125,11 @@ impl PgAttributeTable { }; swap_cache.insert(cache_key, table_oid); - if let Some(table) = schema_provider.table(&table_name).await? { - let table_schema = table.schema(); - + if let Some(table_schema) = this + .catalog_list + .table_schema(&catalog_name, &schema_name, &table_name) + .await? + { // Add column entries for this table for (column_idx, field) in table_schema.fields().iter().enumerate() { @@ -233,7 +238,7 @@ impl PgAttributeTable { } } -impl PartitionStream for PgAttributeTable { +impl PartitionStream for PgAttributeTable { fn schema(&self) -> &SchemaRef { &self.schema } diff --git a/datafusion-postgres/src/pg_catalog/pg_class.rs b/datafusion-postgres/src/pg_catalog/pg_class.rs index 2767c6c..7284111 100644 --- a/datafusion-postgres/src/pg_catalog/pg_class.rs +++ b/datafusion-postgres/src/pg_catalog/pg_class.rs @@ -6,7 +6,7 @@ use datafusion::arrow::array::{ ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, StringArray, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::catalog::CatalogProviderList; +use datafusion::datasource::TableType; use datafusion::error::Result; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -14,22 +14,24 @@ use datafusion::physical_plan::streaming::PartitionStream; use postgres_types::Oid; use tokio::sync::RwLock; -use super::{get_table_type_with_name, OidCacheKey}; +use crate::pg_catalog::catalog_info::{table_type_to_string, CatalogInfo}; + +use super::OidCacheKey; #[derive(Debug, Clone)] -pub(crate) struct PgClassTable { +pub(crate) struct PgClassTable { schema: SchemaRef, - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, } -impl PgClassTable { +impl PgClassTable { pub(crate) fn new( - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, - ) -> PgClassTable { + ) -> Self { // Define the schema for pg_class // This matches key columns from PostgreSQL's pg_class let schema = Arc::new(Schema::new(vec![ @@ -75,7 +77,7 @@ impl PgClassTable { } /// Generate record batches based on the current state of the catalog - async fn get_data(this: PgClassTable) -> Result { + async fn get_data(this: Self) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut relnames = Vec::new(); @@ -124,23 +126,24 @@ impl PgClassTable { }; swap_cache.insert(cache_key, catalog_oid); - if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { - for schema_name in catalog.schema_names() { - if let Some(schema) = catalog.schema(&schema_name) { - let cache_key = - OidCacheKey::Schema(catalog_name.clone(), schema_name.clone()); - let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { - *oid - } else { - this.oid_counter.fetch_add(1, Ordering::Relaxed) - }; - swap_cache.insert(cache_key, schema_oid); + if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) { + for schema_name in schema_names { + let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone()); + let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { + *oid + } else { + this.oid_counter.fetch_add(1, Ordering::Relaxed) + }; + swap_cache.insert(cache_key, schema_oid); - // Add an entry for the schema itself (as a namespace) - // (In a full implementation, this would go in pg_namespace) + // Add an entry for the schema itself (as a namespace) + // (In a full implementation, this would go in pg_namespace) - // Now process all tables in this schema - for table_name in schema.table_names() { + // Now process all tables in this schema + if let Some(table_names) = + this.catalog_list.table_names(&catalog_name, &schema_name) + { + for table_name in table_names { let cache_key = OidCacheKey::Table( catalog_name.clone(), schema_name.clone(), @@ -153,13 +156,20 @@ impl PgClassTable { }; swap_cache.insert(cache_key, table_oid); - if let Some(table) = schema.table(&table_name).await? { + if let Some(table_schema) = this + .catalog_list + .table_schema(&catalog_name, &schema_name, &table_name) + .await? + { // Determine the correct table type based on the table provider and context - let table_type = - get_table_type_with_name(&table, &table_name, &schema_name); + let table_type = this + .catalog_list + .table_type(&catalog_name, &schema_name, &table_name) + .await? + .unwrap_or(TableType::Temporary); // Get column count from schema - let column_count = table.schema().fields().len() as i16; + let column_count = table_schema.fields().len() as i16; // Add table entry oids.push(table_oid as i32); @@ -178,7 +188,7 @@ impl PgClassTable { relhasindexes.push(false); relisshareds.push(false); relpersistences.push("p".to_string()); // Permanent - relkinds.push(table_type.to_string()); + relkinds.push(table_type_to_string(&table_type)); relnattses.push(column_count); relcheckses.push(0); relhasruleses.push(false); @@ -244,7 +254,7 @@ impl PgClassTable { } } -impl PartitionStream for PgClassTable { +impl PartitionStream for PgClassTable { fn schema(&self) -> &SchemaRef { &self.schema } diff --git a/datafusion-postgres/src/pg_catalog/pg_database.rs b/datafusion-postgres/src/pg_catalog/pg_database.rs index 5959977..ccc27e6 100644 --- a/datafusion-postgres/src/pg_catalog/pg_database.rs +++ b/datafusion-postgres/src/pg_catalog/pg_database.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, BooleanArray, Int32Array, RecordBatch, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::catalog::CatalogProviderList; use datafusion::error::Result; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -12,19 +11,21 @@ use datafusion::physical_plan::streaming::PartitionStream; use postgres_types::Oid; use tokio::sync::RwLock; +use crate::pg_catalog::catalog_info::CatalogInfo; + use super::OidCacheKey; #[derive(Debug, Clone)] -pub(crate) struct PgDatabaseTable { +pub(crate) struct PgDatabaseTable { schema: SchemaRef, - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, } -impl PgDatabaseTable { +impl PgDatabaseTable { pub(crate) fn new( - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, ) -> Self { @@ -56,7 +57,7 @@ impl PgDatabaseTable { } /// Generate record batches based on the current state of the catalog - async fn get_data(this: PgDatabaseTable) -> Result { + async fn get_data(this: Self) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut datnames = Vec::new(); @@ -171,7 +172,7 @@ impl PgDatabaseTable { } } -impl PartitionStream for PgDatabaseTable { +impl PartitionStream for PgDatabaseTable { fn schema(&self) -> &SchemaRef { &self.schema } diff --git a/datafusion-postgres/src/pg_catalog/pg_namespace.rs b/datafusion-postgres/src/pg_catalog/pg_namespace.rs index 060a996..c423b7b 100644 --- a/datafusion-postgres/src/pg_catalog/pg_namespace.rs +++ b/datafusion-postgres/src/pg_catalog/pg_namespace.rs @@ -4,7 +4,6 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::catalog::CatalogProviderList; use datafusion::error::Result; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -12,19 +11,21 @@ use datafusion::physical_plan::streaming::PartitionStream; use postgres_types::Oid; use tokio::sync::RwLock; +use crate::pg_catalog::catalog_info::CatalogInfo; + use super::OidCacheKey; #[derive(Debug, Clone)] -pub(crate) struct PgNamespaceTable { +pub(crate) struct PgNamespaceTable { schema: SchemaRef, - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, } -impl PgNamespaceTable { +impl PgNamespaceTable { pub(crate) fn new( - catalog_list: Arc, + catalog_list: C, oid_counter: Arc, oid_cache: Arc>>, ) -> Self { @@ -47,7 +48,7 @@ impl PgNamespaceTable { } /// Generate record batches based on the current state of the catalog - async fn get_data(this: PgNamespaceTable) -> Result { + async fn get_data(this: Self) -> Result { // Vectors to store column data let mut oids = Vec::new(); let mut nspnames = Vec::new(); @@ -62,8 +63,8 @@ impl PgNamespaceTable { // Now add all schemas from DataFusion catalogs for catalog_name in this.catalog_list.catalog_names() { - if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { - for schema_name in catalog.schema_names() { + if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) { + for schema_name in schema_names { let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone()); let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) { *oid @@ -107,7 +108,7 @@ impl PgNamespaceTable { } } -impl PartitionStream for PgNamespaceTable { +impl PartitionStream for PgNamespaceTable { fn schema(&self) -> &SchemaRef { &self.schema } diff --git a/datafusion-postgres/src/pg_catalog/pg_tables.rs b/datafusion-postgres/src/pg_catalog/pg_tables.rs index 7220d9d..155f68b 100644 --- a/datafusion-postgres/src/pg_catalog/pg_tables.rs +++ b/datafusion-postgres/src/pg_catalog/pg_tables.rs @@ -2,20 +2,21 @@ use std::sync::Arc; use datafusion::arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::catalog::CatalogProviderList; use datafusion::error::Result; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::streaming::PartitionStream; +use crate::pg_catalog::catalog_info::CatalogInfo; + #[derive(Debug, Clone)] -pub(crate) struct PgTablesTable { +pub(crate) struct PgTablesTable { schema: SchemaRef, - catalog_list: Arc, + catalog_list: C, } -impl PgTablesTable { - pub(crate) fn new(catalog_list: Arc) -> PgTablesTable { +impl PgTablesTable { + pub(crate) fn new(catalog_list: C) -> Self { // Define the schema for pg_class // This matches key columns from PostgreSQL's pg_class let schema = Arc::new(Schema::new(vec![ @@ -36,7 +37,7 @@ impl PgTablesTable { } /// Generate record batches based on the current state of the catalog - async fn get_data(this: PgTablesTable) -> Result { + async fn get_data(this: Self) -> Result { // Vectors to store column data let mut schema_names = Vec::new(); let mut table_names = Vec::new(); @@ -49,11 +50,13 @@ impl PgTablesTable { // Iterate through all catalogs and schemas for catalog_name in this.catalog_list.catalog_names() { - if let Some(catalog) = this.catalog_list.catalog(&catalog_name) { - for schema_name in catalog.schema_names() { - if let Some(schema) = catalog.schema(&schema_name) { + if let Some(catalog_schema_names) = this.catalog_list.schema_names(&catalog_name) { + for schema_name in catalog_schema_names { + if let Some(catalog_table_names) = + this.catalog_list.table_names(&catalog_name, &schema_name) + { // Now process all tables in this schema - for table_name in schema.table_names() { + for table_name in catalog_table_names { schema_names.push(schema_name.to_string()); table_names.push(table_name.to_string()); table_owners.push("postgres".to_string()); @@ -87,7 +90,7 @@ impl PgTablesTable { } } -impl PartitionStream for PgTablesTable { +impl PartitionStream for PgTablesTable { fn schema(&self) -> &SchemaRef { &self.schema }