Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 85 additions & 35 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ impl PgTypesData {

#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
enum OidCacheKey {
Schema(String),
Catalog(String),
Schema(String, String),
/// Table by schema and table name
Table(String, String),
Table(String, String, String),
}

// Create custom schema provider for pg_catalog
Expand Down Expand Up @@ -262,7 +263,11 @@ impl SchemaProvider for PgCatalogSchemaProvider {
)))
}
PG_CATALOG_TABLE_PG_DATABASE => {
let table = Arc::new(PgDatabaseTable::new(self.catalog_list.clone()));
let table = Arc::new(PgDatabaseTable::new(
self.catalog_list.clone(),
self.oid_counter.clone(),
self.oid_cache.clone(),
));
Ok(Some(Arc::new(
StreamingTable::try_new(Arc::clone(table.schema()), vec![table]).unwrap(),
)))
Expand All @@ -289,7 +294,7 @@ impl PgCatalogSchemaProvider {
pub fn new(catalog_list: Arc<dyn CatalogProviderList>) -> PgCatalogSchemaProvider {
Self {
catalog_list,
oid_counter: Arc::new(AtomicU32::new(0)),
oid_counter: Arc::new(AtomicU32::new(16384)),
oid_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
Expand Down Expand Up @@ -1156,10 +1161,19 @@ impl PgClassTable {

// Iterate through all catalogs and schemas
for catalog_name in this.catalog_list.catalog_names() {
let cache_key = OidCacheKey::Catalog(catalog_name.clone());
let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
swap_cache.insert(cache_key, catalog_oid);

if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
if let Some(schema) = catalog.schema(&schema_name) {
let cache_key = OidCacheKey::Schema(schema_name.clone());
let cache_key =
OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
Expand All @@ -1172,8 +1186,11 @@ impl PgClassTable {

// Now process all tables in this schema
for table_name in schema.table_names() {
let cache_key =
OidCacheKey::Table(schema_name.clone(), table_name.clone());
let cache_key = OidCacheKey::Table(
catalog_name.clone(),
schema_name.clone(),
table_name.clone(),
);
let table_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
Expand Down Expand Up @@ -1334,7 +1351,7 @@ impl PgNamespaceTable {
for catalog_name in this.catalog_list.catalog_names() {
if let Some(catalog) = this.catalog_list.catalog(&catalog_name) {
for schema_name in catalog.schema_names() {
let cache_key = OidCacheKey::Schema(schema_name.clone());
let cache_key = OidCacheKey::Schema(catalog_name.clone(), schema_name.clone());
let schema_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
Expand All @@ -1353,10 +1370,10 @@ impl PgNamespaceTable {

// remove all schema cache and table of the schema which is no longer exists
oid_cache.retain(|key, _| match key {
OidCacheKey::Schema(_) => false,
OidCacheKey::Table(schema_name, _) => {
schema_oid_cache.contains_key(&OidCacheKey::Schema(schema_name.clone()))
}
OidCacheKey::Catalog(..) => true,
OidCacheKey::Schema(..) => false,
OidCacheKey::Table(catalog, schema_name, _) => schema_oid_cache
.contains_key(&OidCacheKey::Schema(catalog.clone(), schema_name.clone())),
});
// add new schema cache
oid_cache.extend(schema_oid_cache);
Expand Down Expand Up @@ -1391,14 +1408,20 @@ impl PartitionStream for PgNamespaceTable {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
struct PgDatabaseTable {
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
}

impl PgDatabaseTable {
pub fn new(catalog_list: Arc<dyn CatalogProviderList>) -> Self {
pub fn new(
catalog_list: Arc<dyn CatalogProviderList>,
oid_counter: Arc<AtomicU32>,
oid_cache: Arc<RwLock<HashMap<OidCacheKey, Oid>>>,
) -> Self {
// Define the schema for pg_database
// This matches PostgreSQL's pg_database table columns
let schema = Arc::new(Schema::new(vec![
Expand All @@ -1421,14 +1444,13 @@ impl PgDatabaseTable {
Self {
schema,
catalog_list,
oid_counter,
oid_cache,
}
}

/// Generate record batches based on the current state of the catalog
async fn get_data(
schema: SchemaRef,
catalog_list: Arc<dyn CatalogProviderList>,
) -> Result<RecordBatch> {
async fn get_data(this: PgDatabaseTable) -> Result<RecordBatch> {
// Vectors to store column data
let mut oids = Vec::new();
let mut datnames = Vec::new();
Expand All @@ -1445,15 +1467,22 @@ impl PgDatabaseTable {
let mut dattablespaces = Vec::new();
let mut datacles: Vec<Option<String>> = Vec::new();

// Start OID counter (this is simplistic and would need to be more robust in practice)
let mut next_oid = 16384; // Standard PostgreSQL starting OID for user databases
// to store all schema-oid mapping temporarily before adding to global oid cache
let mut catalog_oid_cache = HashMap::new();

// Add a record for each catalog (treating catalogs as "databases")
for catalog_name in catalog_list.catalog_names() {
let oid = next_oid;
next_oid += 1;
let mut oid_cache = this.oid_cache.write().await;

oids.push(oid);
// Add a record for each catalog (treating catalogs as "databases")
for catalog_name in this.catalog_list.catalog_names() {
let cache_key = OidCacheKey::Catalog(catalog_name.clone());
let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
catalog_oid_cache.insert(cache_key, catalog_oid);

oids.push(catalog_oid as i32);
datnames.push(catalog_name.clone());
datdbas.push(10); // Default owner (assuming 10 = postgres user)
encodings.push(6); // 6 = UTF8 in PostgreSQL
Expand All @@ -1471,11 +1500,18 @@ impl PgDatabaseTable {

// Always include a "postgres" database entry if not already present
// (This is for compatibility with tools that expect it)
if !datnames.contains(&"postgres".to_string()) {
let oid = next_oid;

oids.push(oid);
datnames.push("postgres".to_string());
let default_datname = "postgres".to_string();
if !datnames.contains(&default_datname) {
let cache_key = OidCacheKey::Catalog(default_datname.clone());
let catalog_oid = if let Some(oid) = oid_cache.get(&cache_key) {
*oid
} else {
this.oid_counter.fetch_add(1, Ordering::Relaxed)
};
catalog_oid_cache.insert(cache_key, catalog_oid);

oids.push(catalog_oid as i32);
datnames.push(default_datname);
datdbas.push(10);
encodings.push(6);
datcollates.push("en_US.UTF-8".to_string());
Expand Down Expand Up @@ -1509,7 +1545,22 @@ impl PgDatabaseTable {
];

// Create a full record batch
let full_batch = RecordBatch::try_new(schema.clone(), arrays)?;
let full_batch = RecordBatch::try_new(this.schema.clone(), arrays)?;

// update cache
// remove all schema cache and table of the schema which is no longer exists
oid_cache.retain(|key, _| match key {
OidCacheKey::Catalog(..) => false,
OidCacheKey::Schema(catalog, ..) => {
catalog_oid_cache.contains_key(&OidCacheKey::Catalog(catalog.clone()))
}
OidCacheKey::Table(catalog, ..) => {
catalog_oid_cache.contains_key(&OidCacheKey::Catalog(catalog.clone()))
}
});
// add new schema cache
oid_cache.extend(catalog_oid_cache);

Ok(full_batch)
}
}
Expand All @@ -1520,11 +1571,10 @@ impl PartitionStream for PgDatabaseTable {
}

fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
let catalog_list = self.catalog_list.clone();
let schema = Arc::clone(&self.schema);
let this = self.clone();
Box::pin(RecordBatchStreamAdapter::new(
schema.clone(),
futures::stream::once(async move { Self::get_data(schema, catalog_list).await }),
this.schema.clone(),
futures::stream::once(async move { Self::get_data(this).await }),
))
}
}
Expand Down
Loading