Skip to content

Commit 04505a0

Browse files
committed
Use scaler fns in variant decimal rescaling
1 parent 9b9800f commit 04505a0

File tree

3 files changed

+30
-74
lines changed

3 files changed

+30
-74
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ where
174174
/// In that case, the caller should treat this as an overflow for the output scale
175175
/// and handle it accordingly (e.g., return a cast error).
176176
#[allow(clippy::type_complexity)]
177-
fn make_upscaler<I: DecimalType, O: DecimalType>(
177+
pub fn make_upscaler<I: DecimalType, O: DecimalType>(
178178
input_precision: u8,
179179
input_scale: i8,
180180
output_precision: u8,
@@ -218,7 +218,7 @@ where
218218
/// available precision). Callers should therefore produce zero values (preserving nulls) rather
219219
/// than returning an error.
220220
#[allow(clippy::type_complexity)]
221-
fn make_downscaler<I: DecimalType, O: DecimalType>(
221+
pub fn make_downscaler<I: DecimalType, O: DecimalType>(
222222
input_precision: u8,
223223
input_scale: i8,
224224
output_precision: u8,

arrow-cast/src/cast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ use arrow_schema::*;
6767
use arrow_select::take::take;
6868
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};
6969

70-
pub use decimal::DecimalCast;
70+
pub use decimal::{DecimalCast, make_downscaler, make_upscaler};
7171

7272
/// CastOptions provides a way to override the default cast behaviors
7373
#[derive(Debug, Clone, PartialEq, Eq, Hash)]

parquet-variant-compute/src/type_conversion.rs

Lines changed: 27 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Module for transforming a typed arrow `Array` to `VariantArray`.
1919
2020
use arrow::array::ArrowNativeTypeOp;
21-
use arrow::compute::DecimalCast;
21+
use arrow::compute::{DecimalCast, make_downscaler, make_upscaler};
2222
use arrow::datatypes::{
2323
self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
2424
DecimalType,
@@ -189,93 +189,49 @@ where
189189
/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale)
190190
/// and return the scaled value if it fits the output precision. Similar to the implementation in
191191
/// decimal.rs in arrow-cast.
192-
pub(crate) fn rescale_decimal<I, O>(
192+
pub(crate) fn rescale_decimal<I: DecimalType, O: DecimalType>(
193193
value: I::Native,
194194
input_precision: u8,
195195
input_scale: i8,
196196
output_precision: u8,
197197
output_scale: i8,
198198
) -> Option<O::Native>
199199
where
200-
I: DecimalType,
201-
O: DecimalType,
202200
I::Native: DecimalCast,
203201
O::Native: DecimalCast,
204202
{
205-
let delta_scale = output_scale - input_scale;
206-
207-
// Determine if the cast is infallible based on precision/scale math
208-
let is_infallible_cast =
209-
is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale);
210-
211-
let scaled = if delta_scale == 0 {
212-
O::Native::from_decimal(value)
213-
} else if delta_scale > 0 {
214-
let mul = O::Native::from_decimal(10_i128)
215-
.and_then(|t| t.pow_checked(delta_scale as u32).ok())?;
216-
O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok())
203+
if input_scale <= output_scale {
204+
let (f, f_infallible) =
205+
make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
206+
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
217207
} else {
218-
// delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the
219-
// scale change divides out more digits than the input has precision and the result of the cast
220-
// is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest
221-
// possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values
222-
// (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even
223-
// smaller results, which also round to zero. In that case, just return an array of zeros.
224-
let delta_scale = delta_scale.unsigned_abs() as usize;
225-
let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else {
208+
let Some((f, f_infallible)) =
209+
make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
210+
else {
211+
// Scale reduction exceeds supported precision; result mathematically rounds to zero
226212
return Some(O::Native::ZERO);
227213
};
228-
let div = max.add_wrapping(I::Native::ONE);
229-
let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE));
230-
let half_neg = half.neg_wrapping();
231-
232-
// div is >= 10 and so this cannot overflow
233-
let d = value.div_wrapping(div);
234-
let r = value.mod_wrapping(div);
235-
236-
// Round result
237-
let adjusted = match value >= I::Native::ZERO {
238-
true if r >= half => d.add_wrapping(I::Native::ONE),
239-
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
240-
_ => d,
241-
};
242-
O::Native::from_decimal(adjusted)
243-
};
244-
245-
scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision))
214+
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
215+
}
246216
}
247217

248-
/// Returns true if casting from (input_precision, input_scale) to
249-
/// (output_precision, output_scale) is infallible based on precision/scale math.
250-
fn is_infallible_decimal_cast(
251-
input_precision: u8,
252-
input_scale: i8,
218+
/// Apply the rescaler function to the value.
219+
/// If the rescaler is infallible, use the infallible function.
220+
/// Otherwise, use the fallible function and validate the precision.
221+
fn apply_rescaler<I: DecimalType, O: DecimalType>(
222+
value: I::Native,
253223
output_precision: u8,
254-
output_scale: i8,
255-
) -> bool {
256-
let delta_scale = output_scale - input_scale;
257-
let input_precision = input_precision as i8;
258-
let output_precision = output_precision as i8;
259-
if delta_scale >= 0 {
260-
// if the gain in precision (digits) is greater than the multiplication due to scaling
261-
// every number will fit into the output type
262-
// Example: If we are starting with any number of precision 5 [xxxxx],
263-
// then an increase of scale by 3 will have the following effect on the representation:
264-
// [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type
265-
// needs to provide at least 8 digits precision
266-
input_precision + delta_scale <= output_precision
224+
f: impl Fn(I::Native) -> Option<O::Native>,
225+
f_infallible: Option<impl Fn(I::Native) -> O::Native>,
226+
) -> Option<O::Native>
227+
where
228+
I::Native: DecimalCast,
229+
O::Native: DecimalCast,
230+
{
231+
if let Some(f_infallible) = f_infallible {
232+
Some(f_infallible(value))
267233
} else {
268-
// if the reduction of the input number through scaling (dividing) is greater
269-
// than a possible precision loss (plus potential increase via rounding)
270-
// every input number will fit into the output type
271-
// Example: If we are starting with any number of precision 5 [xxxxx],
272-
// then and decrease the scale by 3 will have the following effect on the representation:
273-
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
274-
// The rounding may add an additional digit, so for the cast to be infallible,
275-
// the output type needs to have at least 3 digits of precision.
276-
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
277-
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible
278-
input_precision + delta_scale < output_precision
234+
f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
279235
}
280236
}
281237

0 commit comments

Comments
 (0)