diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 8774219..fc3ca35 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -58,27 +58,18 @@ pub struct DfSessionService { session_context: Arc, parser: Arc, timezone: Arc>, - catalog_name: String, } impl DfSessionService { - pub fn new(session_context: SessionContext, catalog_name: Option) -> DfSessionService { + pub fn new(session_context: SessionContext) -> DfSessionService { let session_context = Arc::new(session_context); let parser = Arc::new(Parser { session_context: session_context.clone(), }); - let catalog_name = catalog_name.unwrap_or_else(|| { - session_context - .catalog_names() - .first() - .cloned() - .unwrap_or_else(|| "datafusion".to_string()) - }); DfSessionService { session_context, parser, timezone: Arc::new(Mutex::new("UTC".to_string())), - catalog_name, } } @@ -103,35 +94,40 @@ impl DfSessionService { // Mock pg_namespace response async fn mock_pg_namespace<'a>(&self) -> PgWireResult> { - let fields = vec![FieldInfo::new( + let fields = Arc::new(vec![FieldInfo::new( "nspname".to_string(), None, None, Type::VARCHAR, FieldFormat::Text, - )]; + )]); - let row = { - let mut encoder = pgwire::api::results::DataRowEncoder::new(Arc::new(fields.clone())); - encoder.encode_field(&Some(&self.catalog_name))?; // Return catalog_name as a schema - encoder.finish() - }; - - let row_stream = futures::stream::once(async move { row }); - Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) + let fields_ref = fields.clone(); + let rows = self + .session_context + .catalog_names() + .into_iter() + .map(move |name| { + let mut encoder = pgwire::api::results::DataRowEncoder::new(fields_ref.clone()); + encoder.encode_field(&Some(&name))?; // Return catalog_name as a schema + encoder.finish() + }); + + let row_stream = futures::stream::iter(rows); + Ok(QueryResponse::new(fields.clone(), Box::pin(row_stream))) } async fn try_respond_set_time_zone<'a>( &self, query_lower: &str, - ) -> PgWireResult>>> { + ) -> PgWireResult>> { if query_lower.starts_with("set time zone") { let parts: Vec<&str> = query_lower.split_whitespace().collect(); if parts.len() >= 4 { let tz = parts[3].trim_matches('"'); let mut timezone = self.timezone.lock().await; *timezone = tz.to_string(); - Ok(Some(vec![Response::Execution(Tag::new("SET"))])) + Ok(Some(Response::Execution(Tag::new("SET")))) } else { Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( @@ -149,32 +145,33 @@ impl DfSessionService { async fn try_respond_show_statements<'a>( &self, query_lower: &str, - ) -> PgWireResult>>> { + ) -> PgWireResult>> { if query_lower.starts_with("show ") { - match query_lower { + match query_lower.strip_suffix(";").unwrap_or(query_lower) { "show time zone" => { let timezone = self.timezone.lock().await.clone(); let resp = Self::mock_show_response("TimeZone", &timezone)?; - Ok(Some(vec![Response::Query(resp)])) + Ok(Some(Response::Query(resp))) } "show server_version" => { let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; - Ok(Some(vec![Response::Query(resp)])) + Ok(Some(Response::Query(resp))) } "show transaction_isolation" => { let resp = Self::mock_show_response("transaction_isolation", "read uncommitted")?; - Ok(Some(vec![Response::Query(resp)])) + Ok(Some(Response::Query(resp))) } "show catalogs" => { let catalogs = self.session_context.catalog_names(); let value = catalogs.join(", "); let resp = Self::mock_show_response("Catalogs", &value)?; - Ok(Some(vec![Response::Query(resp)])) + Ok(Some(Response::Query(resp))) } "show search_path" => { - let resp = Self::mock_show_response("search_path", &self.catalog_name)?; - Ok(Some(vec![Response::Query(resp)])) + let default_catalog = "datafusion"; + let resp = Self::mock_show_response("search_path", default_catalog)?; + Ok(Some(Response::Query(resp))) } _ => Err(PgWireError::UserError(Box::new( pgwire::error::ErrorInfo::new( @@ -192,31 +189,31 @@ impl DfSessionService { async fn try_respond_information_schema<'a>( &self, query_lower: &str, - ) -> PgWireResult>>> { + ) -> PgWireResult>> { if query_lower.contains("information_schema.schemata") { let df = schemata_df(&self.session_context) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(vec![Response::Query(resp)])); + return Ok(Some(Response::Query(resp))); } else if query_lower.contains("information_schema.tables") { let df = tables_df(&self.session_context) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(vec![Response::Query(resp)])); + return Ok(Some(Response::Query(resp))); } else if query_lower.contains("information_schema.columns") { let df = columns_df(&self.session_context) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let resp = datatypes::encode_dataframe(df, &Format::UnifiedText).await?; - return Ok(Some(vec![Response::Query(resp)])); + return Ok(Some(Response::Query(resp))); } // Handle pg_catalog.pg_namespace for pgcli compatibility if query_lower.contains("pg_catalog.pg_namespace") { let resp = self.mock_pg_namespace().await?; - return Ok(Some(vec![Response::Query(resp)])); + return Ok(Some(Response::Query(resp))); } Ok(None) @@ -233,15 +230,15 @@ impl SimpleQueryHandler for DfSessionService { log::debug!("Received query: {}", query); // Log the query for debugging if let Some(resp) = self.try_respond_set_time_zone(&query_lower).await? { - return Ok(resp); + return Ok(vec![resp]); } if let Some(resp) = self.try_respond_show_statements(&query_lower).await? { - return Ok(resp); + return Ok(vec![resp]); } if let Some(resp) = self.try_respond_information_schema(&query_lower).await? { - return Ok(resp); + return Ok(vec![resp]); } let df = self @@ -352,67 +349,12 @@ impl ExtendedQueryHandler for DfSessionService { .to_string(); log::debug!("Received extended query: {}", query); // Log for debugging - if query.starts_with("show ") { - match query.as_str() { - "show time zone" => { - let timezone = self.timezone.lock().await.clone(); - let resp = Self::mock_show_response("TimeZone", &timezone)?; - return Ok(Response::Query(resp)); - } - "show server_version" => { - let resp = Self::mock_show_response("server_version", "15.0 (DataFusion)")?; - return Ok(Response::Query(resp)); - } - "show transaction_isolation" => { - let resp = - Self::mock_show_response("transaction_isolation", "read uncommitted")?; - return Ok(Response::Query(resp)); - } - "show catalogs" => { - let catalogs = self.session_context.catalog_names(); - let value = catalogs.join(", "); - let resp = Self::mock_show_response("Catalogs", &value)?; - return Ok(Response::Query(resp)); - } - "show search_path" => { - let resp = Self::mock_show_response("search_path", &self.catalog_name)?; - return Ok(Response::Query(resp)); - } - _ => { - return Err(PgWireError::UserError(Box::new( - pgwire::error::ErrorInfo::new( - "ERROR".to_string(), - "42704".to_string(), - format!("Unrecognized SHOW command: {}", query), - ), - ))); - } - } - } - - if query.contains("information_schema.schemata") { - let df = schemata_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; - return Ok(Response::Query(resp)); - } else if query.contains("information_schema.tables") { - let df = tables_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; - return Ok(Response::Query(resp)); - } else if query.contains("information_schema.columns") { - let df = columns_df(&self.session_context) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = datatypes::encode_dataframe(df, &portal.result_column_format).await?; - return Ok(Response::Query(resp)); + if let Some(resp) = self.try_respond_show_statements(&query).await? { + return Ok(resp); } - if query.contains("pg_catalog.pg_namespace") { - let resp = self.mock_pg_namespace().await?; - return Ok(Response::Query(resp)); + if let Some(resp) = self.try_respond_information_schema(&query).await? { + return Ok(resp); } let plan = &portal.statement.statement; diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index 7f16ba2..af1bbc4 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -39,16 +39,9 @@ pub async fn serve( session_context: SessionContext, opts: &ServerOptions, ) -> Result<(), std::io::Error> { - // Get the first catalog name from the session context - let catalog_name = session_context - .catalog_names() // Fixed: Removed .catalog_list() - .first() - .cloned(); - // Create the handler factory with the session context and catalog name let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new( session_context, - catalog_name, )))); // Bind to the specified host and port @@ -58,15 +51,20 @@ pub async fn serve( // Accept incoming connections loop { - if let Ok((socket, addr)) = listener.accept().await { - let factory_ref = factory.clone(); - println!("Accepted connection from {}", addr); + match listener.accept().await { + Ok((socket, addr)) => { + let factory_ref = factory.clone(); + println!("Accepted connection from {}", addr); - tokio::spawn(async move { - if let Err(e) = process_socket(socket, None, factory_ref).await { - eprintln!("Error processing socket: {}", e); - } - }); - }; + tokio::spawn(async move { + if let Err(e) = process_socket(socket, None, factory_ref).await { + eprintln!("Error processing socket: {}", e); + } + }); + } + Err(e) => { + eprintln!("Error accept socket: {}", e); + } + } } }