Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 11 additions & 39 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ use datafusion::arrow::array::{
use datafusion::arrow::datatypes::{DataType, Field, SchemaRef};
use datafusion::arrow::ipc::reader::FileReader;
use datafusion::catalog::streaming::StreamingTable;
use datafusion::catalog::{CatalogProviderList, MemTable, SchemaProvider, TableFunctionImpl};
use datafusion::catalog::{MemTable, SchemaProvider, TableFunctionImpl};
use datafusion::common::utils::SingleRowListArrayBuilder;
use datafusion::datasource::{TableProvider, ViewTable};
use datafusion::datasource::TableProvider;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, Volatility};
use datafusion::physical_plan::streaming::PartitionStream;
use datafusion::prelude::{create_udf, Expr, SessionContext};
use postgres_types::Oid;
use tokio::sync::RwLock;

use crate::pg_catalog::catalog_info::CatalogInfo;

pub mod catalog_info;
pub mod empty_table;
pub mod has_privilege_udf;
pub mod pg_attribute;
Expand Down Expand Up @@ -97,37 +100,6 @@ const PG_CATALOG_VIEW_PG_MATVIEWS: &str = "pg_matviews";
const PG_CATALOG_VIEW_PG_TABLES: &str = "pg_tables";
const PG_CATALOG_VIEW_PG_STAT_USER_TABELS: &str = "pg_stat_user_tables";

/// Determine PostgreSQL table type (relkind) from DataFusion TableProvider
fn get_table_type(table: &Arc<dyn TableProvider>) -> &'static str {
// Use Any trait to determine the actual table provider type
if table.as_any().is::<ViewTable>() {
"v" // view
} else {
"r" // All other table types (StreamingTable, MemTable, etc.) are treated as regular tables
}
}

/// Determine PostgreSQL table type (relkind) with table name context
fn get_table_type_with_name(
table: &Arc<dyn TableProvider>,
table_name: &str,
schema_name: &str,
) -> &'static str {
// Check if this is a system catalog table
if schema_name == "pg_catalog" || schema_name == "information_schema" {
if table_name.starts_with("pg_")
|| table_name.contains("_table")
|| table_name.contains("_column")
{
"r" // System tables are still regular tables in PostgreSQL
} else {
"v" // Some system objects might be views
}
} else {
get_table_type(table)
}
}

pub const PG_CATALOG_TABLES: &[&str] = &[
PG_CATALOG_TABLE_PG_AGGREGATE,
PG_CATALOG_TABLE_PG_AM,
Expand Down Expand Up @@ -206,15 +178,15 @@ pub(crate) enum OidCacheKey {

// Create custom schema provider for pg_catalog
#[derive(Debug)]
pub struct PgCatalogSchemaProvider {
catalog_list: Arc<dyn CatalogProviderList>,
pub struct PgCatalogSchemaProvider<C> {
catalog_list: C,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
static_tables: Arc<PgCatalogStaticTables>,
}

#[async_trait]
impl SchemaProvider for PgCatalogSchemaProvider {
impl<C: CatalogInfo> SchemaProvider for PgCatalogSchemaProvider<C> {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand Down Expand Up @@ -389,11 +361,11 @@ impl SchemaProvider for PgCatalogSchemaProvider {
}
}

impl PgCatalogSchemaProvider {
impl<C: CatalogInfo> PgCatalogSchemaProvider<C> {
pub fn try_new(
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: C,
static_tables: Arc<PgCatalogStaticTables>,
) -> Result<PgCatalogSchemaProvider> {
) -> Result<PgCatalogSchemaProvider<C>> {
Ok(Self {
catalog_list,
oid_counter: Arc::new(AtomicU32::new(16384)),
Expand Down
91 changes: 91 additions & 0 deletions datafusion-postgres/src/pg_catalog/catalog_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::fmt::Debug;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::{
arrow::datatypes::SchemaRef, catalog::CatalogProviderList, datasource::TableType,
error::DataFusionError,
};

/// Define the interface for retrieve catalog data for pg_catalog tables
#[async_trait]
pub trait CatalogInfo: Clone + Send + Sync + Debug + 'static {
fn catalog_names(&self) -> Vec<String>;

fn schema_names(&self, catalog_name: &str) -> Option<Vec<String>>;

fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option<Vec<String>>;

async fn table_schema(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
) -> Result<Option<SchemaRef>, DataFusionError>;

async fn table_type(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
) -> Result<Option<TableType>, DataFusionError>;
}

#[async_trait]
impl CatalogInfo for Arc<dyn CatalogProviderList> {
fn catalog_names(&self) -> Vec<String> {
CatalogProviderList::catalog_names(self.as_ref())
}

fn schema_names(&self, catalog_name: &str) -> Option<Vec<String>> {
self.catalog(catalog_name).map(|c| c.schema_names())
}

fn table_names(&self, catalog_name: &str, schema_name: &str) -> Option<Vec<String>> {
self.catalog(catalog_name)
.and_then(|c| c.schema(schema_name))
.map(|s| s.table_names())
}

async fn table_schema(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
) -> Result<Option<SchemaRef>, DataFusionError> {
let schema = self
.catalog(catalog_name)
.and_then(|c| c.schema(schema_name));
if let Some(schema) = schema {
let table_schema = schema.table(table_name).await?.map(|t| t.schema());
Ok(table_schema)
} else {
Ok(None)
}
}

async fn table_type(
&self,
catalog_name: &str,
schema_name: &str,
table_name: &str,
) -> Result<Option<TableType>, DataFusionError> {
let schema = self
.catalog(catalog_name)
.and_then(|c| c.schema(schema_name));
if let Some(schema) = schema {
let table_type = schema.table_type(table_name).await?;
Ok(table_type)
} else {
Ok(None)
}
}
}

pub fn table_type_to_string(tt: &TableType) -> String {
match tt {
TableType::Base => "r".to_string(),
TableType::View => "v".to_string(),
TableType::Temporary => "r".to_string(),
}
}
31 changes: 18 additions & 13 deletions datafusion-postgres/src/pg_catalog/pg_attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@ use datafusion::arrow::array::{
ArrayRef, BooleanArray, Int16Array, Int32Array, RecordBatch, StringArray,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::catalog::CatalogProviderList;
use datafusion::error::Result;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::streaming::PartitionStream;
use postgres_types::Oid;
use tokio::sync::RwLock;

use crate::pg_catalog::catalog_info::CatalogInfo;

use super::OidCacheKey;

#[derive(Debug, Clone)]
pub(crate) struct PgAttributeTable {
pub(crate) struct PgAttributeTable<C> {
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: C,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

impl PgAttributeTable {
impl<C: CatalogInfo> PgAttributeTable<C> {
pub(crate) fn new(
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: C,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
) -> Self {
Expand Down Expand Up @@ -105,11 +106,13 @@ impl PgAttributeTable {
let mut swap_cache = HashMap::new();

for catalog_name in this.catalog_list.catalog_names() {
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
if let Some(schema_provider) = catalog.schema(&schema_name) {
if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) {
for schema_name in schema_names {
if let Some(table_names) =
this.catalog_list.table_names(&catalog_name, &schema_name)
{
// Process all tables in this schema
for table_name in schema_provider.table_names() {
for table_name in table_names {
let cache_key = OidCacheKey::Table(
catalog_name.clone(),
schema_name.clone(),
Expand All @@ -122,9 +125,11 @@ impl PgAttributeTable {
};
swap_cache.insert(cache_key, table_oid);

if let Some(table) = schema_provider.table(&table_name).await? {
let table_schema = table.schema();

if let Some(table_schema) = this
.catalog_list
.table_schema(&catalog_name, &schema_name, &table_name)
.await?
{
// Add column entries for this table
for (column_idx, field) in table_schema.fields().iter().enumerate()
{
Expand Down Expand Up @@ -233,7 +238,7 @@ impl PgAttributeTable {
}
}

impl PartitionStream for PgAttributeTable {
impl<C: CatalogInfo> PartitionStream for PgAttributeTable<C> {
fn schema(&self) -> &SchemaRef {
&self.schema
}
Expand Down
68 changes: 39 additions & 29 deletions datafusion-postgres/src/pg_catalog/pg_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,32 @@ use datafusion::arrow::array::{
ArrayRef, BooleanArray, Float64Array, Int16Array, Int32Array, RecordBatch, StringArray,
};
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion::catalog::CatalogProviderList;
use datafusion::datasource::TableType;
use datafusion::error::Result;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::streaming::PartitionStream;
use postgres_types::Oid;
use tokio::sync::RwLock;

use super::{get_table_type_with_name, OidCacheKey};
use crate::pg_catalog::catalog_info::{table_type_to_string, CatalogInfo};

use super::OidCacheKey;

#[derive(Debug, Clone)]
pub(crate) struct PgClassTable {
pub(crate) struct PgClassTable<C> {
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: C,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

impl PgClassTable {
impl<C: CatalogInfo> PgClassTable<C> {
pub(crate) fn new(
catalog_list: Arc<dyn CatalogProviderList>,
catalog_list: C,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
) -> PgClassTable {
) -> Self {
// Define the schema for pg_class
// This matches key columns from PostgreSQL's pg_class
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -75,7 +77,7 @@ impl PgClassTable {
}

/// Generate record batches based on the current state of the catalog
async fn get_data(this: PgClassTable) -> Result<RecordBatch> {
async fn get_data(this: Self) -> Result<RecordBatch> {
// Vectors to store column data
let mut oids = Vec::new();
let mut relnames = Vec::new();
Expand Down Expand Up @@ -124,23 +126,24 @@ impl PgClassTable {
};
swap_cache.insert(cache_key, catalog_oid);

if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
if let Some(schema) = catalog.schema(&schema_name) {
let cache_key =
OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
swap_cache.insert(cache_key, schema_oid);
if let Some(schema_names) = this.catalog_list.schema_names(&catalog_name) {
for schema_name in schema_names {
let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
swap_cache.insert(cache_key, schema_oid);

// Add an entry for the schema itself (as a namespace)
// (In a full implementation, this would go in pg_namespace)
// Add an entry for the schema itself (as a namespace)
// (In a full implementation, this would go in pg_namespace)

// Now process all tables in this schema
for table_name in schema.table_names() {
// Now process all tables in this schema
if let Some(table_names) =
this.catalog_list.table_names(&catalog_name, &schema_name)
{
for table_name in table_names {
let cache_key = OidCacheKey::Table(
catalog_name.clone(),
schema_name.clone(),
Expand All @@ -153,13 +156,20 @@ impl PgClassTable {
};
swap_cache.insert(cache_key, table_oid);

if let Some(table) = schema.table(&table_name).await? {
if let Some(table_schema) = this
.catalog_list
.table_schema(&catalog_name, &schema_name, &table_name)
.await?
{
// Determine the correct table type based on the table provider and context
let table_type =
get_table_type_with_name(&table, &table_name, &schema_name);
let table_type = this
.catalog_list
.table_type(&catalog_name, &schema_name, &table_name)
.await?
.unwrap_or(TableType::Temporary);

// Get column count from schema
let column_count = table.schema().fields().len() as i16;
let column_count = table_schema.fields().len() as i16;

// Add table entry
oids.push(table_oid as i32);
Expand All @@ -178,7 +188,7 @@ impl PgClassTable {
relhasindexes.push(false);
relisshareds.push(false);
relpersistences.push("p".to_string()); // Permanent
relkinds.push(table_type.to_string());
relkinds.push(table_type_to_string(&table_type));
relnattses.push(column_count);
relcheckses.push(0);
relhasruleses.push(false);
Expand Down Expand Up @@ -244,7 +254,7 @@ impl PgClassTable {
}
}

impl PartitionStream for PgClassTable {
impl<C: CatalogInfo> PartitionStream for PgClassTable<C> {
fn schema(&self) -> &SchemaRef {
&self.schema
}
Expand Down
Loading
Loading