1
1
use std:: collections:: HashMap ;
2
- use std:: hash:: { Hash , Hasher } ;
3
2
use std:: sync:: Arc ;
4
3
5
4
use crate :: auth:: { AuthManager , Permission , ResourceType } ;
@@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser;
19
18
use pgwire:: api:: stmt:: StoredStatement ;
20
19
use pgwire:: api:: { ClientInfo , PgWireServerHandlers , Type } ;
21
20
use pgwire:: error:: { PgWireError , PgWireResult } ;
22
- use std:: sync:: atomic:: { AtomicU64 , Ordering } ;
23
- use std:: time:: { Duration , Instant } ;
24
- use tokio:: sync:: { Mutex , RwLock } ;
21
+ use pgwire:: messages:: response:: TransactionStatus ;
22
+ use tokio:: sync:: Mutex ;
25
23
26
24
use arrow_pg:: datatypes:: df;
27
25
use arrow_pg:: datatypes:: { arrow_schema_to_pg_fields, into_pg_type} ;
28
26
29
- #[ derive( Debug , Clone , Copy , PartialEq ) ]
30
- pub enum TransactionState {
31
- None ,
32
- Active ,
33
- Failed ,
34
- }
35
-
36
27
/// Simple startup handler that does no authentication
37
28
/// For production, use DfAuthSource with proper pgwire authentication handlers
38
29
pub struct SimpleStartupHandler ;
@@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory {
66
57
}
67
58
}
68
59
69
- /// Per-connection transaction state storage
70
- /// We use a hash of both PID and secret key as the connection identifier for better uniqueness
71
- pub type ConnectionId = u64 ;
72
-
73
- #[ derive( Debug , Clone ) ]
74
- struct ConnectionState {
75
- transaction_state : TransactionState ,
76
- last_activity : Instant ,
77
- }
78
-
79
- type ConnectionStates = Arc < RwLock < HashMap < ConnectionId , ConnectionState > > > ;
80
-
81
60
/// The pgwire handler backed by a datafusion `SessionContext`
82
61
pub struct DfSessionService {
83
62
session_context : Arc < SessionContext > ,
84
63
parser : Arc < Parser > ,
85
64
timezone : Arc < Mutex < String > > ,
86
- connection_states : ConnectionStates ,
87
65
auth_manager : Arc < AuthManager > ,
88
- cleanup_counter : AtomicU64 ,
89
66
}
90
67
91
68
impl DfSessionService {
@@ -100,57 +77,10 @@ impl DfSessionService {
100
77
session_context,
101
78
parser,
102
79
timezone : Arc :: new ( Mutex :: new ( "UTC" . to_string ( ) ) ) ,
103
- connection_states : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
104
80
auth_manager,
105
- cleanup_counter : AtomicU64 :: new ( 0 ) ,
106
- }
107
- }
108
-
109
- async fn get_transaction_state ( & self , client_id : ConnectionId ) -> TransactionState {
110
- self . connection_states
111
- . read ( )
112
- . await
113
- . get ( & client_id)
114
- . map ( |s| s. transaction_state )
115
- . unwrap_or ( TransactionState :: None )
116
- }
117
-
118
- async fn update_transaction_state ( & self , client_id : ConnectionId , new_state : TransactionState ) {
119
- let mut states = self . connection_states . write ( ) . await ;
120
-
121
- // Update or insert state using entry API
122
- states
123
- . entry ( client_id)
124
- . and_modify ( |s| {
125
- s. transaction_state = new_state;
126
- s. last_activity = Instant :: now ( ) ;
127
- } )
128
- . or_insert ( ConnectionState {
129
- transaction_state : new_state,
130
- last_activity : Instant :: now ( ) ,
131
- } ) ;
132
-
133
- // Inline cleanup every 100 operations
134
- if self . cleanup_counter . fetch_add ( 1 , Ordering :: Relaxed ) % 100 == 0 {
135
- let cutoff = Instant :: now ( ) - Duration :: from_secs ( 3600 ) ;
136
- states. retain ( |_, state| state. last_activity > cutoff) ;
137
81
}
138
82
}
139
83
140
- fn get_client_id < C : ClientInfo > ( client : & C ) -> ConnectionId {
141
- // Use a hash of PID, secret key, and socket address for better uniqueness
142
- let ( pid, secret) = client. pid_and_secret_key ( ) ;
143
- let socket_addr = client. socket_addr ( ) ;
144
-
145
- // Create a hash of all identifying values
146
- let mut hasher = std:: collections:: hash_map:: DefaultHasher :: new ( ) ;
147
- pid. hash ( & mut hasher) ;
148
- secret. hash ( & mut hasher) ;
149
- socket_addr. hash ( & mut hasher) ;
150
-
151
- hasher. finish ( )
152
- }
153
-
154
84
/// Check if the current user has permission to execute a query
155
85
async fn check_query_permission < C > ( & self , client : & C , query : & str ) -> PgWireResult < ( ) >
156
86
where
@@ -290,24 +220,15 @@ impl DfSessionService {
290
220
where
291
221
C : ClientInfo ,
292
222
{
293
- let client_id = Self :: get_client_id ( client) ;
294
-
295
223
// Transaction handling based on pgwire example:
296
224
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
297
225
match query_lower. trim ( ) {
298
226
"begin" | "begin transaction" | "begin work" | "start transaction" => {
299
- match self . get_transaction_state ( client_id) . await {
300
- TransactionState :: None => {
301
- self . update_transaction_state ( client_id, TransactionState :: Active )
302
- . await ;
303
- Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
304
- }
305
- TransactionState :: Active => {
306
- // Already in transaction, PostgreSQL allows this but issues a warning
307
- // For simplicity, we'll just return BEGIN again
227
+ match client. transaction_status ( ) {
228
+ TransactionStatus :: Idle | TransactionStatus :: Transaction => {
308
229
Ok ( Some ( Response :: TransactionStart ( Tag :: new ( "BEGIN" ) ) ) )
309
230
}
310
- TransactionState :: Failed => {
231
+ TransactionStatus :: Error => {
311
232
// Can't start new transaction from failed state
312
233
Err ( PgWireError :: UserError ( Box :: new (
313
234
pgwire:: error:: ErrorInfo :: new (
@@ -320,27 +241,16 @@ impl DfSessionService {
320
241
}
321
242
}
322
243
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
323
- match self . get_transaction_state ( client_id) . await {
324
- TransactionState :: Active => {
325
- self . update_transaction_state ( client_id, TransactionState :: None )
326
- . await ;
327
- Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
328
- }
329
- TransactionState :: None => {
330
- // PostgreSQL allows COMMIT outside transaction with warning
244
+ match client. transaction_status ( ) {
245
+ TransactionStatus :: Idle | TransactionStatus :: Transaction => {
331
246
Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "COMMIT" ) ) ) )
332
247
}
333
- TransactionState :: Failed => {
334
- // COMMIT in failed transaction is treated as ROLLBACK
335
- self . update_transaction_state ( client_id, TransactionState :: None )
336
- . await ;
248
+ TransactionStatus :: Error => {
337
249
Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
338
250
}
339
251
}
340
252
}
341
253
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
342
- self . update_transaction_state ( client_id, TransactionState :: None )
343
- . await ;
344
254
Ok ( Some ( Response :: TransactionEnd ( Tag :: new ( "ROLLBACK" ) ) ) )
345
255
}
346
256
_ => Ok ( None ) ,
@@ -399,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService {
399
309
C : ClientInfo + Unpin + Send + Sync ,
400
310
{
401
311
let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
402
- log:: debug!( "Received query: {}" , query ) ; // Log the query for debugging
312
+ log:: debug!( "Received query: {query}" ) ; // Log the query for debugging
403
313
404
314
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
405
315
if !query_lower. starts_with ( "set" )
@@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService {
429
339
return Ok ( vec ! [ resp] ) ;
430
340
}
431
341
432
- // Check if we're in a failed transaction and block non-transaction commands
433
- let client_id = Self :: get_client_id ( client ) ;
434
- if self . get_transaction_state ( client_id ) . await == TransactionState :: Failed {
342
+ // Check if we're in a failed transaction and block non-transaction
343
+ // commands
344
+ if client . transaction_status ( ) == TransactionStatus :: Error {
435
345
return Err ( PgWireError :: UserError ( Box :: new (
436
346
pgwire:: error:: ErrorInfo :: new (
437
347
"ERROR" . to_string ( ) ,
@@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService {
447
357
let df = match df_result {
448
358
Ok ( df) => df,
449
359
Err ( e) => {
450
- // If we're in a transaction and a query fails, mark transaction as failed
451
- let client_id = Self :: get_client_id ( client) ;
452
- if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
453
- self . update_transaction_state ( client_id, TransactionState :: Failed )
454
- . await ;
455
- }
456
360
return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
457
361
}
458
362
} ;
@@ -557,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService {
557
461
. to_lowercase ( )
558
462
. trim ( )
559
463
. to_string ( ) ;
560
- log:: debug!( "Received execute extended query: {}" , query ) ; // Log for debugging
464
+ log:: debug!( "Received execute extended query: {query}" ) ; // Log for debugging
561
465
562
466
// Check permissions for the query (skip for SET and SHOW statements)
563
467
if !query. starts_with ( "set" ) && !query. starts_with ( "show" ) {
@@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService {
580
484
return Ok ( resp) ;
581
485
}
582
486
583
- // Check if we're in a failed transaction and block non-transaction commands
584
- let client_id = Self :: get_client_id ( client ) ;
585
- if self . get_transaction_state ( client_id ) . await == TransactionState :: Failed {
487
+ // Check if we're in a failed transaction and block non-transaction
488
+ // commands
489
+ if client . transaction_status ( ) == TransactionStatus :: Error {
586
490
return Err ( PgWireError :: UserError ( Box :: new (
587
491
pgwire:: error:: ErrorInfo :: new (
588
492
"ERROR" . to_string ( ) ,
@@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService {
605
509
let dataframe = match self . session_context . execute_logical_plan ( plan) . await {
606
510
Ok ( df) => df,
607
511
Err ( e) => {
608
- // If we're in a transaction and a query fails, mark transaction as failed
609
- let client_id = Self :: get_client_id ( client) ;
610
- if self . get_transaction_state ( client_id) . await == TransactionState :: Active {
611
- self . update_transaction_state ( client_id, TransactionState :: Failed )
612
- . await ;
613
- }
614
512
return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
615
513
}
616
514
} ;
@@ -633,7 +531,7 @@ impl QueryParser for Parser {
633
531
sql : & str ,
634
532
_types : & [ Type ] ,
635
533
) -> PgWireResult < Self :: Statement > {
636
- log:: debug!( "Received parse extended query: {}" , sql ) ; // Log for debugging
534
+ log:: debug!( "Received parse extended query: {sql}" ) ; // Log for debugging
637
535
let context = & self . session_context ;
638
536
let state = context. state ( ) ;
639
537
let logical_plan = state
@@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
654
552
types. sort_by ( |a, b| a. 0 . cmp ( b. 0 ) ) ;
655
553
types. into_iter ( ) . map ( |pt| pt. 1 . as_ref ( ) ) . collect ( )
656
554
}
657
-
658
- #[ cfg( test) ]
659
- mod tests {
660
- use super :: * ;
661
- use datafusion:: prelude:: SessionContext ;
662
-
663
- #[ tokio:: test]
664
- async fn test_transaction_isolation ( ) {
665
- let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
666
- let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
667
- let service = DfSessionService :: new ( session_context, auth_manager) ;
668
-
669
- // Simulate two different connection IDs
670
- let client_id_1 = 1001 ;
671
- let client_id_2 = 1002 ;
672
-
673
- // Client 1 starts a transaction
674
- service
675
- . update_transaction_state ( client_id_1, TransactionState :: Active )
676
- . await ;
677
-
678
- // Client 2 starts a transaction
679
- service
680
- . update_transaction_state ( client_id_2, TransactionState :: Active )
681
- . await ;
682
-
683
- // Verify both have active transactions independently
684
- {
685
- let states = service. connection_states . read ( ) . await ;
686
- assert_eq ! (
687
- states. get( & client_id_1) . map( |s| s. transaction_state) ,
688
- Some ( TransactionState :: Active )
689
- ) ;
690
- assert_eq ! (
691
- states. get( & client_id_2) . map( |s| s. transaction_state) ,
692
- Some ( TransactionState :: Active )
693
- ) ;
694
- }
695
-
696
- // Client 1 fails a transaction
697
- service
698
- . update_transaction_state ( client_id_1, TransactionState :: Failed )
699
- . await ;
700
-
701
- // Verify client 1 is failed but client 2 is still active
702
- {
703
- let states = service. connection_states . read ( ) . await ;
704
- assert_eq ! (
705
- states. get( & client_id_1) . map( |s| s. transaction_state) ,
706
- Some ( TransactionState :: Failed )
707
- ) ;
708
- assert_eq ! (
709
- states. get( & client_id_2) . map( |s| s. transaction_state) ,
710
- Some ( TransactionState :: Active )
711
- ) ;
712
- }
713
-
714
- // Client 1 rollback
715
- service
716
- . update_transaction_state ( client_id_1, TransactionState :: None )
717
- . await ;
718
-
719
- // Client 2 commit
720
- service
721
- . update_transaction_state ( client_id_2, TransactionState :: None )
722
- . await ;
723
-
724
- // Verify both are back to None state
725
- {
726
- let states = service. connection_states . read ( ) . await ;
727
- assert_eq ! (
728
- states. get( & client_id_1) . map( |s| s. transaction_state) ,
729
- Some ( TransactionState :: None )
730
- ) ;
731
- assert_eq ! (
732
- states. get( & client_id_2) . map( |s| s. transaction_state) ,
733
- Some ( TransactionState :: None )
734
- ) ;
735
- }
736
- }
737
-
738
- #[ tokio:: test]
739
- async fn test_opportunistic_cleanup ( ) {
740
- let session_context = Arc :: new ( SessionContext :: new ( ) ) ;
741
- let auth_manager = Arc :: new ( AuthManager :: new ( ) ) ;
742
- let service = DfSessionService :: new ( session_context, auth_manager) ;
743
-
744
- // Add some connection states
745
- service
746
- . update_transaction_state ( 2001 , TransactionState :: Active )
747
- . await ;
748
- service
749
- . update_transaction_state ( 2002 , TransactionState :: Failed )
750
- . await ;
751
-
752
- // Manually create an old connection
753
- {
754
- let mut states = service. connection_states . write ( ) . await ;
755
- states. insert (
756
- 2003 ,
757
- ConnectionState {
758
- transaction_state : TransactionState :: Active ,
759
- last_activity : Instant :: now ( ) - Duration :: from_secs ( 7200 ) , // 2 hours old
760
- } ,
761
- ) ;
762
- }
763
-
764
- // Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
765
- service. cleanup_counter . store ( 99 , Ordering :: Relaxed ) ;
766
-
767
- // First update sets counter to 100 (99 + 1)
768
- service
769
- . update_transaction_state ( 2004 , TransactionState :: Active )
770
- . await ;
771
-
772
- // This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
773
- service
774
- . update_transaction_state ( 2005 , TransactionState :: Active )
775
- . await ;
776
-
777
- // Verify only the old connection was removed (cleanup is now inline, no wait needed)
778
- {
779
- let states = service. connection_states . read ( ) . await ;
780
- assert ! ( states. contains_key( & 2001 ) ) ;
781
- assert ! ( states. contains_key( & 2002 ) ) ;
782
- assert ! ( !states. contains_key( & 2003 ) ) ; // Old connection should be removed
783
- assert ! ( states. contains_key( & 2004 ) ) ;
784
- assert ! ( states. contains_key( & 2005 ) ) ;
785
- }
786
- }
787
- }
0 commit comments