Skip to content

Commit 651dce7

Browse files
authored
feat: provide function to customize pgwire handlers (#103)
Signed-off-by: Ning Sun <[email protected]>
1 parent a06a293 commit 651dce7

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ impl PgWireServerHandlers for HandlerFactory {
6363
}
6464
}
6565

66+
/// The pgwire handler backed by a datafusion `SessionContext`
6667
pub struct DfSessionService {
6768
session_context: Arc<SessionContext>,
6869
parser: Arc<Parser>,

datafusion-postgres/src/lib.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use datafusion::prelude::SessionContext;
1010
pub mod auth;
1111
use getset::{Getters, Setters, WithSetters};
1212
use log::{info, warn};
13+
use pgwire::api::PgWireServerHandlers;
1314
use pgwire::tokio::process_socket;
1415
use rustls_pemfile::{certs, pkcs8_private_keys};
1516
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
@@ -81,6 +82,18 @@ pub async fn serve(
8182
// Create the handler factory with authentication
8283
let factory = Arc::new(HandlerFactory::new(session_context, auth_manager));
8384

85+
serve_with_handlers(factory, opts).await
86+
}
87+
88+
/// Serve with custom pgwire handlers
89+
///
90+
/// This function allows you to rewrite some of the built-in logic including
91+
/// authentication and query processing. You can Implement your own
92+
/// `PgWireServerHandlers` by reusing `DfSessionService`.
93+
pub async fn serve_with_handlers(
94+
handlers: Arc<impl PgWireServerHandlers + Sync + Send + 'static>,
95+
opts: &ServerOptions,
96+
) -> Result<(), std::io::Error> {
8497
// Set up TLS if configured
8598
let tls_acceptor =
8699
if let (Some(cert_path), Some(key_path)) = (&opts.tls_cert_path, &opts.tls_key_path) {
@@ -112,9 +125,8 @@ pub async fn serve(
112125
loop {
113126
match listener.accept().await {
114127
Ok((socket, _addr)) => {
115-
let factory_ref = factory.clone();
128+
let factory_ref = handlers.clone();
116129
let tls_acceptor_ref = tls_acceptor.clone();
117-
// Connection accepted from {addr} - log appropriately based on your logging strategy
118130

119131
tokio::spawn(async move {
120132
if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {

0 commit comments

Comments
 (0)