Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 17 additions & 250 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;

use crate::auth::{AuthManager, Permission, ResourceType};
Expand All @@ -19,20 +18,12 @@ use pgwire::api::stmt::QueryParser;
use pgwire::api::stmt::StoredStatement;
use pgwire::api::{ClientInfo, PgWireServerHandlers, Type};
use pgwire::error::{PgWireError, PgWireResult};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use pgwire::messages::response::TransactionStatus;
use tokio::sync::Mutex;

use arrow_pg::datatypes::df;
use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type};

#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TransactionState {
None,
Active,
Failed,
}

/// Simple startup handler that does no authentication
/// For production, use DfAuthSource with proper pgwire authentication handlers
pub struct SimpleStartupHandler;
Expand Down Expand Up @@ -66,26 +57,12 @@ impl PgWireServerHandlers for HandlerFactory {
}
}

/// Per-connection transaction state storage
/// We use a hash of both PID and secret key as the connection identifier for better uniqueness
pub type ConnectionId = u64;

#[derive(Debug, Clone)]
struct ConnectionState {
transaction_state: TransactionState,
last_activity: Instant,
}

type ConnectionStates = Arc<RwLock<HashMap<ConnectionId, ConnectionState>>>;

/// The pgwire handler backed by a datafusion `SessionContext`
pub struct DfSessionService {
session_context: Arc<SessionContext>,
parser: Arc<Parser>,
timezone: Arc<Mutex<String>>,
connection_states: ConnectionStates,
auth_manager: Arc<AuthManager>,
cleanup_counter: AtomicU64,
}

impl DfSessionService {
Expand All @@ -100,57 +77,10 @@ impl DfSessionService {
session_context,
parser,
timezone: Arc::new(Mutex::new("UTC".to_string())),
connection_states: Arc::new(RwLock::new(HashMap::new())),
auth_manager,
cleanup_counter: AtomicU64::new(0),
}
}

async fn get_transaction_state(&self, client_id: ConnectionId) -> TransactionState {
self.connection_states
.read()
.await
.get(&client_id)
.map(|s| s.transaction_state)
.unwrap_or(TransactionState::None)
}

async fn update_transaction_state(&self, client_id: ConnectionId, new_state: TransactionState) {
let mut states = self.connection_states.write().await;

// Update or insert state using entry API
states
.entry(client_id)
.and_modify(|s| {
s.transaction_state = new_state;
s.last_activity = Instant::now();
})
.or_insert(ConnectionState {
transaction_state: new_state,
last_activity: Instant::now(),
});

// Inline cleanup every 100 operations
if self.cleanup_counter.fetch_add(1, Ordering::Relaxed) % 100 == 0 {
let cutoff = Instant::now() - Duration::from_secs(3600);
states.retain(|_, state| state.last_activity > cutoff);
}
}

fn get_client_id<C: ClientInfo>(client: &C) -> ConnectionId {
// Use a hash of PID, secret key, and socket address for better uniqueness
let (pid, secret) = client.pid_and_secret_key();
let socket_addr = client.socket_addr();

// Create a hash of all identifying values
let mut hasher = std::collections::hash_map::DefaultHasher::new();
pid.hash(&mut hasher);
secret.hash(&mut hasher);
socket_addr.hash(&mut hasher);

hasher.finish()
}

/// Check if the current user has permission to execute a query
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
where
Expand Down Expand Up @@ -290,24 +220,15 @@ impl DfSessionService {
where
C: ClientInfo,
{
let client_id = Self::get_client_id(client);

// Transaction handling based on pgwire example:
// https://github.com/sunng87/pgwire/blob/master/examples/transaction.rs#L57
match query_lower.trim() {
"begin" | "begin transaction" | "begin work" | "start transaction" => {
match self.get_transaction_state(client_id).await {
TransactionState::None => {
self.update_transaction_state(client_id, TransactionState::Active)
.await;
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
}
TransactionState::Active => {
// Already in transaction, PostgreSQL allows this but issues a warning
// For simplicity, we'll just return BEGIN again
match client.transaction_status() {
TransactionStatus::Idle | TransactionStatus::Transaction => {
Ok(Some(Response::TransactionStart(Tag::new("BEGIN"))))
}
TransactionState::Failed => {
TransactionStatus::Error => {
// Can't start new transaction from failed state
Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
Expand All @@ -320,27 +241,16 @@ impl DfSessionService {
}
}
"commit" | "commit transaction" | "commit work" | "end" | "end transaction" => {
match self.get_transaction_state(client_id).await {
TransactionState::Active => {
self.update_transaction_state(client_id, TransactionState::None)
.await;
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
}
TransactionState::None => {
// PostgreSQL allows COMMIT outside transaction with warning
match client.transaction_status() {
TransactionStatus::Idle | TransactionStatus::Transaction => {
Ok(Some(Response::TransactionEnd(Tag::new("COMMIT"))))
}
TransactionState::Failed => {
// COMMIT in failed transaction is treated as ROLLBACK
self.update_transaction_state(client_id, TransactionState::None)
.await;
TransactionStatus::Error => {
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
}
}
}
"rollback" | "rollback transaction" | "rollback work" | "abort" => {
self.update_transaction_state(client_id, TransactionState::None)
.await;
Ok(Some(Response::TransactionEnd(Tag::new("ROLLBACK"))))
}
_ => Ok(None),
Expand Down Expand Up @@ -399,7 +309,7 @@ impl SimpleQueryHandler for DfSessionService {
C: ClientInfo + Unpin + Send + Sync,
{
let query_lower = query.to_lowercase().trim().to_string();
log::debug!("Received query: {}", query); // Log the query for debugging
log::debug!("Received query: {query}"); // Log the query for debugging

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
Expand Down Expand Up @@ -429,9 +339,9 @@ impl SimpleQueryHandler for DfSessionService {
return Ok(vec![resp]);
}

// Check if we're in a failed transaction and block non-transaction commands
let client_id = Self::get_client_id(client);
if self.get_transaction_state(client_id).await == TransactionState::Failed {
// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
Expand All @@ -447,12 +357,6 @@ impl SimpleQueryHandler for DfSessionService {
let df = match df_result {
Ok(df) => df,
Err(e) => {
// If we're in a transaction and a query fails, mark transaction as failed
let client_id = Self::get_client_id(client);
if self.get_transaction_state(client_id).await == TransactionState::Active {
self.update_transaction_state(client_id, TransactionState::Failed)
.await;
}
return Err(PgWireError::ApiError(Box::new(e)));
}
};
Expand Down Expand Up @@ -557,7 +461,7 @@ impl ExtendedQueryHandler for DfSessionService {
.to_lowercase()
.trim()
.to_string();
log::debug!("Received execute extended query: {}", query); // Log for debugging
log::debug!("Received execute extended query: {query}"); // Log for debugging

// Check permissions for the query (skip for SET and SHOW statements)
if !query.starts_with("set") && !query.starts_with("show") {
Expand All @@ -580,9 +484,9 @@ impl ExtendedQueryHandler for DfSessionService {
return Ok(resp);
}

// Check if we're in a failed transaction and block non-transaction commands
let client_id = Self::get_client_id(client);
if self.get_transaction_state(client_id).await == TransactionState::Failed {
// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
Expand All @@ -605,12 +509,6 @@ impl ExtendedQueryHandler for DfSessionService {
let dataframe = match self.session_context.execute_logical_plan(plan).await {
Ok(df) => df,
Err(e) => {
// If we're in a transaction and a query fails, mark transaction as failed
let client_id = Self::get_client_id(client);
if self.get_transaction_state(client_id).await == TransactionState::Active {
self.update_transaction_state(client_id, TransactionState::Failed)
.await;
}
return Err(PgWireError::ApiError(Box::new(e)));
}
};
Expand All @@ -633,7 +531,7 @@ impl QueryParser for Parser {
sql: &str,
_types: &[Type],
) -> PgWireResult<Self::Statement> {
log::debug!("Received parse extended query: {}", sql); // Log for debugging
log::debug!("Received parse extended query: {sql}"); // Log for debugging
let context = &self.session_context;
let state = context.state();
let logical_plan = state
Expand All @@ -654,134 +552,3 @@ fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<
types.sort_by(|a, b| a.0.cmp(b.0));
types.into_iter().map(|pt| pt.1.as_ref()).collect()
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::prelude::SessionContext;

#[tokio::test]
async fn test_transaction_isolation() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);

// Simulate two different connection IDs
let client_id_1 = 1001;
let client_id_2 = 1002;

// Client 1 starts a transaction
service
.update_transaction_state(client_id_1, TransactionState::Active)
.await;

// Client 2 starts a transaction
service
.update_transaction_state(client_id_2, TransactionState::Active)
.await;

// Verify both have active transactions independently
{
let states = service.connection_states.read().await;
assert_eq!(
states.get(&client_id_1).map(|s| s.transaction_state),
Some(TransactionState::Active)
);
assert_eq!(
states.get(&client_id_2).map(|s| s.transaction_state),
Some(TransactionState::Active)
);
}

// Client 1 fails a transaction
service
.update_transaction_state(client_id_1, TransactionState::Failed)
.await;

// Verify client 1 is failed but client 2 is still active
{
let states = service.connection_states.read().await;
assert_eq!(
states.get(&client_id_1).map(|s| s.transaction_state),
Some(TransactionState::Failed)
);
assert_eq!(
states.get(&client_id_2).map(|s| s.transaction_state),
Some(TransactionState::Active)
);
}

// Client 1 rollback
service
.update_transaction_state(client_id_1, TransactionState::None)
.await;

// Client 2 commit
service
.update_transaction_state(client_id_2, TransactionState::None)
.await;

// Verify both are back to None state
{
let states = service.connection_states.read().await;
assert_eq!(
states.get(&client_id_1).map(|s| s.transaction_state),
Some(TransactionState::None)
);
assert_eq!(
states.get(&client_id_2).map(|s| s.transaction_state),
Some(TransactionState::None)
);
}
}

#[tokio::test]
async fn test_opportunistic_cleanup() {
let session_context = Arc::new(SessionContext::new());
let auth_manager = Arc::new(AuthManager::new());
let service = DfSessionService::new(session_context, auth_manager);

// Add some connection states
service
.update_transaction_state(2001, TransactionState::Active)
.await;
service
.update_transaction_state(2002, TransactionState::Failed)
.await;

// Manually create an old connection
{
let mut states = service.connection_states.write().await;
states.insert(
2003,
ConnectionState {
transaction_state: TransactionState::Active,
last_activity: Instant::now() - Duration::from_secs(7200), // 2 hours old
},
);
}

// Set cleanup counter to trigger cleanup on next update (fetch_add returns old value)
service.cleanup_counter.store(99, Ordering::Relaxed);

// First update sets counter to 100 (99 + 1)
service
.update_transaction_state(2004, TransactionState::Active)
.await;

// This should trigger cleanup (counter becomes 100, 100 % 100 == 0)
service
.update_transaction_state(2005, TransactionState::Active)
.await;

// Verify only the old connection was removed (cleanup is now inline, no wait needed)
{
let states = service.connection_states.read().await;
assert!(states.contains_key(&2001));
assert!(states.contains_key(&2002));
assert!(!states.contains_key(&2003)); // Old connection should be removed
assert!(states.contains_key(&2004));
assert!(states.contains_key(&2005));
}
}
}
Loading
Loading