Skip to content

Commit cb261e5

Browse files
authored
refactor: update transaction management (#132)
* refactor: remove transaction management logic and use pgwire api * chore: lint fix * flake.lock: Update Flake lock file updates: • Updated input 'fenix': 'github:nix-community/fenix/b37f026b49ecb295a448c96bcbb0c174c14fc91b?narHash=sha256-icKFXb83uv2ezRCfuq5G8QSwCuaoLywLljSL%2BUGmPPI%3D' (2025-06-27) → 'github:nix-community/fenix/6ed03ef4c8ec36d193c18e06b9ecddde78fb7e42?narHash=sha256-tl/0cnsqB/Yt7DbaGMel2RLa7QG5elA8lkaOXli6VdY%3D' (2025-08-19) • Updated input 'fenix/rust-analyzer-src': 'github:rust-lang/rust-analyzer/317542c1e4a3ec3467d21d1c25f6a43b80d83e7d?narHash=sha256-hMNZXMtlhfjQdu1F4Fa/UFiMoXdZag4cider2R9a648%3D' (2025-06-25) → 'github:rust-lang/rust-analyzer/a905e3b21b144d77e1b304e49f3264f6f8d4db75?narHash=sha256-VX0B9hwhJypCGqncVVLC%2BSmeMVd/GAYbJZ0MiiUn2Pk%3D' (2025-08-18) • Updated input 'nixpkgs': 'github:NixOS/nixpkgs/7284e2decc982b81a296ab35aa46e804baaa1cfe?narHash=sha256-aVkL3/yu50oQzi2YuKo0ceiCypVZpZXYd2P2p1FMJM4%3D' (2025-06-25) → 'github:NixOS/nixpkgs/a58390ab6f1aa810eb8e0f0fc74230e7cc06de03?narHash=sha256-BA9MuPjBDx/WnpTJ0EGhStyfE7hug8g85Y3Ju9oTsM4%3D' (2025-08-19)
1 parent 0af480f commit cb261e5

File tree

2 files changed

+26
-259
lines changed

2 files changed

+26
-259
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 17 additions & 250 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::collections::HashMap;
2-
use std::hash::{Hash, Hasher};
32
use std::sync::Arc;
43

54
use crate::auth::{AuthManager, Permission, ResourceType};
@@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser;
1918
use pgwire::api::stmt::StoredStatement;
2019
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
2120
use pgwire::error::{PgWireError, PgWireResult};
22-
use std::sync::atomic::{AtomicU64, Ordering};
23-
use std::time::{Duration, Instant};
24-
use tokio::sync::{Mutex, RwLock};
21+
use pgwire::messages::response::TransactionStatus;
22+
use tokio::sync::Mutex;
2523

2624
use arrow_pg::datatypes::df;
2725
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};
2826

29-
#[derive(Debug, Clone, Copy, PartialEq)]
30-
pub enum TransactionState {
31-
None,
32-
Active,
33-
Failed,
34-
}
35-
3627
/// Simple startup handler that does no authentication
3728
/// For production, use DfAuthSource with proper pgwire authentication handlers
3829
pub struct SimpleStartupHandler;
@@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory {
6657
}
6758
}
6859

69-
/// Per-connection transaction state storage
70-
/// We use a hash of both PID and secret key as the connection identifier for better uniqueness
71-
pub type ConnectionId = u64;
72-
73-
#[derive(Debug, Clone)]
74-
struct ConnectionState {
75-
transaction_state: TransactionState,
76-
last_activity: Instant,
77-
}
78-
79-
type ConnectionStates = Arc<RwLock<HashMap<ConnectionId, ConnectionState>>>;
80-
8160
/// The pgwire handler backed by a datafusion `SessionContext`
8261
pub struct DfSessionService {
8362
session_context: Arc<SessionContext>,
8463
parser: Arc<Parser>,
8564
timezone: Arc<Mutex<String>>,
86-
connection_states: ConnectionStates,
8765
auth_manager: Arc<AuthManager>,
88-
cleanup_counter: AtomicU64,
8966
}
9067

9168
impl DfSessionService {
@@ -100,57 +77,10 @@ impl DfSessionService {
10077
session_context,
10178
parser,
10279
timezone: Arc::new(Mutex::new("UTC".to_string())),
103-
connection_states: Arc::new(RwLock::new(HashMap::new())),
10480
auth_manager,
105-
cleanup_counter: AtomicU64::new(0),
106-
}
107-
}
108-
109-
async fn get_transaction_state(&self, client_id: ConnectionId) -> TransactionState {
110-
self.connection_states
111-
.read()
112-
.await
113-
.get(&client_id)
114-
.map(|s| s.transaction_state)
115-
.unwrap_or(TransactionState::None)
116-
}
117-
118-
async fn update_transaction_state(&self, client_id: ConnectionId, new_state: TransactionState) {
119-
let mut states = self.connection_states.write().await;
120-
121-
// Update or insert state using entry API
122-
states
123-
.entry(client_id)
124-
.and_modify(|s| {
125-
s.transaction_state = new_state;
126-
s.last_activity = Instant::now();
127-
})
128-
.or_insert(ConnectionState {
129-
transaction_state: new_state,
130-
last_activity: Instant::now(),
131-
});
132-
133-
// Inline cleanup every 100 operations
134-
if self.cleanup_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 {
135-
let cutoff = Instant::now() - Duration::from_secs(3600);
136-
states.retain(|_, state| state.last_activity > cutoff);
13781
}
13882
}
13983

140-
fn get_client_id<C: ClientInfo>(client: &C) -> ConnectionId {
141-
// Use a hash of PID, secret key, and socket address for better uniqueness
142-
let (pid, secret) = client.pid_and_secret_key();
143-
let socket_addr = client.socket_addr();
144-
145-
// Create a hash of all identifying values
146-
let mut hasher = std::collections::hash_map::DefaultHasher::new();
147-
pid.hash(&mut hasher);
148-
secret.hash(&mut hasher);
149-
socket_addr.hash(&mut hasher);
150-
151-
hasher.finish()
152-
}
153-
15484
/// Check if the current user has permission to execute a query
15585
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
15686
where
@@ -290,24 +220,15 @@ impl DfSessionService {
290220
where
291221
C: ClientInfo,
292222
{
293-
let client_id = Self::get_client_id(client);
294-
295223
// Transaction handling based on pgwire example:
296224
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
297225
match query_lower.trim() {
298226
"begin" | "begin transaction" | "begin work" | "start transaction" => {
299-
match self.get_transaction_state(client_id).await {
300-
TransactionState::None => {
301-
self.update_transaction_state(client_id, TransactionState::Active)
302-
.await;
303-
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
304-
}
305-
TransactionState::Active => {
306-
// Already in transaction, PostgreSQL allows this but issues a warning
307-
// For simplicity, we'll just return BEGIN again
227+
match client.transaction_status() {
228+
TransactionStatus::Idle | TransactionStatus::Transaction => {
308229
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
309230
}
310-
TransactionState::Failed => {
231+
TransactionStatus::Error => {
311232
// Can't start new transaction from failed state
312233
Err(PgWireError::UserError(Box::new(
313234
pgwire::error::ErrorInfo::new(
@@ -320,27 +241,16 @@ impl DfSessionService {
320241
}
321242
}
322243
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
323-
match self.get_transaction_state(client_id).await {
324-
TransactionState::Active => {
325-
self.update_transaction_state(client_id, TransactionState::None)
326-
.await;
327-
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
328-
}
329-
TransactionState::None => {
330-
// PostgreSQL allows COMMIT outside transaction with warning
244+
match client.transaction_status() {
245+
TransactionStatus::Idle | TransactionStatus::Transaction => {
331246
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
332247
}
333-
TransactionState::Failed => {
334-
// COMMIT in failed transaction is treated as ROLLBACK
335-
self.update_transaction_state(client_id, TransactionState::None)
336-
.await;
248+
TransactionStatus::Error => {
337249
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
338250
}
339251
}
340252
}
341253
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
342-
self.update_transaction_state(client_id, TransactionState::None)
343-
.await;
344254
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
345255
}
346256
_ => Ok(None),
@@ -399,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService {
399309
C: ClientInfo + Unpin + Send + Sync,
400310
{
401311
let query_lower = query.to_lowercase().trim().to_string();
402-
log::debug!("Received query: {}", query); // Log the query for debugging
312+
log::debug!("Received query: {query}"); // Log the query for debugging
403313

404314
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
405315
if !query_lower.starts_with("set")
@@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService {
429339
return Ok(vec![resp]);
430340
}
431341

432-
// Check if we're in a failed transaction and block non-transaction commands
433-
let client_id = Self::get_client_id(client);
434-
if self.get_transaction_state(client_id).await == TransactionState::Failed {
342+
// Check if we're in a failed transaction and block non-transaction
343+
// commands
344+
if client.transaction_status() == TransactionStatus::Error {
435345
return Err(PgWireError::UserError(Box::new(
436346
pgwire::error::ErrorInfo::new(
437347
"ERROR".to_string(),
@@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService {
447357
let df = match df_result {
448358
Ok(df) => df,
449359
Err(e) => {
450-
// If we're in a transaction and a query fails, mark transaction as failed
451-
let client_id = Self::get_client_id(client);
452-
if self.get_transaction_state(client_id).await == TransactionState::Active {
453-
self.update_transaction_state(client_id, TransactionState::Failed)
454-
.await;
455-
}
456360
return Err(PgWireError::ApiError(Box::new(e)));
457361
}
458362
};
@@ -557,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService {
557461
.to_lowercase()
558462
.trim()
559463
.to_string();
560-
log::debug!("Received execute extended query: {}", query); // Log for debugging
464+
log::debug!("Received execute extended query: {query}"); // Log for debugging
561465

562466
// Check permissions for the query (skip for SET and SHOW statements)
563467
if !query.starts_with("set") && !query.starts_with("show") {
@@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService {
580484
return Ok(resp);
581485
}
582486

583-
// Check if we're in a failed transaction and block non-transaction commands
584-
let client_id = Self::get_client_id(client);
585-
if self.get_transaction_state(client_id).await == TransactionState::Failed {
487+
// Check if we're in a failed transaction and block non-transaction
488+
// commands
489+
if client.transaction_status() == TransactionStatus::Error {
586490
return Err(PgWireError::UserError(Box::new(
587491
pgwire::error::ErrorInfo::new(
588492
"ERROR".to_string(),
@@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService {
605509
let dataframe = match self.session_context.execute_logical_plan(plan).await {
606510
Ok(df) => df,
607511
Err(e) => {
608-
// If we're in a transaction and a query fails, mark transaction as failed
609-
let client_id = Self::get_client_id(client);
610-
if self.get_transaction_state(client_id).await == TransactionState::Active {
611-
self.update_transaction_state(client_id, TransactionState::Failed)
612-
.await;
613-
}
614512
return Err(PgWireError::ApiError(Box::new(e)));
615513
}
616514
};
@@ -633,7 +531,7 @@ impl QueryParser for Parser {
633531
sql: &str,
634532
_types: &[Type],
635533
) -> PgWireResult<Self::Statement> {
636-
log::debug!("Received parse extended query: {}", sql); // Log for debugging
534+
log::debug!("Received parse extended query: {sql}"); // Log for debugging
637535
let context = &self.session_context;
638536
let state = context.state();
639537
let logical_plan = state
@@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
654552
types.sort_by(|a, b| a.0.cmp(b.0));
655553
types.into_iter().map(|pt| pt.1.as_ref()).collect()
656554
}
657-
658-
#[cfg(test)]
659-
mod tests {
660-
use super::*;
661-
use datafusion::prelude::SessionContext;
662-
663-
#[tokio::test]
664-
async fn test_transaction_isolation() {
665-
let session_context = Arc::new(SessionContext::new());
666-
let auth_manager = Arc::new(AuthManager::new());
667-
let service = DfSessionService::new(session_context, auth_manager);
668-
669-
// Simulate two different connection IDs
670-
let client_id_1 = 1001;
671-
let client_id_2 = 1002;
672-
673-
// Client 1 starts a transaction
674-
service
675-
.update_transaction_state(client_id_1, TransactionState::Active)
676-
.await;
677-
678-
// Client 2 starts a transaction
679-
service
680-
.update_transaction_state(client_id_2, TransactionState::Active)
681-
.await;
682-
683-
// Verify both have active transactions independently
684-
{
685-
let states = service.connection_states.read().await;
686-
assert_eq!(
687-
states.get(&client_id_1).map(|s| s.transaction_state),
688-
Some(TransactionState::Active)
689-
);
690-
assert_eq!(
691-
states.get(&client_id_2).map(|s| s.transaction_state),
692-
Some(TransactionState::Active)
693-
);
694-
}
695-
696-
// Client 1 fails a transaction
697-
service
698-
.update_transaction_state(client_id_1, TransactionState::Failed)
699-
.await;
700-
701-
// Verify client 1 is failed but client 2 is still active
702-
{
703-
let states = service.connection_states.read().await;
704-
assert_eq!(
705-
states.get(&client_id_1).map(|s| s.transaction_state),
706-
Some(TransactionState::Failed)
707-
);
708-
assert_eq!(
709-
states.get(&client_id_2).map(|s| s.transaction_state),
710-
Some(TransactionState::Active)
711-
);
712-
}
713-
714-
// Client 1 rollback
715-
service
716-
.update_transaction_state(client_id_1, TransactionState::None)
717-
.await;
718-
719-
// Client 2 commit
720-
service
721-
.update_transaction_state(client_id_2, TransactionState::None)
722-
.await;
723-
724-
// Verify both are back to None state
725-
{
726-
let states = service.connection_states.read().await;
727-
assert_eq!(
728-
states.get(&client_id_1).map(|s| s.transaction_state),
729-
Some(TransactionState::None)
730-
);
731-
assert_eq!(
732-
states.get(&client_id_2).map(|s| s.transaction_state),
733-
Some(TransactionState::None)
734-
);
735-
}
736-
}
737-
738-
#[tokio::test]
739-
async fn test_opportunistic_cleanup() {
740-
let session_context = Arc::new(SessionContext::new());
741-
let auth_manager = Arc::new(AuthManager::new());
742-
let service = DfSessionService::new(session_context, auth_manager);
743-
744-
// Add some connection states
745-
service
746-
.update_transaction_state(2001, TransactionState::Active)
747-
.await;
748-
service
749-
.update_transaction_state(2002, TransactionState::Failed)
750-
.await;
751-
752-
// Manually create an old connection
753-
{
754-
let mut states = service.connection_states.write().await;
755-
states.insert(
756-
2003,
757-
ConnectionState {
758-
transaction_state: TransactionState::Active,
759-
last_activity: Instant::now() - Duration::from_secs(7200), // 2 hours old
760-
},
761-
);
762-
}
763-
764-
// Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
765-
service.cleanup_counter.store(99, Ordering::Relaxed);
766-
767-
// First update sets counter to 100 (99 + 1)
768-
service
769-
.update_transaction_state(2004, TransactionState::Active)
770-
.await;
771-
772-
// This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
773-
service
774-
.update_transaction_state(2005, TransactionState::Active)
775-
.await;
776-
777-
// Verify only the old connection was removed (cleanup is now inline, no wait needed)
778-
{
779-
let states = service.connection_states.read().await;
780-
assert!(states.contains_key(&2001));
781-
assert!(states.contains_key(&2002));
782-
assert!(!states.contains_key(&2003)); // Old connection should be removed
783-
assert!(states.contains_key(&2004));
784-
assert!(states.contains_key(&2005));
785-
}
786-
}
787-
}

0 commit comments

Comments
 (0)