-
Notifications
You must be signed in to change notification settings - Fork 1k
Refactor arrow-cast decimal casting to unify the rescale logic used in Parquet variant casts #8689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
3f792b4
a5fcc37
9a88258
9b9800f
04505a0
e32483c
5900bd4
49e72cd
2a73ffb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,7 +145,7 @@ impl DecimalCast for i256 { | |
| } | ||
| } | ||
|
|
||
| pub(crate) fn cast_decimal_to_decimal_error<I, O>( | ||
| fn cast_decimal_to_decimal_error<I, O>( | ||
| output_precision: u8, | ||
| output_scale: i8, | ||
| ) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError | ||
|
|
@@ -166,50 +166,86 @@ where | |
| } | ||
| } | ||
|
|
||
| pub(crate) fn convert_to_smaller_scale_decimal<I, O>( | ||
| array: &PrimitiveArray<I>, | ||
| /// Construct closures to upscale decimals from `(input_precision, input_scale)` to | ||
| /// `(output_precision, output_scale)`. | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// | ||
| /// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale` | ||
| /// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`. | ||
| /// In that case, the caller should treat this as an overflow for the output scale | ||
| /// and handle it accordingly (e.g., return a cast error). | ||
| #[allow(clippy::type_complexity)] | ||
| pub fn make_upscaler<I: DecimalType, O: DecimalType>( | ||
| input_precision: u8, | ||
| input_scale: i8, | ||
| output_precision: u8, | ||
| output_scale: i8, | ||
| cast_options: &CastOptions, | ||
| ) -> Result<PrimitiveArray<O>, ArrowError> | ||
| ) -> Option<( | ||
| impl Fn(I::Native) -> Option<O::Native>, | ||
| Option<impl Fn(I::Native) -> O::Native>, | ||
| )> | ||
| where | ||
| I: DecimalType, | ||
| O: DecimalType, | ||
| I::Native: DecimalCast + ArrowNativeTypeOp, | ||
| O::Native: DecimalCast + ArrowNativeTypeOp, | ||
| { | ||
| let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
| let delta_scale = input_scale - output_scale; | ||
| // if the reduction of the input number through scaling (dividing) is greater | ||
| // than a possible precision loss (plus potential increase via rounding) | ||
| // every input number will fit into the output type | ||
| let delta_scale = output_scale - input_scale; | ||
|
|
||
| // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). | ||
| // Adding 1 yields exactly 10^k without computing a power at runtime. | ||
| // Using the precomputed table avoids pow(10, k) and its checked/overflow | ||
| // handling, which is faster and simpler for scaling by 10^delta_scale. | ||
| let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; | ||
| let mul = max.add_wrapping(O::Native::ONE); | ||
| let f = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); | ||
|
|
||
| // if the gain in precision (digits) is greater than the multiplication due to scaling | ||
| // every number will fit into the output type | ||
| // Example: If we are starting with any number of precision 5 [xxxxx], | ||
| // then and decrease the scale by 3 will have the following effect on the representation: | ||
| // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). | ||
| // The rounding may add an additional digit, so the cast to be infallible, | ||
| // the output type needs to have at least 3 digits of precision. | ||
| // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: | ||
| // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible | ||
| let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); | ||
| // then an increase of scale by 3 will have the following effect on the representation: | ||
| // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type | ||
| // needs to provide at least 8 digits precision | ||
| let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see the old code did this too, but it seems like the cast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this becomes an issue once we expose the API. Normally it isn’t a concern, since the precision is guaranteed to stay under 128. Let me redesign the public interface |
||
| let f_infallible = is_infallible_cast | ||
| .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul)); | ||
| Some((f, f_infallible)) | ||
|
||
| } | ||
|
|
||
| /// Construct closures to downscale decimals from `(input_precision, input_scale)` to | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// `(output_precision, output_scale)`. | ||
| /// | ||
| /// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale` | ||
| /// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`. | ||
| /// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the | ||
| /// available precision). Callers should therefore produce zero values (preserving nulls) rather | ||
| /// than returning an error. | ||
| #[allow(clippy::type_complexity)] | ||
| pub fn make_downscaler<I: DecimalType, O: DecimalType>( | ||
| input_precision: u8, | ||
| input_scale: i8, | ||
| output_precision: u8, | ||
| output_scale: i8, | ||
| ) -> Option<( | ||
| impl Fn(I::Native) -> Option<O::Native>, | ||
| Option<impl Fn(I::Native) -> O::Native>, | ||
| )> | ||
| where | ||
| I::Native: DecimalCast + ArrowNativeTypeOp, | ||
| O::Native: DecimalCast + ArrowNativeTypeOp, | ||
| { | ||
| let delta_scale = input_scale - output_scale; | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the | ||
| // scale change divides out more digits than the input has precision and the result of the cast | ||
| // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest | ||
| // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values | ||
| // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even | ||
| // smaller results, which also round to zero. In that case, just return an array of zeros. | ||
| let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize) else { | ||
| let zeros = vec![O::Native::ZERO; array.len()]; | ||
| return Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())); | ||
| }; | ||
| let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; | ||
|
|
||
| let div = max.add_wrapping(I::Native::ONE); | ||
| let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); | ||
| let half_neg = half.neg_wrapping(); | ||
|
|
||
| let f = |x: I::Native| { | ||
| let f = move |x: I::Native| { | ||
| // div is >= 10 and so this cannot overflow | ||
| let d = x.div_wrapping(div); | ||
| let r = x.mod_wrapping(div); | ||
|
|
@@ -223,24 +259,49 @@ where | |
| O::Native::from_decimal(adjusted) | ||
| }; | ||
|
|
||
| Ok(if is_infallible_cast { | ||
| // make sure we don't perform calculations that don't make sense w/o validation | ||
| validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed | ||
| // to fit into the target type | ||
| array.unary(g) | ||
| // if the reduction of the input number through scaling (dividing) is greater | ||
| // than a possible precision loss (plus potential increase via rounding) | ||
| // every input number will fit into the output type | ||
| // Example: If we are starting with any number of precision 5 [xxxxx], | ||
| // then and decrease the scale by 3 will have the following effect on the representation: | ||
| // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). | ||
| // The rounding may add a digit, so the cast to be infallible, | ||
| // the output type needs to have at least 3 digits of precision. | ||
| // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: | ||
| // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible | ||
| let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); | ||
| let f_infallible = is_infallible_cast.then_some(move |x| f(x).unwrap()); | ||
| Some((f, f_infallible)) | ||
| } | ||
|
|
||
| fn apply_decimal_cast<I: DecimalType, O: DecimalType>( | ||
| array: &PrimitiveArray<I>, | ||
| output_precision: u8, | ||
| output_scale: i8, | ||
| f: impl Fn(I::Native) -> Option<O::Native>, | ||
| f_infallible: Option<impl Fn(I::Native) -> O::Native>, | ||
| cast_options: &CastOptions, | ||
| ) -> Result<PrimitiveArray<O>, ArrowError> | ||
| where | ||
| I::Native: DecimalCast + ArrowNativeTypeOp, | ||
| O::Native: DecimalCast + ArrowNativeTypeOp, | ||
| { | ||
| let array = if let Some(f_infallible) = f_infallible { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a very nice formulation now |
||
| array.unary(f_infallible) | ||
| } else if cast_options.safe { | ||
| array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) | ||
| } else { | ||
| let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
| array.try_unary(|x| { | ||
| f(x).ok_or_else(|| error(x)).and_then(|v| { | ||
| O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) | ||
| }) | ||
| })? | ||
| }) | ||
| }; | ||
| Ok(array) | ||
| } | ||
|
|
||
| pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>( | ||
| fn convert_to_smaller_scale_decimal<I, O>( | ||
| array: &PrimitiveArray<I>, | ||
| input_precision: u8, | ||
| input_scale: i8, | ||
|
|
@@ -254,36 +315,58 @@ where | |
| I::Native: DecimalCast + ArrowNativeTypeOp, | ||
| O::Native: DecimalCast + ArrowNativeTypeOp, | ||
| { | ||
| let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); | ||
| let delta_scale = output_scale - input_scale; | ||
| let mul = O::Native::from_decimal(10_i128) | ||
| .unwrap() | ||
| .pow_checked(delta_scale as u32)?; | ||
| if let Some((f, f_infallible)) = | ||
| make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale) | ||
| { | ||
| apply_decimal_cast( | ||
| array, | ||
| output_precision, | ||
| output_scale, | ||
| f, | ||
| f_infallible, | ||
| cast_options, | ||
| ) | ||
| } else { | ||
| // Scale reduction exceeds supported precision; result mathematically rounds to zero | ||
| let zeros = vec![O::Native::ZERO; array.len()]; | ||
| Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())) | ||
| } | ||
| } | ||
|
|
||
| // if the gain in precision (digits) is greater than the multiplication due to scaling | ||
| // every number will fit into the output type | ||
| // Example: If we are starting with any number of precision 5 [xxxxx], | ||
| // then an increase of scale by 3 will have the following effect on the representation: | ||
| // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type | ||
| // needs to provide at least 8 digits precision | ||
| let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); | ||
| let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); | ||
|
|
||
| Ok(if is_infallible_cast { | ||
| // make sure we don't perform calculations that don't make sense w/o validation | ||
| validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; | ||
| // unwrapping is safe since the result is guaranteed to fit into the target type | ||
| let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); | ||
| array.unary(f) | ||
| } else if cast_options.safe { | ||
| array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) | ||
| fn convert_to_bigger_or_equal_scale_decimal<I, O>( | ||
| array: &PrimitiveArray<I>, | ||
| input_precision: u8, | ||
| input_scale: i8, | ||
| output_precision: u8, | ||
| output_scale: i8, | ||
| cast_options: &CastOptions, | ||
| ) -> Result<PrimitiveArray<O>, ArrowError> | ||
| where | ||
| I: DecimalType, | ||
| O: DecimalType, | ||
| I::Native: DecimalCast + ArrowNativeTypeOp, | ||
| O::Native: DecimalCast + ArrowNativeTypeOp, | ||
| { | ||
| if let Some((f, f_infallible)) = | ||
| make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale) | ||
| { | ||
| apply_decimal_cast( | ||
| array, | ||
| output_precision, | ||
| output_scale, | ||
| f, | ||
| f_infallible, | ||
| cast_options, | ||
| ) | ||
| } else { | ||
| array.try_unary(|x| { | ||
| f(x).ok_or_else(|| error(x)).and_then(|v| { | ||
| O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) | ||
| }) | ||
| })? | ||
| }) | ||
| // Scale increase exceeds supported precision; return overflow error | ||
| Err(ArrowError::CastError(format!( | ||
| "Cannot cast to {}({}, {}). Value overflows for output scale", | ||
| O::PREFIX, | ||
| output_precision, | ||
| output_scale | ||
| ))) | ||
| } | ||
| } | ||
|
|
||
| // Only support one type of decimal cast operations | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
downgrade the visibility since it's only used in this file