From 81f91d9caade2ab2aa216b52de5b2dc426470aea Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 20 Aug 2025 16:32:55 +0800 Subject: [PATCH 1/3] refactor: remove transaction management logic and use pgwire api --- datafusion-postgres/src/handlers.rs | 261 ++-------------------------- 1 file changed, 14 insertions(+), 247 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 9565e67..d9c5ac8 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::auth::{AuthManager, Permission, ResourceType}; @@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::{Duration, Instant}; -use tokio::sync::{Mutex, RwLock}; +use pgwire::messages::response::TransactionStatus; +use tokio::sync::Mutex; use arrow_pg::datatypes::df; use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum TransactionState { - None, - Active, - Failed, -} - /// Simple startup handler that does no authentication /// For production, use DfAuthSource with proper pgwire authentication handlers pub struct SimpleStartupHandler; @@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory { } } -/// Per-connection transaction state storage -/// We use a hash of both PID and secret key as the connection identifier for better uniqueness -pub type ConnectionId = u64; - -#[derive(Debug, Clone)] -struct ConnectionState { - transaction_state: TransactionState, - last_activity: Instant, -} - -type ConnectionStates = Arc>>; - /// The pgwire handler backed by a datafusion `SessionContext` pub struct DfSessionService { session_context: Arc, parser: Arc, timezone: Arc>, - connection_states: ConnectionStates, auth_manager: Arc, - cleanup_counter: AtomicU64, } impl DfSessionService { @@ -100,57 +77,10 @@ impl DfSessionService { session_context, parser, timezone: Arc::new(Mutex::new("UTC".to_string())), - connection_states: Arc::new(RwLock::new(HashMap::new())), auth_manager, - cleanup_counter: AtomicU64::new(0), - } - } - - async fn get_transaction_state(&self, client_id: ConnectionId) -> TransactionState { - self.connection_states - .read() - .await - .get(&client_id) - .map(|s| s.transaction_state) - .unwrap_or(TransactionState::None) - } - - async fn update_transaction_state(&self, client_id: ConnectionId, new_state: TransactionState) { - let mut states = self.connection_states.write().await; - - // Update or insert state using entry API - states - .entry(client_id) - .and_modify(|s| { - s.transaction_state = new_state; - s.last_activity = Instant::now(); - }) - .or_insert(ConnectionState { - transaction_state: new_state, - last_activity: Instant::now(), - }); - - // Inline cleanup every 100 operations - if self.cleanup_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 { - let cutoff = Instant::now() - Duration::from_secs(3600); - states.retain(|_, state| state.last_activity > cutoff); } } - fn get_client_id(client: &C) -> ConnectionId { - // Use a hash of PID, secret key, and socket address for better uniqueness - let (pid, secret) = client.pid_and_secret_key(); - let socket_addr = client.socket_addr(); - - // Create a hash of all identifying values - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - pid.hash(&mut hasher); - secret.hash(&mut hasher); - socket_addr.hash(&mut hasher); - - hasher.finish() - } - /// Check if the current user has permission to execute a query async fn check_query_permission(&self, client: &C, query: &str) -> PgWireResult<()> where @@ -290,24 +220,15 @@ impl DfSessionService { where C: ClientInfo, { - let client_id = Self::get_client_id(client); - // Transaction handling based on pgwire example: // https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57 match query_lower.trim() { "begin" | "begin transaction" | "begin work" | "start transaction" => { - match self.get_transaction_state(client_id).await { - TransactionState::None => { - self.update_transaction_state(client_id, TransactionState::Active) - .await; - Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))) - } - TransactionState::Active => { - // Already in transaction, PostgreSQL allows this but issues a warning - // For simplicity, we'll just return BEGIN again + match client.transaction_status() { + TransactionStatus::Idle | TransactionStatus::Transaction => { Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))) } - TransactionState::Failed => { + TransactionStatus::Error => { // Can't start new transaction from failed state Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( @@ -320,27 +241,16 @@ impl DfSessionService { } } "commit" | "commit transaction" | "commit work" | "end" | "end transaction" => { - match self.get_transaction_state(client_id).await { - TransactionState::Active => { - self.update_transaction_state(client_id, TransactionState::None) - .await; - Ok(Some(Response::TransactionEnd(Tag::new("COMMIT")))) - } - TransactionState::None => { - // PostgreSQL allows COMMIT outside transaction with warning + match client.transaction_status() { + TransactionStatus::Idle | TransactionStatus::Transaction => { Ok(Some(Response::TransactionEnd(Tag::new("COMMIT")))) } - TransactionState::Failed => { - // COMMIT in failed transaction is treated as ROLLBACK - self.update_transaction_state(client_id, TransactionState::None) - .await; + TransactionStatus::Error => { Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) } } } "rollback" | "rollback transaction" | "rollback work" | "abort" => { - self.update_transaction_state(client_id, TransactionState::None) - .await; Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK")))) } _ => Ok(None), @@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService { return Ok(vec![resp]); } - // Check if we're in a failed transaction and block non-transaction commands - let client_id = Self::get_client_id(client); - if self.get_transaction_state(client_id).await == TransactionState::Failed { + // Check if we're in a failed transaction and block non-transaction + // commands + if client.transaction_status() == TransactionStatus::Error { return Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( "ERROR".to_string(), @@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService { let df = match df_result { Ok(df) => df, Err(e) => { - // If we're in a transaction and a query fails, mark transaction as failed - let client_id = Self::get_client_id(client); - if self.get_transaction_state(client_id).await == TransactionState::Active { - self.update_transaction_state(client_id, TransactionState::Failed) - .await; - } return Err(PgWireError::ApiError(Box::new(e))); } }; @@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService { return Ok(resp); } - // Check if we're in a failed transaction and block non-transaction commands - let client_id = Self::get_client_id(client); - if self.get_transaction_state(client_id).await == TransactionState::Failed { + // Check if we're in a failed transaction and block non-transaction + // commands + if client.transaction_status() == TransactionStatus::Error { return Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( "ERROR".to_string(), @@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService { let dataframe = match self.session_context.execute_logical_plan(plan).await { Ok(df) => df, Err(e) => { - // If we're in a transaction and a query fails, mark transaction as failed - let client_id = Self::get_client_id(client); - if self.get_transaction_state(client_id).await == TransactionState::Active { - self.update_transaction_state(client_id, TransactionState::Failed) - .await; - } return Err(PgWireError::ApiError(Box::new(e))); } }; @@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap>) -> Vec Date: Wed, 20 Aug 2025 16:38:03 +0800 Subject: [PATCH 2/3] chore: lint fix --- datafusion-postgres/src/handlers.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index d9c5ac8..664f41f 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -309,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService { C: ClientInfo + Unpin + Send + Sync, { let query_lower = query.to_lowercase().trim().to_string(); - log::debug!("Received query: {}", query); // Log the query for debugging + log::debug!("Received query: {query}"); // Log the query for debugging // Check permissions for the query (skip for SET, transaction, and SHOW statements) if !query_lower.starts_with("set") @@ -461,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService { .to_lowercase() .trim() .to_string(); - log::debug!("Received execute extended query: {}", query); // Log for debugging + log::debug!("Received execute extended query: {query}"); // Log for debugging // Check permissions for the query (skip for SET and SHOW statements) if !query.starts_with("set") && !query.starts_with("show") { @@ -531,7 +531,7 @@ impl QueryParser for Parser { sql: &str, _types: &[Type], ) -> PgWireResult { - log::debug!("Received parse extended query: {}", sql); // Log for debugging + log::debug!("Received parse extended query: {sql}"); // Log for debugging let context = &self.session_context; let state = context.state(); let logical_plan = state From 366bc58a0b836e0f171aec70fb4a67c4981399a4 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 20 Aug 2025 16:38:50 +0800 Subject: [PATCH 3/3] flake.lock: Update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flake lock file updates: • Updated input 'fenix': 'github:nix-community/fenix/b37f026b49ecb295a448c96bcbb0c174c14fc91b?narHash=sha256-icKFXb83uv2ezRCfuq5G8QSwCuaoLywLljSL%2BUGmPPI%3D' (2025-06-27) → 'github:nix-community/fenix/6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42?narHash=sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY%3D' (2025-08-19) • Updated input 'fenix/rust-analyzer-src': 'github:rust-lang/rust-analyzer/317542c1e4a3ec3467d21d1c25f6a43b80d83e7d?narHash=sha256-hMNZXMtlhfjQdu1F4Fa/UFiMoXdZag4cider2R9a648%3D' (2025-06-25) → 'github:rust-lang/rust-analyzer/a905e3b21b144d77e1b304e49f3264f6f8d4db75?narHash=sha256-VX0B9hwhJypCGqncVVLC%2BSmeMVd/GAYbJZ0MiiUn2Pk%3D' (2025-08-18) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/7284e2decc982b81a296ab35aa46e804baaa1cfe?narHash=sha256-aVkL3/yu50oQzi2YuKo0ceiCypVZpZXYd2P2p1FMJM4%3D' (2025-06-25) → 'github:NixOS/nixpkgs/a58390ab6f1aa810eb8e0f0fc74230e7cc06de03?narHash=sha256-BA9MuPjBDx/WnpTJ0EGhStyfE7hug8g85Y3Ju9oTsM4%3D' (2025-08-19) --- flake.lock | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/flake.lock b/flake.lock index b5fcedb..e90c42b 100644 --- a/flake.lock +++ b/flake.lock @@ -8,11 +8,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1751006353, - "narHash": "sha256-icKFXb83uv2ezRCfuq5G8QSwCuaoLywLljSL+UGmPPI=", + "lastModified": 1755585599, + "narHash": "sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY=", "owner": "nix-community", "repo": "fenix", - "rev": "b37f026b49ecb295a448c96bcbb0c174c14fc91b", + "rev": "6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42", "type": "github" }, "original": { @@ -41,11 +41,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1750838302, - "narHash": "sha256-aVkL3/yu50oQzi2YuKo0ceiCypVZpZXYd2P2p1FMJM4=", + "lastModified": 1755593991, + "narHash": "sha256-BA9MuPjBDx/WnpTJ0EGhStyfE7hug8g85Y3Ju9oTsM4=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "7284e2decc982b81a296ab35aa46e804baaa1cfe", + "rev": "a58390ab6f1aa810eb8e0f0fc74230e7cc06de03", "type": "github" }, "original": { @@ -65,11 +65,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1750871759, - "narHash": "sha256-hMNZXMtlhfjQdu1F4Fa/UFiMoXdZag4cider2R9a648=", + "lastModified": 1755504847, + "narHash": "sha256-VX0B9hwhJypCGqncVVLC+SmeMVd/GAYbJZ0MiiUn2Pk=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "317542c1e4a3ec3467d21d1c25f6a43b80d83e7d", + "rev": "a905e3b21b144d77e1b304e49f3264f6f8d4db75", "type": "github" }, "original": {