From b4d285cc3d17b4b87ae88f843a535db3bbd76afc Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 13 Jun 2025 21:12:59 -0700 Subject: [PATCH 1/3] feat: initial commit for arrow-pg library --- Cargo.lock | 14 +- Cargo.toml | 6 +- arrow-pg/Cargo.toml | 20 + arrow-pg/src/lib.rs | 511 ++++++++++++++++++ .../encoder => arrow-pg/src}/list_encoder.rs | 57 +- .../src}/struct_encoder.rs | 22 +- datafusion-postgres/Cargo.toml | 4 +- 7 files changed, 576 insertions(+), 58 deletions(-) create mode 100644 arrow-pg/Cargo.toml create mode 100644 arrow-pg/src/lib.rs rename {datafusion-postgres/src/encoder => arrow-pg/src}/list_encoder.rs (90%) rename {datafusion-postgres/src/encoder => arrow-pg/src}/struct_encoder.rs (87%) diff --git a/Cargo.lock b/Cargo.lock index 79ae693..4571382 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -351,6 +351,18 @@ dependencies = [ "arrow-select", ] +[[package]] +name = "arrow-pg" +version = "0.1.0" +dependencies = [ + "arrow", + "bytes", + "chrono", + "pgwire", + "postgres-types", + "rust_decimal", +] + [[package]] name = "arrow-row" version = "55.1.0" @@ -1520,14 +1532,12 @@ version = "0.5.1" dependencies = [ "async-trait", "bytes", - "chrono", "datafusion", "futures", "getset", "log", "pgwire", "postgres-types", - "rust_decimal", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index 8d85415..a77c1db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["datafusion-postgres", "datafusion-postgres-cli"] +members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"] [workspace.package] version = "0.5.1" @@ -14,8 +14,10 @@ repository = "https://github.com/datafusion-contrib/datafusion-postgres/" documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] -pgwire = "0.30.2" +arrow = "55" datafusion = { version = "47", default-features = false } +pgwire = "0.30.2" +postgres-types = "0.2" tokio = { version = "1", default-features = false } [profile.release] diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml new file mode 100644 index 0000000..4a44aa5 --- /dev/null +++ b/arrow-pg/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "arrow-pg" +description = "Arrow data mapping and encoding/decoding for Postgres" +version = "0.1.0" +edition.workspace = true +license.workspace = true +authors.workspace = true +keywords.workspace = true +homepage.workspace = true +repository.workspace = true +documentation.workspace = true +readme = "../README.md" + +[dependencies] +arrow.workspace = true +bytes = "1" +chrono = { version = "0.4", features = ["std"] } +pgwire = { workspace = true, default-features = false } +postgres-types.workspace = true +rust_decimal = { version = "1.37", features = ["db-postgres"] } diff --git a/arrow-pg/src/lib.rs b/arrow-pg/src/lib.rs new file mode 100644 index 0000000..b742aad --- /dev/null +++ b/arrow-pg/src/lib.rs @@ -0,0 +1,511 @@ +use std::io::Write; +use std::str::FromStr; +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::*; +use bytes::BufMut; +use bytes::BytesMut; +use chrono::{NaiveDate, NaiveDateTime}; +use list_encoder::encode_list; +use pgwire::types::ToSqlText; +use postgres_types::{ToSql, Type}; +use rust_decimal::Decimal; +use struct_encoder::encode_struct; +use timezone::Tz; + +pub mod list_encoder; +pub mod struct_encoder; + +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +#[repr(i16)] +pub enum FieldFormat { + Text = 0, + Binary = 1, +} + +pub type ToSqlError = Box; +pub type Result = std::result::Result; + +pub trait Encoder { + fn encode_field_with_type_and_format( + &mut self, + value: &T, + data_type: &Type, + format: FieldFormat, + ) -> Result<()> + where + T: ToSql + ToSqlText + Sized; +} + +pub(crate) struct EncodedValue { + pub(crate) bytes: BytesMut, +} + +impl std::fmt::Debug for EncodedValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EncodedValue").finish() + } +} + +impl ToSql for EncodedValue { + fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result + where + Self: Sized, + { + out.writer().write_all(&self.bytes)?; + Ok(postgres_types::IsNull::No) + } + + fn accepts(_ty: &Type) -> bool + where + Self: Sized, + { + true + } + + fn to_sql_checked(&self, ty: &Type, out: &mut BytesMut) -> Result { + self.to_sql(ty, out) + } +} + +impl ToSqlText for EncodedValue { + fn to_sql_text(&self, _ty: &Type, out: &mut BytesMut) -> Result + where + Self: Sized, + { + out.writer().write_all(&self.bytes)?; + Ok(postgres_types::IsNull::No) + } +} + +fn get_bool_value(arr: &Arc, idx: usize) -> Option { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +macro_rules! get_primitive_value { + ($name:ident, $t:ty, $pt:ty) => { + fn $name(arr: &Arc, idx: usize) -> Option<$pt> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::>() + .unwrap() + .value(idx) + }) + } + }; +} + +get_primitive_value!(get_i8_value, Int8Type, i8); +get_primitive_value!(get_i16_value, Int16Type, i16); +get_primitive_value!(get_i32_value, Int32Type, i32); +get_primitive_value!(get_i64_value, Int64Type, i64); +get_primitive_value!(get_u8_value, UInt8Type, u8); +get_primitive_value!(get_u16_value, UInt16Type, u16); +get_primitive_value!(get_u32_value, UInt32Type, u32); +get_primitive_value!(get_u64_value, UInt64Type, u64); +get_primitive_value!(get_f32_value, Float32Type, f32); +get_primitive_value!(get_f64_value, Float64Type, f64); + +fn get_utf8_view_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_large_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_large_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { + (!arr.is_null(idx)).then(|| { + arr.as_any() + .downcast_ref::() + .unwrap() + .value(idx) + }) +} + +fn get_date32_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_date(idx) +} + +fn get_date64_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_date(idx) +} + +fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} +fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { + if arr.is_null(idx) { + return None; + } + arr.as_any() + .downcast_ref::() + .unwrap() + .value_as_datetime(idx) +} + +fn get_numeric_128_value(arr: &Arc, idx: usize, scale: u32) -> Result> { + if arr.is_null(idx) { + return Ok(None); + } + + let array = arr.as_any().downcast_ref::().unwrap(); + let value = array.value(idx); + Decimal::try_from_i128_with_scale(value, scale) + .map_err(|e| { + let message = match e { + rust_decimal::Error::ExceedsMaximumPossibleValue => { + "Exceeds maximum possible value" + } + rust_decimal::Error::LessThanMinimumPossibleValue => { + "Less than minimum possible value" + } + rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => { + "Scale exceeds maximum precision" + } + _ => unreachable!(), + }; + ToSqlError::from(message) + }) + .map(Some) +} + +pub fn encode_value( + encoder: &mut T, + arr: &Arc, + idx: usize, + type_: &Type, + format: FieldFormat, +) -> Result<()> { + match arr.data_type() { + DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, + DataType::Boolean => { + encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)? + } + DataType::Int8 => { + encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)? + } + DataType::Int16 => { + encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)? + } + DataType::Int32 => { + encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)? + } + DataType::Int64 => { + encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)? + } + DataType::UInt8 => encoder.encode_field_with_type_and_format( + &(get_u8_value(arr, idx).map(|x| x as i8)), + type_, + format, + )?, + DataType::UInt16 => encoder.encode_field_with_type_and_format( + &(get_u16_value(arr, idx).map(|x| x as i16)), + type_, + format, + )?, + DataType::UInt32 => { + encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)? + } + DataType::UInt64 => encoder.encode_field_with_type_and_format( + &(get_u64_value(arr, idx).map(|x| x as i64)), + type_, + format, + )?, + DataType::Float32 => { + encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)? + } + DataType::Float64 => { + encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)? + } + DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format( + &get_numeric_128_value(arr, idx, *s as u32)?, + type_, + format, + )?, + DataType::Utf8 => { + encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)? + } + DataType::Utf8View => encoder.encode_field_with_type_and_format( + &get_utf8_view_value(arr, idx), + type_, + format, + )?, + DataType::LargeUtf8 => encoder.encode_field_with_type_and_format( + &get_large_utf8_value(arr, idx), + type_, + format, + )?, + DataType::Binary => { + encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)? + } + DataType::LargeBinary => encoder.encode_field_with_type_and_format( + &get_large_binary_value(arr, idx), + type_, + format, + )?, + DataType::Date32 => { + encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)? + } + DataType::Date64 => { + encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)? + } + DataType::Time32(unit) => match unit { + TimeUnit::Second => encoder.encode_field_with_type_and_format( + &get_time32_second_value(arr, idx), + type_, + format, + )?, + TimeUnit::Millisecond => encoder.encode_field_with_type_and_format( + &get_time32_millisecond_value(arr, idx), + type_, + format, + )?, + _ => {} + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => encoder.encode_field_with_type_and_format( + &get_time64_microsecond_value(arr, idx), + type_, + format, + )?, + TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format( + &get_time64_nanosecond_value(arr, idx), + type_, + format, + )?, + _ => {} + }, + DataType::Timestamp(unit, timezone) => match unit { + TimeUnit::Second => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr.as_any().downcast_ref::().unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Millisecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Microsecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + TimeUnit::Nanosecond => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format( + &None::, + type_, + format, + ); + } + let ts_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + if let Some(tz) = timezone { + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; + let value = ts_array + .value_as_datetime_with_tz(idx, tz) + .map(|d| d.fixed_offset()); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } else { + let value = ts_array.value_as_datetime(idx); + encoder.encode_field_with_type_and_format(&value, type_, format)?; + } + } + }, + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format); + } + let array = arr.as_any().downcast_ref::().unwrap().value(idx); + let value = encode_list(array, type_, format)?; + encoder.encode_field_with_type_and_format(&value, type_, format)? + } + DataType::Struct(_) => { + let fields = match type_.kind() { + postgres_types::Kind::Composite(fields) => fields, + _ => { + return Err(ToSqlError::from(format!( + "Failed to unwrap a composite type from type {}", + type_ + ))) + } + }; + let value = encode_struct(arr, idx, fields, format)?; + encoder.encode_field_with_type_and_format(&value, type_, format)? + } + DataType::Dictionary(_, value_type) => { + if arr.is_null(idx) { + return encoder.encode_field_with_type_and_format(&None::, type_, format); + } + // Get the dictionary values, ignoring keys + // We'll use Int32Type as a common key type, but we're only interested in values + macro_rules! get_dict_values { + ($key_type:ty) => { + arr.as_any() + .downcast_ref::>() + .map(|dict| dict.values()) + }; + } + + // Try to extract values using different key types + let values = get_dict_values!(Int8Type) + .or_else(|| get_dict_values!(Int16Type)) + .or_else(|| get_dict_values!(Int32Type)) + .or_else(|| get_dict_values!(Int64Type)) + .or_else(|| get_dict_values!(UInt8Type)) + .or_else(|| get_dict_values!(UInt16Type)) + .or_else(|| get_dict_values!(UInt32Type)) + .or_else(|| get_dict_values!(UInt64Type)) + .ok_or_else(|| { + ToSqlError::from(format!( + "Unsupported dictionary key type for value type {}", + value_type + )) + })?; + + // If the dictionary has only one value, treat it as a primitive + if values.len() == 1 { + encode_value(encoder, values, 0, type_, format)? + } else { + // Otherwise, use value directly indexed by values array + encode_value(encoder, values, idx, type_, format)? + } + } + _ => { + return Err(ToSqlError::from(format!( + "Unsupported Datatype {} and array {:?}", + arr.data_type(), + &arr + ))) + } + } + + Ok(()) +} diff --git a/datafusion-postgres/src/encoder/list_encoder.rs b/arrow-pg/src/list_encoder.rs similarity index 90% rename from datafusion-postgres/src/encoder/list_encoder.rs rename to arrow-pg/src/list_encoder.rs index a8758ef..14a8f41 100644 --- a/datafusion-postgres/src/encoder/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -1,14 +1,12 @@ -use std::{error::Error, str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc}; -use bytes::{BufMut, BytesMut}; -use chrono::{DateTime, TimeZone, Utc}; -use datafusion::arrow::array::{ +use arrow::array::{ timezone::Tz, Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, LargeBinaryArray, PrimitiveArray, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; -use datafusion::arrow::{ +use arrow::{ datatypes::{ DataType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, @@ -16,15 +14,13 @@ use datafusion::arrow::{ }, temporal_conversions::{as_date, as_time}, }; -use pgwire::{ - api::results::FieldFormat, - error::{ErrorInfo, PgWireError}, - types::{ToSqlText, QUOTE_ESCAPE}, -}; +use bytes::{BufMut, BytesMut}; +use chrono::{DateTime, TimeZone, Utc}; +use pgwire::types::{ToSqlText, QUOTE_ESCAPE}; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; -use super::{struct_encoder::encode_struct, EncodedValue}; +use super::{struct_encoder::encode_struct, EncodedValue, FieldFormat, Result, ToSqlError}; fn get_bool_list_value(arr: &Arc) -> Vec> { arr.as_any() @@ -76,7 +72,7 @@ fn encode_field( t: &[T], type_: &Type, format: FieldFormat, -) -> Result> { +) -> Result { let mut bytes = BytesMut::new(); match format { FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?, @@ -89,7 +85,7 @@ pub(crate) fn encode_list( arr: Arc, type_: &Type, format: FieldFormat, -) -> Result> { +) -> Result { match arr.data_type() { DataType::Null => { let mut bytes = BytesMut::new(); @@ -227,8 +223,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -258,8 +253,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -291,8 +285,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -324,8 +317,7 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value: Vec<_> = array_iter .map(|i| { i.map(|i| { @@ -363,12 +355,9 @@ pub(crate) fn encode_list( type_.kind() )), }) - .map_err(|err| { - let err = ErrorInfo::new("ERROR".to_owned(), "XX000".to_owned(), err); - Box::new(PgWireError::UserError(Box::new(err))) - })?; + .map_err(ToSqlError::from)?; - let values: Result, _> = (0..arr.len()) + let values: Result> = (0..arr.len()) .map(|row| encode_struct(&arr, row, fields, format)) .map(|x| { if matches!(format, FieldFormat::Text) { @@ -396,17 +385,9 @@ pub(crate) fn encode_list( encode_field(&values?, type_, format) } // TODO: more types - list_type => { - let err = PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported List Datatype {} and array {:?}", - list_type, &arr - ), - ))); - - Err(Box::new(err)) - } + list_type => Err(ToSqlError::from(format!( + "Unsupported List Datatype {} and array {:?}", + list_type, &arr + ))), } } diff --git a/datafusion-postgres/src/encoder/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs similarity index 87% rename from datafusion-postgres/src/encoder/struct_encoder.rs rename to arrow-pg/src/struct_encoder.rs index d0b1523..20f368f 100644 --- a/datafusion-postgres/src/encoder/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -1,22 +1,18 @@ -use std::{error::Error, sync::Arc}; +use std::sync::Arc; +use arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; -use datafusion::arrow::array::{Array, StructArray}; -use pgwire::{ - api::results::FieldFormat, - error::PgWireResult, - types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}, -}; +use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql, Type}; -use super::{encode_value, EncodedValue}; +use super::{encode_value, EncodedValue, FieldFormat, Result}; -pub fn encode_struct( +pub(crate) fn encode_struct( arr: &Arc, idx: usize, fields: &[Field], format: FieldFormat, -) -> Result, Box> { +) -> Result> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { return Ok(None); @@ -32,14 +28,14 @@ pub fn encode_struct( })) } -struct StructEncoder { +pub(crate) struct StructEncoder { num_cols: usize, curr_col: usize, row_buffer: BytesMut, } impl StructEncoder { - fn new(num_cols: usize) -> Self { + pub(crate) fn new(num_cols: usize) -> Self { Self { num_cols, curr_col: 0, @@ -54,7 +50,7 @@ impl super::Encoder for StructEncoder { value: &T, data_type: &Type, format: FieldFormat, - ) -> PgWireResult<()> + ) -> Result<()> where T: ToSql + ToSqlText + Sized, { diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 49b8a88..021ad0d 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -14,12 +14,10 @@ readme = "../README.md" [dependencies] async-trait = "0.1" bytes = "1.10.1" -chrono = { version = "0.4", features = ["std"] } datafusion = { workspace = true } futures = "0.3" getset = "0.1" log = "0.4" pgwire = { workspace = true } -postgres-types = "0.2" -rust_decimal = { version = "1.37", features = ["db-postgres"] } +postgres-types.workspace = true tokio = { version = "1.45", features = ["sync", "net"] } From 85038a7177ad679fdadc60a06102f6ab07f9d1df Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sat, 14 Jun 2025 17:50:10 -0700 Subject: [PATCH 2/3] refactor: move more components to arrow-pg --- Cargo.lock | 4 + Cargo.toml | 1 + arrow-pg/Cargo.toml | 3 +- arrow-pg/src/datatypes.rs | 129 +++++ .../encoder/mod.rs => arrow-pg/src/encoder.rs | 80 ++- arrow-pg/src/error.rs | 1 + arrow-pg/src/lib.rs | 513 +----------------- arrow-pg/src/list_encoder.rs | 19 +- .../encoder => arrow-pg/src}/row_encoder.rs | 4 +- arrow-pg/src/struct_encoder.rs | 10 +- datafusion-postgres/Cargo.toml | 5 +- datafusion-postgres/src/datatypes.rs | 129 +---- datafusion-postgres/src/handlers.rs | 9 +- datafusion-postgres/src/lib.rs | 1 - 14 files changed, 208 insertions(+), 700 deletions(-) create mode 100644 arrow-pg/src/datatypes.rs rename datafusion-postgres/src/encoder/mod.rs => arrow-pg/src/encoder.rs (88%) create mode 100644 arrow-pg/src/error.rs rename {datafusion-postgres/src/encoder => arrow-pg/src}/row_encoder.rs (94%) diff --git a/Cargo.lock b/Cargo.lock index 4571382..37ea8ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,7 @@ dependencies = [ "arrow", "bytes", "chrono", + "futures", "pgwire", "postgres-types", "rust_decimal", @@ -1530,14 +1531,17 @@ dependencies = [ name = "datafusion-postgres" version = "0.5.1" dependencies = [ + "arrow-pg", "async-trait", "bytes", + "chrono", "datafusion", "futures", "getset", "log", "pgwire", "postgres-types", + "rust_decimal", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index a77c1db..7a17670 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] arrow = "55" datafusion = { version = "47", default-features = false } +futures = "0.3" pgwire = "0.30.2" postgres-types = "0.2" tokio = { version = "1", default-features = false } diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 4a44aa5..0494305 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -15,6 +15,7 @@ readme = "../README.md" arrow.workspace = true bytes = "1" chrono = { version = "0.4", features = ["std"] } -pgwire = { workspace = true, default-features = false } +futures.workspace = true +pgwire.workspace = true postgres-types.workspace = true rust_decimal = { version = "1.37", features = ["db-postgres"] } diff --git a/arrow-pg/src/datatypes.rs b/arrow-pg/src/datatypes.rs new file mode 100644 index 0000000..06dafe2 --- /dev/null +++ b/arrow-pg/src/datatypes.rs @@ -0,0 +1,129 @@ +use std::sync::Arc; + +use arrow::datatypes::*; +use arrow::record_batch::RecordBatch; +use pgwire::api::portal::Format; +use pgwire::api::results::FieldInfo; +use pgwire::api::Type; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use postgres_types::Kind; + +use crate::row_encoder::RowEncoder; + +pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult { + Ok(match arrow_type { + DataType::Null => Type::UNKNOWN, + DataType::Boolean => Type::BOOL, + DataType::Int8 | DataType::UInt8 => Type::CHAR, + DataType::Int16 | DataType::UInt16 => Type::INT2, + DataType::Int32 | DataType::UInt32 => Type::INT4, + DataType::Int64 | DataType::UInt64 => Type::INT8, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ + } else { + Type::TIMESTAMP + } + } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME, + DataType::Date32 | DataType::Date64 => Type::DATE, + DataType::Interval(_) => Type::INTERVAL, + DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA, + DataType::Float16 | DataType::Float32 => Type::FLOAT4, + DataType::Float64 => Type::FLOAT8, + DataType::Decimal128(_, _) => Type::NUMERIC, + DataType::Utf8 => Type::VARCHAR, + DataType::LargeUtf8 => Type::TEXT, + DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { + match field.data_type() { + DataType::Boolean => Type::BOOL_ARRAY, + DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, + DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, + DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, + DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, + DataType::Timestamp(_, tz) => { + if tz.is_some() { + Type::TIMESTAMPTZ_ARRAY + } else { + Type::TIMESTAMP_ARRAY + } + } + DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, + DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, + DataType::Interval(_) => Type::INTERVAL_ARRAY, + DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY, + DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, + DataType::Float64 => Type::FLOAT8_ARRAY, + DataType::Utf8 => Type::VARCHAR_ARRAY, + DataType::LargeUtf8 => Type::TEXT_ARRAY, + struct_type @ DataType::Struct(_) => Type::new( + Type::RECORD_ARRAY.name().into(), + Type::RECORD_ARRAY.oid(), + Kind::Array(into_pg_type(struct_type)?), + Type::RECORD_ARRAY.schema().into(), + ), + list_type => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported List Datatype {list_type}"), + )))); + } + } + } + DataType::Utf8View => Type::TEXT, + DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, + DataType::Struct(fields) => { + let name: String = fields + .iter() + .map(|x| x.name().clone()) + .reduce(|a, b| a + ", " + &b) + .map(|x| format!("({x})")) + .unwrap_or("()".to_string()); + let kind = Kind::Composite( + fields + .iter() + .map(|x| { + into_pg_type(x.data_type()) + .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) + }) + .collect::, PgWireError>>()?, + ); + Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) + } + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "XX000".to_owned(), + format!("Unsupported Datatype {arrow_type}"), + )))); + } + }) +} + +pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult> { + schema + .fields() + .iter() + .enumerate() + .map(|(idx, f)| { + let pg_type = into_pg_type(f.data_type())?; + Ok(FieldInfo::new( + f.name().into(), + None, + None, + pg_type, + format.format_for(idx), + )) + }) + .collect::>>() +} + +pub fn encode_recordbatch( + fields: Arc>, + record_batch: RecordBatch, +) -> Box>> { + let mut row_stream = RowEncoder::new(record_batch, fields); + Box::new(std::iter::from_fn(move || row_stream.next_row())) +} diff --git a/datafusion-postgres/src/encoder/mod.rs b/arrow-pg/src/encoder.rs similarity index 88% rename from datafusion-postgres/src/encoder/mod.rs rename to arrow-pg/src/encoder.rs index 233d24c..cda5ba7 100644 --- a/datafusion-postgres/src/encoder/mod.rs +++ b/arrow-pg/src/encoder.rs @@ -1,27 +1,27 @@ +use std::error::Error; use std::io::Write; use std::str::FromStr; use std::sync::Arc; +use arrow::array::*; +use arrow::datatypes::*; use bytes::BufMut; use bytes::BytesMut; use chrono::{NaiveDate, NaiveDateTime}; -use datafusion::arrow::array::*; -use datafusion::arrow::datatypes::*; -use list_encoder::encode_list; use pgwire::api::results::DataRowEncoder; use pgwire::api::results::FieldFormat; -use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::error::PgWireError; +use pgwire::error::PgWireResult; use pgwire::types::ToSqlText; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; -use struct_encoder::encode_struct; use timezone::Tz; -pub mod list_encoder; -pub mod row_encoder; -pub mod struct_encoder; +use crate::error::ToSqlError; +use crate::list_encoder::encode_list; +use crate::struct_encoder::encode_struct; -trait Encoder { +pub trait Encoder { fn encode_field_with_type_and_format( &mut self, value: &T, @@ -61,7 +61,7 @@ impl ToSql for EncodedValue { &self, _ty: &Type, out: &mut BytesMut, - ) -> Result> + ) -> Result> where Self: Sized, { @@ -80,7 +80,7 @@ impl ToSql for EncodedValue { &self, ty: &Type, out: &mut BytesMut, - ) -> Result> { + ) -> Result> { self.to_sql(ty, out) } } @@ -90,7 +90,7 @@ impl ToSqlText for EncodedValue { &self, _ty: &Type, out: &mut BytesMut, - ) -> Result> + ) -> Result> where Self: Sized, { @@ -261,16 +261,13 @@ fn get_numeric_128_value( } _ => unreachable!(), }; - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - message.to_owned(), - ))) + // TODO: add error type in PgWireError + PgWireError::ApiError(ToSqlError::from(message)) }) .map(Some) } -fn encode_value( +pub fn encode_value( encoder: &mut T, arr: &Arc, idx: usize, @@ -387,8 +384,7 @@ fn encode_value( } let ts_array = arr.as_any().downcast_ref::().unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -411,8 +407,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -435,8 +430,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -459,8 +453,7 @@ fn encode_value( .downcast_ref::() .unwrap(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; let value = ts_array .value_as_datetime_with_tz(idx, tz) .map(|d| d.fixed_offset()); @@ -483,11 +476,10 @@ fn encode_value( let fields = match type_.kind() { postgres_types::Kind::Composite(fields) => fields, _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Failed to unwrap a composite type from type {}", type_), - )))) + return Err(PgWireError::ApiError(ToSqlError::from(format!( + "Failed to unwrap a composite type from type {}", + type_ + )))); } }; let value = encode_struct(arr, idx, fields, format)?; @@ -517,14 +509,10 @@ fn encode_value( .or_else(|| get_dict_values!(UInt32Type)) .or_else(|| get_dict_values!(UInt64Type)) .ok_or_else(|| { - PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported dictionary key type for value type {}", - value_type - ), - ))) + ToSqlError::from(format!( + "Unsupported dictionary key type for value type {}", + value_type + )) })?; // If the dictionary has only one value, treat it as a primitive @@ -536,15 +524,11 @@ fn encode_value( } } _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!( - "Unsupported Datatype {} and array {:?}", - arr.data_type(), - &arr - ), - )))) + return Err(PgWireError::ApiError(ToSqlError::from(format!( + "Unsupported Datatype {} and array {:?}", + arr.data_type(), + &arr + )))); } } diff --git a/arrow-pg/src/error.rs b/arrow-pg/src/error.rs new file mode 100644 index 0000000..9dca31b --- /dev/null +++ b/arrow-pg/src/error.rs @@ -0,0 +1 @@ +pub type ToSqlError = Box; diff --git a/arrow-pg/src/lib.rs b/arrow-pg/src/lib.rs index b742aad..dd77bce 100644 --- a/arrow-pg/src/lib.rs +++ b/arrow-pg/src/lib.rs @@ -1,511 +1,6 @@ -use std::io::Write; -use std::str::FromStr; -use std::sync::Arc; - -use arrow::array::*; -use arrow::datatypes::*; -use bytes::BufMut; -use bytes::BytesMut; -use chrono::{NaiveDate, NaiveDateTime}; -use list_encoder::encode_list; -use pgwire::types::ToSqlText; -use postgres_types::{ToSql, Type}; -use rust_decimal::Decimal; -use struct_encoder::encode_struct; -use timezone::Tz; - +pub mod datatypes; +pub mod encoder; +mod error; pub mod list_encoder; +pub mod row_encoder; pub mod struct_encoder; - -#[derive(Debug, Eq, PartialEq, Clone, Copy)] -#[repr(i16)] -pub enum FieldFormat { - Text = 0, - Binary = 1, -} - -pub type ToSqlError = Box; -pub type Result = std::result::Result; - -pub trait Encoder { - fn encode_field_with_type_and_format( - &mut self, - value: &T, - data_type: &Type, - format: FieldFormat, - ) -> Result<()> - where - T: ToSql + ToSqlText + Sized; -} - -pub(crate) struct EncodedValue { - pub(crate) bytes: BytesMut, -} - -impl std::fmt::Debug for EncodedValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EncodedValue").finish() - } -} - -impl ToSql for EncodedValue { - fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result - where - Self: Sized, - { - out.writer().write_all(&self.bytes)?; - Ok(postgres_types::IsNull::No) - } - - fn accepts(_ty: &Type) -> bool - where - Self: Sized, - { - true - } - - fn to_sql_checked(&self, ty: &Type, out: &mut BytesMut) -> Result { - self.to_sql(ty, out) - } -} - -impl ToSqlText for EncodedValue { - fn to_sql_text(&self, _ty: &Type, out: &mut BytesMut) -> Result - where - Self: Sized, - { - out.writer().write_all(&self.bytes)?; - Ok(postgres_types::IsNull::No) - } -} - -fn get_bool_value(arr: &Arc, idx: usize) -> Option { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -macro_rules! get_primitive_value { - ($name:ident, $t:ty, $pt:ty) => { - fn $name(arr: &Arc, idx: usize) -> Option<$pt> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::>() - .unwrap() - .value(idx) - }) - } - }; -} - -get_primitive_value!(get_i8_value, Int8Type, i8); -get_primitive_value!(get_i16_value, Int16Type, i16); -get_primitive_value!(get_i32_value, Int32Type, i32); -get_primitive_value!(get_i64_value, Int64Type, i64); -get_primitive_value!(get_u8_value, UInt8Type, u8); -get_primitive_value!(get_u16_value, UInt16Type, u16); -get_primitive_value!(get_u32_value, UInt32Type, u32); -get_primitive_value!(get_u64_value, UInt64Type, u64); -get_primitive_value!(get_f32_value, Float32Type, f32); -get_primitive_value!(get_f64_value, Float64Type, f64); - -fn get_utf8_view_value(arr: &Arc, idx: usize) -> Option<&str> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -fn get_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -fn get_large_utf8_value(arr: &Arc, idx: usize) -> Option<&str> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -fn get_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -fn get_large_binary_value(arr: &Arc, idx: usize) -> Option<&[u8]> { - (!arr.is_null(idx)).then(|| { - arr.as_any() - .downcast_ref::() - .unwrap() - .value(idx) - }) -} - -fn get_date32_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_date(idx) -} - -fn get_date64_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_date(idx) -} - -fn get_time32_second_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn get_time32_millisecond_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn get_time64_microsecond_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} -fn get_time64_nanosecond_value(arr: &Arc, idx: usize) -> Option { - if arr.is_null(idx) { - return None; - } - arr.as_any() - .downcast_ref::() - .unwrap() - .value_as_datetime(idx) -} - -fn get_numeric_128_value(arr: &Arc, idx: usize, scale: u32) -> Result> { - if arr.is_null(idx) { - return Ok(None); - } - - let array = arr.as_any().downcast_ref::().unwrap(); - let value = array.value(idx); - Decimal::try_from_i128_with_scale(value, scale) - .map_err(|e| { - let message = match e { - rust_decimal::Error::ExceedsMaximumPossibleValue => { - "Exceeds maximum possible value" - } - rust_decimal::Error::LessThanMinimumPossibleValue => { - "Less than minimum possible value" - } - rust_decimal::Error::ScaleExceedsMaximumPrecision(_) => { - "Scale exceeds maximum precision" - } - _ => unreachable!(), - }; - ToSqlError::from(message) - }) - .map(Some) -} - -pub fn encode_value( - encoder: &mut T, - arr: &Arc, - idx: usize, - type_: &Type, - format: FieldFormat, -) -> Result<()> { - match arr.data_type() { - DataType::Null => encoder.encode_field_with_type_and_format(&None::, type_, format)?, - DataType::Boolean => { - encoder.encode_field_with_type_and_format(&get_bool_value(arr, idx), type_, format)? - } - DataType::Int8 => { - encoder.encode_field_with_type_and_format(&get_i8_value(arr, idx), type_, format)? - } - DataType::Int16 => { - encoder.encode_field_with_type_and_format(&get_i16_value(arr, idx), type_, format)? - } - DataType::Int32 => { - encoder.encode_field_with_type_and_format(&get_i32_value(arr, idx), type_, format)? - } - DataType::Int64 => { - encoder.encode_field_with_type_and_format(&get_i64_value(arr, idx), type_, format)? - } - DataType::UInt8 => encoder.encode_field_with_type_and_format( - &(get_u8_value(arr, idx).map(|x| x as i8)), - type_, - format, - )?, - DataType::UInt16 => encoder.encode_field_with_type_and_format( - &(get_u16_value(arr, idx).map(|x| x as i16)), - type_, - format, - )?, - DataType::UInt32 => { - encoder.encode_field_with_type_and_format(&get_u32_value(arr, idx), type_, format)? - } - DataType::UInt64 => encoder.encode_field_with_type_and_format( - &(get_u64_value(arr, idx).map(|x| x as i64)), - type_, - format, - )?, - DataType::Float32 => { - encoder.encode_field_with_type_and_format(&get_f32_value(arr, idx), type_, format)? - } - DataType::Float64 => { - encoder.encode_field_with_type_and_format(&get_f64_value(arr, idx), type_, format)? - } - DataType::Decimal128(_, s) => encoder.encode_field_with_type_and_format( - &get_numeric_128_value(arr, idx, *s as u32)?, - type_, - format, - )?, - DataType::Utf8 => { - encoder.encode_field_with_type_and_format(&get_utf8_value(arr, idx), type_, format)? - } - DataType::Utf8View => encoder.encode_field_with_type_and_format( - &get_utf8_view_value(arr, idx), - type_, - format, - )?, - DataType::LargeUtf8 => encoder.encode_field_with_type_and_format( - &get_large_utf8_value(arr, idx), - type_, - format, - )?, - DataType::Binary => { - encoder.encode_field_with_type_and_format(&get_binary_value(arr, idx), type_, format)? - } - DataType::LargeBinary => encoder.encode_field_with_type_and_format( - &get_large_binary_value(arr, idx), - type_, - format, - )?, - DataType::Date32 => { - encoder.encode_field_with_type_and_format(&get_date32_value(arr, idx), type_, format)? - } - DataType::Date64 => { - encoder.encode_field_with_type_and_format(&get_date64_value(arr, idx), type_, format)? - } - DataType::Time32(unit) => match unit { - TimeUnit::Second => encoder.encode_field_with_type_and_format( - &get_time32_second_value(arr, idx), - type_, - format, - )?, - TimeUnit::Millisecond => encoder.encode_field_with_type_and_format( - &get_time32_millisecond_value(arr, idx), - type_, - format, - )?, - _ => {} - }, - DataType::Time64(unit) => match unit { - TimeUnit::Microsecond => encoder.encode_field_with_type_and_format( - &get_time64_microsecond_value(arr, idx), - type_, - format, - )?, - TimeUnit::Nanosecond => encoder.encode_field_with_type_and_format( - &get_time64_nanosecond_value(arr, idx), - type_, - format, - )?, - _ => {} - }, - DataType::Timestamp(unit, timezone) => match unit { - TimeUnit::Second => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); - } - let ts_array = arr.as_any().downcast_ref::().unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } - } - TimeUnit::Millisecond => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); - } - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } - } - TimeUnit::Microsecond => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); - } - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } - } - TimeUnit::Nanosecond => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format( - &None::, - type_, - format, - ); - } - let ts_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; - let value = ts_array - .value_as_datetime_with_tz(idx, tz) - .map(|d| d.fixed_offset()); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } else { - let value = ts_array.value_as_datetime(idx); - encoder.encode_field_with_type_and_format(&value, type_, format)?; - } - } - }, - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format(&None::<&[i8]>, type_, format); - } - let array = arr.as_any().downcast_ref::().unwrap().value(idx); - let value = encode_list(array, type_, format)?; - encoder.encode_field_with_type_and_format(&value, type_, format)? - } - DataType::Struct(_) => { - let fields = match type_.kind() { - postgres_types::Kind::Composite(fields) => fields, - _ => { - return Err(ToSqlError::from(format!( - "Failed to unwrap a composite type from type {}", - type_ - ))) - } - }; - let value = encode_struct(arr, idx, fields, format)?; - encoder.encode_field_with_type_and_format(&value, type_, format)? - } - DataType::Dictionary(_, value_type) => { - if arr.is_null(idx) { - return encoder.encode_field_with_type_and_format(&None::, type_, format); - } - // Get the dictionary values, ignoring keys - // We'll use Int32Type as a common key type, but we're only interested in values - macro_rules! get_dict_values { - ($key_type:ty) => { - arr.as_any() - .downcast_ref::>() - .map(|dict| dict.values()) - }; - } - - // Try to extract values using different key types - let values = get_dict_values!(Int8Type) - .or_else(|| get_dict_values!(Int16Type)) - .or_else(|| get_dict_values!(Int32Type)) - .or_else(|| get_dict_values!(Int64Type)) - .or_else(|| get_dict_values!(UInt8Type)) - .or_else(|| get_dict_values!(UInt16Type)) - .or_else(|| get_dict_values!(UInt32Type)) - .or_else(|| get_dict_values!(UInt64Type)) - .ok_or_else(|| { - ToSqlError::from(format!( - "Unsupported dictionary key type for value type {}", - value_type - )) - })?; - - // If the dictionary has only one value, treat it as a primitive - if values.len() == 1 { - encode_value(encoder, values, 0, type_, format)? - } else { - // Otherwise, use value directly indexed by values array - encode_value(encoder, values, idx, type_, format)? - } - } - _ => { - return Err(ToSqlError::from(format!( - "Unsupported Datatype {} and array {:?}", - arr.data_type(), - &arr - ))) - } - } - - Ok(()) -} diff --git a/arrow-pg/src/list_encoder.rs b/arrow-pg/src/list_encoder.rs index 14a8f41..766da3f 100644 --- a/arrow-pg/src/list_encoder.rs +++ b/arrow-pg/src/list_encoder.rs @@ -16,11 +16,15 @@ use arrow::{ }; use bytes::{BufMut, BytesMut}; use chrono::{DateTime, TimeZone, Utc}; +use pgwire::api::results::FieldFormat; +use pgwire::error::{PgWireError, PgWireResult}; use pgwire::types::{ToSqlText, QUOTE_ESCAPE}; use postgres_types::{ToSql, Type}; use rust_decimal::Decimal; -use super::{struct_encoder::encode_struct, EncodedValue, FieldFormat, Result, ToSqlError}; +use crate::encoder::EncodedValue; +use crate::error::ToSqlError; +use crate::struct_encoder::encode_struct; fn get_bool_list_value(arr: &Arc) -> Vec> { arr.as_any() @@ -72,7 +76,7 @@ fn encode_field( t: &[T], type_: &Type, format: FieldFormat, -) -> Result { +) -> PgWireResult { let mut bytes = BytesMut::new(); match format { FieldFormat::Text => t.to_sql_text(type_, &mut bytes)?, @@ -85,7 +89,7 @@ pub(crate) fn encode_list( arr: Arc, type_: &Type, format: FieldFormat, -) -> Result { +) -> PgWireResult { match arr.data_type() { DataType::Null => { let mut bytes = BytesMut::new(); @@ -223,7 +227,8 @@ pub(crate) fn encode_list( .iter(); if let Some(tz) = timezone { - let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?; + let tz = Tz::from_str(tz.as_ref()) + .map_err(|e| PgWireError::ApiError(ToSqlError::from(e)))?; let value: Vec<_> = array_iter .map(|i| { i.and_then(|i| { @@ -357,7 +362,7 @@ pub(crate) fn encode_list( }) .map_err(ToSqlError::from)?; - let values: Result> = (0..arr.len()) + let values: PgWireResult> = (0..arr.len()) .map(|row| encode_struct(&arr, row, fields, format)) .map(|x| { if matches!(format, FieldFormat::Text) { @@ -385,9 +390,9 @@ pub(crate) fn encode_list( encode_field(&values?, type_, format) } // TODO: more types - list_type => Err(ToSqlError::from(format!( + list_type => Err(PgWireError::ApiError(ToSqlError::from(format!( "Unsupported List Datatype {} and array {:?}", list_type, &arr - ))), + )))), } } diff --git a/datafusion-postgres/src/encoder/row_encoder.rs b/arrow-pg/src/row_encoder.rs similarity index 94% rename from datafusion-postgres/src/encoder/row_encoder.rs rename to arrow-pg/src/row_encoder.rs index 9d48145..3eab8c7 100644 --- a/datafusion-postgres/src/encoder/row_encoder.rs +++ b/arrow-pg/src/row_encoder.rs @@ -1,13 +1,13 @@ use std::sync::Arc; -use datafusion::arrow::array::RecordBatch; +use arrow::array::RecordBatch; use pgwire::{ api::results::{DataRowEncoder, FieldInfo}, error::PgWireResult, messages::data::DataRow, }; -use super::encode_value; +use crate::encoder::encode_value; pub struct RowEncoder { rb: RecordBatch, diff --git a/arrow-pg/src/struct_encoder.rs b/arrow-pg/src/struct_encoder.rs index 20f368f..96c9467 100644 --- a/arrow-pg/src/struct_encoder.rs +++ b/arrow-pg/src/struct_encoder.rs @@ -2,17 +2,19 @@ use std::sync::Arc; use arrow::array::{Array, StructArray}; use bytes::{BufMut, BytesMut}; +use pgwire::api::results::FieldFormat; +use pgwire::error::PgWireResult; use pgwire::types::{ToSqlText, QUOTE_CHECK, QUOTE_ESCAPE}; use postgres_types::{Field, IsNull, ToSql, Type}; -use super::{encode_value, EncodedValue, FieldFormat, Result}; +use crate::encoder::{encode_value, EncodedValue, Encoder}; pub(crate) fn encode_struct( arr: &Arc, idx: usize, fields: &[Field], format: FieldFormat, -) -> Result> { +) -> PgWireResult> { let arr = arr.as_any().downcast_ref::().unwrap(); if arr.is_null(idx) { return Ok(None); @@ -44,13 +46,13 @@ impl StructEncoder { } } -impl super::Encoder for StructEncoder { +impl Encoder for StructEncoder { fn encode_field_with_type_and_format( &mut self, value: &T, data_type: &Type, format: FieldFormat, - ) -> Result<()> + ) -> PgWireResult<()> where T: ToSql + ToSqlText + Sized, { diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 021ad0d..c3fe9c4 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -12,12 +12,15 @@ documentation.workspace = true readme = "../README.md" [dependencies] +arrow-pg = { path = "../arrow-pg", version = "0.1" } async-trait = "0.1" bytes = "1.10.1" +chrono = { version = "0.4", features = ["std"] } datafusion = { workspace = true } -futures = "0.3" +futures.workspace = true getset = "0.1" log = "0.4" pgwire = { workspace = true } postgres-types.workspace = true +rust_decimal = { version = "1.37", features = ["db-postgres"] } tokio = { version = "1.45", features = ["sync", "net"] } diff --git a/datafusion-postgres/src/datatypes.rs b/datafusion-postgres/src/datatypes.rs index bf5223b..cbae22b 100644 --- a/datafusion-postgres/src/datatypes.rs +++ b/datafusion-postgres/src/datatypes.rs @@ -3,140 +3,27 @@ use std::sync::Arc; use chrono::{DateTime, FixedOffset}; use chrono::{NaiveDate, NaiveDateTime}; -use datafusion::arrow::datatypes::*; +use datafusion::arrow::datatypes::{DataType, Date32Type}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::{DFSchema, ParamValues}; +use datafusion::common::ParamValues; use datafusion::prelude::*; use datafusion::scalar::ScalarValue; use futures::{stream, StreamExt}; use pgwire::api::portal::{Format, Portal}; -use pgwire::api::results::{FieldInfo, QueryResponse}; +use pgwire::api::results::QueryResponse; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; -use postgres_types::Kind; use rust_decimal::prelude::ToPrimitive; use rust_decimal::Decimal; -use crate::encoder::row_encoder::RowEncoder; - -pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult { - Ok(match df_type { - DataType::Null => Type::UNKNOWN, - DataType::Boolean => Type::BOOL, - DataType::Int8 | DataType::UInt8 => Type::CHAR, - DataType::Int16 | DataType::UInt16 => Type::INT2, - DataType::Int32 | DataType::UInt32 => Type::INT4, - DataType::Int64 | DataType::UInt64 => Type::INT8, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ - } else { - Type::TIMESTAMP - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::DATE, - DataType::Interval(_) => Type::INTERVAL, - DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA, - DataType::Float16 | DataType::Float32 => Type::FLOAT4, - DataType::Float64 => Type::FLOAT8, - DataType::Decimal128(_, _) => Type::NUMERIC, - DataType::Utf8 => Type::VARCHAR, - DataType::LargeUtf8 => Type::TEXT, - DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { - match field.data_type() { - DataType::Boolean => Type::BOOL_ARRAY, - DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY, - DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY, - DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY, - DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY, - DataType::Timestamp(_, tz) => { - if tz.is_some() { - Type::TIMESTAMPTZ_ARRAY - } else { - Type::TIMESTAMP_ARRAY - } - } - DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY, - DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY, - DataType::Interval(_) => Type::INTERVAL_ARRAY, - DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY, - DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY, - DataType::Float64 => Type::FLOAT8_ARRAY, - DataType::Utf8 => Type::VARCHAR_ARRAY, - DataType::LargeUtf8 => Type::TEXT_ARRAY, - struct_type @ DataType::Struct(_) => Type::new( - Type::RECORD_ARRAY.name().into(), - Type::RECORD_ARRAY.oid(), - Kind::Array(into_pg_type(struct_type)?), - Type::RECORD_ARRAY.schema().into(), - ), - list_type => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported List Datatype {list_type}"), - )))); - } - } - } - DataType::Utf8View => Type::TEXT, - DataType::Dictionary(_, value_type) => into_pg_type(value_type)?, - DataType::Struct(fields) => { - let name: String = fields - .iter() - .map(|x| x.name().clone()) - .reduce(|a, b| a + ", " + &b) - .map(|x| format!("({x})")) - .unwrap_or("()".to_string()); - let kind = Kind::Composite( - fields - .iter() - .map(|x| { - into_pg_type(x.data_type()) - .map(|_type| postgres_types::Field::new(x.name().clone(), _type)) - }) - .collect::, PgWireError>>()?, - ); - Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into()) - } - _ => { - return Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "XX000".to_owned(), - format!("Unsupported Datatype {df_type}"), - )))); - } - }) -} - -pub(crate) fn df_schema_to_pg_fields( - schema: &DFSchema, - format: &Format, -) -> PgWireResult> { - schema - .fields() - .iter() - .enumerate() - .map(|(idx, f)| { - let pg_type = into_pg_type(f.data_type())?; - Ok(FieldInfo::new( - f.name().into(), - None, - None, - pg_type, - format.format_for(idx), - )) - }) - .collect::>>() -} +use arrow_pg::datatypes::{arrow_schema_to_pg_fields, encode_recordbatch, into_pg_type}; pub(crate) async fn encode_dataframe<'a>( df: DataFrame, format: &Format, ) -> PgWireResult> { - let fields = Arc::new(df_schema_to_pg_fields(df.schema(), format)?); + let fields = Arc::new(arrow_schema_to_pg_fields(df.schema().as_arrow(), format)?); let recordbatch_stream = df .execute_stream() @@ -148,11 +35,7 @@ pub(crate) async fn encode_dataframe<'a>( .map(move |rb: datafusion::error::Result| { let row_stream: Box> + Send + Sync> = match rb { - Ok(rb) => { - let fields = fields_ref.clone(); - let mut row_stream = RowEncoder::new(rb, fields); - Box::new(std::iter::from_fn(move || row_stream.next_row())) - } + Ok(rb) => encode_recordbatch(fields_ref.clone(), rb), Err(e) => Box::new(iter::once(Err(PgWireError::ApiError(e.into())))), }; stream::iter(row_stream) diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index a77f92a..1b122c1 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -16,10 +16,11 @@ use pgwire::api::results::{ use pgwire::api::stmt::QueryParser; use pgwire::api::stmt::StoredStatement; use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; +use pgwire::error::{PgWireError, PgWireResult}; use tokio::sync::Mutex; use crate::datatypes; -use pgwire::error::{PgWireError, PgWireResult}; +use arrow_pg::datatypes::{arrow_schema_to_pg_fields, into_pg_type}; pub struct HandlerFactory(pub Arc); @@ -237,7 +238,7 @@ impl ExtendedQueryHandler for DfSessionService { { let (_, plan) = &target.statement; let schema = plan.schema(); - let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), &Format::UnifiedBinary)?; + let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?; let params = plan .get_parameter_types() .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -246,7 +247,7 @@ impl ExtendedQueryHandler for DfSessionService { for param_type in ordered_param_types(¶ms).iter() { // Fixed: Use ¶ms if let Some(datatype) = param_type { - let pgtype = datatypes::into_pg_type(datatype)?; + let pgtype = into_pg_type(datatype)?; param_types.push(pgtype); } else { param_types.push(Type::UNKNOWN); @@ -267,7 +268,7 @@ impl ExtendedQueryHandler for DfSessionService { let (_, plan) = &target.statement.statement; let format = &target.result_column_format; let schema = plan.schema(); - let fields = datatypes::df_schema_to_pg_fields(schema.as_ref(), format)?; + let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?; Ok(DescribePortalResponse::new(fields)) } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index da1af1a..db957c5 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -1,5 +1,4 @@ mod datatypes; -mod encoder; mod handlers; pub mod pg_catalog; From 335d99e1e9dc6b13c44564fbee766f8950701ae0 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Sun, 15 Jun 2025 13:13:29 -0700 Subject: [PATCH 3/3] chore: tune workspace dependencies Signed-off-by: Ning Sun --- Cargo.lock | 2 +- Cargo.toml | 3 +++ arrow-pg/Cargo.toml | 8 ++++---- datafusion-postgres/Cargo.toml | 12 ++++++------ 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 37ea8ba..640cbb7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,7 +353,7 @@ dependencies = [ [[package]] name = "arrow-pg" -version = "0.1.0" +version = "0.0.1" dependencies = [ "arrow", "bytes", diff --git a/Cargo.toml b/Cargo.toml index 7a17670..cd7b686 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,10 +15,13 @@ documentation = "https://docs.rs/crate/datafusion-postgres/" [workspace.dependencies] arrow = "55" +bytes = "1.10.1" +chrono = { version = "0.4", features = ["std"] } datafusion = { version = "47", default-features = false } futures = "0.3" pgwire = "0.30.2" postgres-types = "0.2" +rust_decimal = { version = "1.37", features = ["db-postgres"] } tokio = { version = "1", default-features = false } [profile.release] diff --git a/arrow-pg/Cargo.toml b/arrow-pg/Cargo.toml index 0494305..0df800d 100644 --- a/arrow-pg/Cargo.toml +++ b/arrow-pg/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "arrow-pg" description = "Arrow data mapping and encoding/decoding for Postgres" -version = "0.1.0" +version = "0.0.1" edition.workspace = true license.workspace = true authors.workspace = true @@ -13,9 +13,9 @@ readme = "../README.md" [dependencies] arrow.workspace = true -bytes = "1" -chrono = { version = "0.4", features = ["std"] } +bytes.workspace = true +chrono.workspace = true futures.workspace = true pgwire.workspace = true postgres-types.workspace = true -rust_decimal = { version = "1.37", features = ["db-postgres"] } +rust_decimal.workspace = true diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index c3fe9c4..e80cb69 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -12,15 +12,15 @@ documentation.workspace = true readme = "../README.md" [dependencies] -arrow-pg = { path = "../arrow-pg", version = "0.1" } +arrow-pg = { path = "../arrow-pg", version = "0.0.1" } +bytes.workspace = true async-trait = "0.1" -bytes = "1.10.1" -chrono = { version = "0.4", features = ["std"] } -datafusion = { workspace = true } +chrono.workspace = true +datafusion.workspace = true futures.workspace = true getset = "0.1" log = "0.4" -pgwire = { workspace = true } +pgwire.workspace = true postgres-types.workspace = true -rust_decimal = { version = "1.37", features = ["db-postgres"] } +rust_decimal.workspace = true tokio = { version = "1.45", features = ["sync", "net"] }