Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl PgWireServerHandlers for HandlerFactory {
}
}

/// The pgwire handler backed by a datafusion `SessionContext`
pub struct DfSessionService {
session_context: Arc<SessionContext>,
parser: Arc<Parser>,
Expand Down
16 changes: 14 additions & 2 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use datafusion::prelude::SessionContext;
pub mod auth;
use getset::{Getters, Setters, WithSetters};
use log::{info, warn};
use pgwire::api::PgWireServerHandlers;
use pgwire::tokio::process_socket;
use rustls_pemfile::{certs, pkcs8_private_keys};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
Expand Down Expand Up @@ -81,6 +82,18 @@ pub async fn serve(
// Create the handler factory with authentication
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));

serve_with_handlers(factory, opts).await
}

/// Serve with custom pgwire handlers
///
/// This function allows you to rewrite some of the built-in logic including
/// authentication and query processing. You can Implement your own
/// `PgWireServerHandlers` by reusing `DfSessionService`.
pub async fn serve_with_handlers(
handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
opts: &ServerOptions,
) -> Result<(), std::io::Error> {
// Set up TLS if configured
let tls_acceptor =
if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
Expand Down Expand Up @@ -112,9 +125,8 @@ pub async fn serve(
loop {
match listener.accept().await {
Ok((socket, _addr)) => {
let factory_ref = factory.clone();
let factory_ref = handlers.clone();
let tls_acceptor_ref = tls_acceptor.clone();
// Connection accepted from {addr} - log appropriately based on your logging strategy

tokio::spawn(async move {
if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
Expand Down
Loading