diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 06d7c98..2e64435 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -320,9 +320,15 @@ impl DfSessionService { match query_lower.trim() { "begin" | "begin transaction" | "begin work" | "start transaction" => { match client.transaction_status() { - TransactionStatus::Idle | TransactionStatus::Transaction => { + TransactionStatus::Idle => { Ok(Some(Response::TransactionStart(Tag::new("BEGIN")))) } + TransactionStatus::Transaction => { + // PostgreSQL behavior: ignore nested BEGIN, just return SUCCESS + // This matches PostgreSQL's handling of nested transaction blocks + log::warn!("BEGIN command ignored: already in transaction block"); + Ok(Some(Response::Execution(Tag::new("BEGIN")))) + } TransactionStatus::Error => { // Can't start new transaction from failed state Err(PgWireError::UserError(Box::new( @@ -417,6 +423,16 @@ impl SimpleQueryHandler for DfSessionService { C: ClientInfo + Unpin + Send + Sync, { log::debug!("Received query: {query}"); // Log the query for debugging + + // Check for transaction commands early to avoid SQL parsing issues with ABORT + let query_lower = query.to_lowercase().trim().to_string(); + if let Some(resp) = self + .try_respond_transaction_statements(client, &query_lower) + .await? + { + return Ok(vec![resp]); + } + let mut statements = parse(query).map_err(|e| PgWireError::ApiError(Box::new(e)))?; // TODO: deal with multiple statements @@ -449,13 +465,6 @@ impl SimpleQueryHandler for DfSessionService { return Ok(vec![resp]); } - if let Some(resp) = self - .try_respond_transaction_statements(client, &query_lower) - .await? - { - return Ok(vec![resp]); - } - if let Some(resp) = self .try_respond_show_statements(client, &query_lower) .await? @@ -697,8 +706,38 @@ impl QueryParser for Parser { sql: &str, _types: &[Type], ) -> PgWireResult { - log::debug!("Received parse extended query: {sql}"); // Log for - // debugging + log::debug!("Received parse extended query: {sql}"); // Log for debugging + + // Check for transaction commands that shouldn't be parsed by DataFusion + let sql_lower = sql.to_lowercase(); + let sql_trimmed = sql_lower.trim(); + if matches!( + sql_trimmed, + "begin" + | "begin transaction" + | "begin work" + | "start transaction" + | "commit" + | "commit transaction" + | "commit work" + | "end" + | "end transaction" + | "rollback" + | "rollback transaction" + | "rollback work" + | "abort" + ) { + // Return a dummy plan for transaction commands - they'll be handled by transaction handler + let dummy_schema = datafusion::common::DFSchema::empty(); + let dummy_plan = datafusion::logical_expr::LogicalPlan::EmptyRelation( + datafusion::logical_expr::EmptyRelation { + produce_one_row: false, + schema: std::sync::Arc::new(dummy_schema), + }, + ); + return Ok((sql.to_string(), dummy_plan)); + } + let mut statements = parse(sql).map_err(|e| PgWireError::ApiError(Box::new(e)))?; let mut statement = statements.remove(0);