Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 143 additions & 60 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl DecimalCast for i256 {
}
}

pub(crate) fn cast_decimal_to_decimal_error<I, O>(
fn cast_decimal_to_decimal_error<I, O>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

downgrade the visibility since it's only used in this file

output_precision: u8,
output_scale: i8,
) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
Expand All @@ -166,50 +166,86 @@ where
}
}

pub(crate) fn convert_to_smaller_scale_decimal<I, O>(
array: &PrimitiveArray<I>,
/// Construct closures to upscale decimals from `(input_precision, input_scale)` to
/// `(output_precision, output_scale)`.
///
/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale`
/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`.
/// In that case, the caller should treat this as an overflow for the output scale
/// and handle it accordingly (e.g., return a cast error).
#[allow(clippy::type_complexity)]
pub fn make_upscaler<I: DecimalType, O: DecimalType>(
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<O>, ArrowError>
) -> Option<(
impl Fn(I::Native) -> Option<O::Native>,
Option<impl Fn(I::Native) -> O::Native>,
)>
where
I: DecimalType,
O: DecimalType,
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = input_scale - output_scale;
// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
let delta_scale = output_scale - input_scale;

// O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...).
// Adding 1 yields exactly 10^k without computing a power at runtime.
// Using the precomputed table avoids pow(10, k) and its checked/overflow
// handling, which is faster and simpler for scaling by 10^delta_scale.
let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
let mul = max.add_wrapping(O::Native::ONE);
let f = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());

// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add an additional digit, so the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the old code did this too, but it seems like the cast as i8 could porentially convert a number larger than 128 to a negative number : -- maybe that is ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this becomes an issue once we expose the API. Normally it isn’t a concern, since the precision is guaranteed to stay under 128. Let me redesign the public interface

let f_infallible = is_infallible_cast
.then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul));
Some((f, f_infallible))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chose to return f_infallible instead of is_infallible_cast because, unlike make_downscaler, we cannot derive an infallible closure from f. So to keep the interface consistent, I applied the same approach to make_downscaler to return (f, f_infallible) as well.

}

/// Construct closures to downscale decimals from `(input_precision, input_scale)` to
/// `(output_precision, output_scale)`.
///
/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale`
/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`.
/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the
/// available precision). Callers should therefore produce zero values (preserving nulls) rather
/// than returning an error.
#[allow(clippy::type_complexity)]
pub fn make_downscaler<I: DecimalType, O: DecimalType>(
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
) -> Option<(
impl Fn(I::Native) -> Option<O::Native>,
Option<impl Fn(I::Native) -> O::Native>,
)>
where
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let delta_scale = input_scale - output_scale;

// delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the
// scale change divides out more digits than the input has precision and the result of the cast
// is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest
// possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values
// (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even
// smaller results, which also round to zero. In that case, just return an array of zeros.
let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else {
let zeros = vec![O::Native::ZERO; array.len()];
return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned()));
};
let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;

let div = max.add_wrapping(I::Native::ONE);
let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
let half_neg = half.neg_wrapping();

let f = |x: I::Native| {
let f = move |x: I::Native| {
// div is >= 10 and so this cannot overflow
let d = x.div_wrapping(div);
let r = x.mod_wrapping(div);
Expand All @@ -223,24 +259,49 @@ where
O::Native::from_decimal(adjusted)
};

Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed
// to fit into the target type
array.unary(g)
// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add a digit, so the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8);
let f_infallible = is_infallible_cast.then_some(move |x| f(x).unwrap());
Some((f, f_infallible))
}

fn apply_decimal_cast<I: DecimalType, O: DecimalType>(
array: &PrimitiveArray<I>,
output_precision: u8,
output_scale: i8,
f: impl Fn(I::Native) -> Option<O::Native>,
f_infallible: Option<impl Fn(I::Native) -> O::Native>,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<O>, ArrowError>
where
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let array = if let Some(f_infallible) = f_infallible {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very nice formulation now

array.unary(f_infallible)
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
} else {
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
array.try_unary(|x| {
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
})
};
Ok(array)
}

pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>(
fn convert_to_smaller_scale_decimal<I, O>(
array: &PrimitiveArray<I>,
input_precision: u8,
input_scale: i8,
Expand All @@ -254,36 +315,58 @@ where
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale);
let delta_scale = output_scale - input_scale;
let mul = O::Native::from_decimal(10_i128)
.unwrap()
.pow_checked(delta_scale as u32)?;
if let Some((f, f_infallible)) =
make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
{
apply_decimal_cast(
array,
output_precision,
output_scale,
f,
f_infallible,
cast_options,
)
} else {
// Scale reduction exceeds supported precision; result mathematically rounds to zero
let zeros = vec![O::Native::ZERO; array.len()];
Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned()))
}
}

// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8);
let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok());

Ok(if is_infallible_cast {
// make sure we don't perform calculations that don't make sense w/o validation
validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;
// unwrapping is safe since the result is guaranteed to fit into the target type
let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul);
array.unary(f)
} else if cast_options.safe {
array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)))
fn convert_to_bigger_or_equal_scale_decimal<I, O>(
array: &PrimitiveArray<I>,
input_precision: u8,
input_scale: i8,
output_precision: u8,
output_scale: i8,
cast_options: &CastOptions,
) -> Result<PrimitiveArray<O>, ArrowError>
where
I: DecimalType,
O: DecimalType,
I::Native: DecimalCast + ArrowNativeTypeOp,
O::Native: DecimalCast + ArrowNativeTypeOp,
{
if let Some((f, f_infallible)) =
make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
{
apply_decimal_cast(
array,
output_precision,
output_scale,
f,
f_infallible,
cast_options,
)
} else {
array.try_unary(|x| {
f(x).ok_or_else(|| error(x)).and_then(|v| {
O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v)
})
})?
})
// Scale increase exceeds supported precision; return overflow error
Err(ArrowError::CastError(format!(
"Cannot cast to {}({}, {}). Value overflows for output scale",
O::PREFIX,
output_precision,
output_scale
)))
}
}

// Only support one type of decimal cast operations
Expand Down
2 changes: 1 addition & 1 deletion arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ use arrow_schema::*;
use arrow_select::take::take;
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};

pub use decimal::DecimalCast;
pub use decimal::{DecimalCast, make_downscaler, make_upscaler};

/// CastOptions provides a way to override the default cast behaviors
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down
91 changes: 28 additions & 63 deletions parquet-variant-compute/src/type_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Module for transforming a typed arrow `Array` to `VariantArray`.

use arrow::array::ArrowNativeTypeOp;
use arrow::compute::DecimalCast;
use arrow::compute::{DecimalCast, make_downscaler, make_upscaler};
use arrow::datatypes::{
self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
DecimalType,
Expand Down Expand Up @@ -204,73 +204,38 @@ where
I::Native: DecimalCast,
O::Native: DecimalCast,
{
let delta_scale = output_scale - input_scale;

let (scaled, is_infallible_cast) = if delta_scale >= 0 {
// O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...).
// Adding 1 yields exactly 10^k without computing a power at runtime.
// Using the precomputed table avoids pow(10, k) and its checked/overflow
// handling, which is faster and simpler for scaling by 10^delta_scale.
let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
let mul = max.add_wrapping(O::Native::ONE);

// if the gain in precision (digits) is greater than the multiplication due to scaling
// every number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then an increase of scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
// needs to provide at least 8 digits precision
let is_infallible_cast = input_precision as i8 + delta_scale <= output_precision as i8;
let value = O::Native::from_decimal(value);
let scaled = if is_infallible_cast {
Some(value.unwrap().mul_wrapping(mul))
} else {
value.and_then(|x| x.mul_checked(mul).ok())
};
(scaled, is_infallible_cast)
if input_scale <= output_scale {
let (f, f_infallible) =
make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
} else {
// the abs of delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION.
// If so, the scale change divides out more digits than the input has precision and the result
// of the cast is always zero. For example, if we try to apply delta_scale=10 a decimal32 value,
// the largest possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero.
// Smaller values (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000)
// produce even smaller results, which also round to zero. In that case, just return zero.
let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale.unsigned_abs() as usize) else {
let Some((f, f_infallible)) =
make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
else {
// Scale reduction exceeds supported precision; result mathematically rounds to zero
return Some(O::Native::ZERO);
};
let div = max.add_wrapping(I::Native::ONE);
let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
let half_neg = half.neg_wrapping();

// div is >= 10 and so this cannot overflow
let d = value.div_wrapping(div);
let r = value.mod_wrapping(div);

// Round result
let adjusted = match value >= I::Native::ZERO {
true if r >= half => d.add_wrapping(I::Native::ONE),
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
};

// if the reduction of the input number through scaling (dividing) is greater
// than a possible precision loss (plus potential increase via rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
// The rounding may add a digit, so for the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
let is_infallible_cast = input_precision as i8 + delta_scale < output_precision as i8;
(O::Native::from_decimal(adjusted), is_infallible_cast)
};
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
}
}

if is_infallible_cast {
scaled
/// Apply the rescaler function to the value.
/// If the rescaler is infallible, use the infallible function.
/// Otherwise, use the fallible function and validate the precision.
fn apply_rescaler<I: DecimalType, O: DecimalType>(
value: I::Native,
output_precision: u8,
f: impl Fn(I::Native) -> Option<O::Native>,
f_infallible: Option<impl Fn(I::Native) -> O::Native>,
) -> Option<O::Native>
where
I::Native: DecimalCast,
O::Native: DecimalCast,
{
if let Some(f_infallible) = f_infallible {
Some(f_infallible(value))
} else {
scaled.filter(|v| O::is_valid_decimal_precision(*v, output_precision))
f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
}
}

Expand Down
Loading