Skip to content

Commit 4fe7d19

Browse files
committed
Move rescale_decimal to arrow-cast
1 parent e32483c commit 4fe7d19

File tree

3 files changed

+74
-75
lines changed

3 files changed

+74
-75
lines changed

arrow-cast/src/cast/decimal.rs

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -145,27 +145,6 @@ impl DecimalCast for i256 {
145145
}
146146
}
147147

148-
fn cast_decimal_to_decimal_error<I, O>(
149-
output_precision: u8,
150-
output_scale: i8,
151-
) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
152-
where
153-
I: DecimalType,
154-
O: DecimalType,
155-
I::Native: DecimalCast + ArrowNativeTypeOp,
156-
O::Native: DecimalCast + ArrowNativeTypeOp,
157-
{
158-
move |x: I::Native| {
159-
ArrowError::CastError(format!(
160-
"Cannot cast to {}({}, {}). Overflowing on {:?}",
161-
O::PREFIX,
162-
output_precision,
163-
output_scale,
164-
x
165-
))
166-
}
167-
}
168-
169148
/// Construct closures to upscale decimals from `(input_precision, input_scale)` to
170149
/// `(output_precision, output_scale)`.
171150
///
@@ -174,7 +153,7 @@ where
174153
/// In that case, the caller should treat this as an overflow for the output scale
175154
/// and handle it accordingly (e.g., return a cast error).
176155
#[allow(clippy::type_complexity)]
177-
pub fn make_upscaler<I: DecimalType, O: DecimalType>(
156+
fn make_upscaler<I: DecimalType, O: DecimalType>(
178157
input_precision: u8,
179158
input_scale: i8,
180159
output_precision: u8,
@@ -218,7 +197,7 @@ where
218197
/// available precision). Callers should therefore produce zero values (preserving nulls) rather
219198
/// than returning an error.
220199
#[allow(clippy::type_complexity)]
221-
pub fn make_downscaler<I: DecimalType, O: DecimalType>(
200+
fn make_downscaler<I: DecimalType, O: DecimalType>(
222201
input_precision: u8,
223202
input_scale: i8,
224203
output_precision: u8,
@@ -274,6 +253,76 @@ where
274253
Some((f, f_infallible))
275254
}
276255

256+
/// Apply the rescaler function to the value.
257+
/// If the rescaler is infallible, use the infallible function.
258+
/// Otherwise, use the fallible function and validate the precision.
259+
fn apply_rescaler<I: DecimalType, O: DecimalType>(
260+
value: I::Native,
261+
output_precision: u8,
262+
f: impl Fn(I::Native) -> Option<O::Native>,
263+
f_infallible: Option<impl Fn(I::Native) -> O::Native>,
264+
) -> Option<O::Native>
265+
where
266+
I::Native: DecimalCast,
267+
O::Native: DecimalCast,
268+
{
269+
if let Some(f_infallible) = f_infallible {
270+
Some(f_infallible(value))
271+
} else {
272+
f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
273+
}
274+
}
275+
276+
/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale)
277+
/// and return the scaled value if it fits the output precision. Similar to the implementation in
278+
/// decimal.rs in arrow-cast.
279+
pub fn rescale_decimal<I: DecimalType, O: DecimalType>(
280+
value: I::Native,
281+
input_precision: u8,
282+
input_scale: i8,
283+
output_precision: u8,
284+
output_scale: i8,
285+
) -> Option<O::Native>
286+
where
287+
I::Native: DecimalCast + ArrowNativeTypeOp,
288+
O::Native: DecimalCast + ArrowNativeTypeOp,
289+
{
290+
if input_scale <= output_scale {
291+
let (f, f_infallible) =
292+
make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
293+
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
294+
} else {
295+
let Some((f, f_infallible)) =
296+
make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
297+
else {
298+
// Scale reduction exceeds supported precision; result mathematically rounds to zero
299+
return Some(O::Native::ZERO);
300+
};
301+
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
302+
}
303+
}
304+
305+
fn cast_decimal_to_decimal_error<I, O>(
306+
output_precision: u8,
307+
output_scale: i8,
308+
) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
309+
where
310+
I: DecimalType,
311+
O: DecimalType,
312+
I::Native: DecimalCast + ArrowNativeTypeOp,
313+
O::Native: DecimalCast + ArrowNativeTypeOp,
314+
{
315+
move |x: I::Native| {
316+
ArrowError::CastError(format!(
317+
"Cannot cast to {}({}, {}). Overflowing on {:?}",
318+
O::PREFIX,
319+
output_precision,
320+
output_scale,
321+
x
322+
))
323+
}
324+
}
325+
277326
fn apply_decimal_cast<I: DecimalType, O: DecimalType>(
278327
array: &PrimitiveArray<I>,
279328
output_precision: u8,

arrow-cast/src/cast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ use arrow_schema::*;
6767
use arrow_select::take::take;
6868
use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive};
6969

70-
pub use decimal::{DecimalCast, make_downscaler, make_upscaler};
70+
pub use decimal::{DecimalCast, rescale_decimal};
7171

7272
/// CastOptions provides a way to override the default cast behaviors
7373
#[derive(Debug, Clone, PartialEq, Eq, Hash)]

parquet-variant-compute/src/type_conversion.rs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
//! Module for transforming a typed arrow `Array` to `VariantArray`.
1919
20-
use arrow::array::ArrowNativeTypeOp;
21-
use arrow::compute::{DecimalCast, make_downscaler, make_upscaler};
20+
use arrow::compute::{DecimalCast, rescale_decimal};
2221
use arrow::datatypes::{
2322
self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type,
2423
DecimalType,
@@ -190,55 +189,6 @@ where
190189
}
191190
}
192191

193-
/// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale)
194-
/// and return the scaled value if it fits the output precision. Similar to the implementation in
195-
/// decimal.rs in arrow-cast.
196-
pub(crate) fn rescale_decimal<I: DecimalType, O: DecimalType>(
197-
value: I::Native,
198-
input_precision: u8,
199-
input_scale: i8,
200-
output_precision: u8,
201-
output_scale: i8,
202-
) -> Option<O::Native>
203-
where
204-
I::Native: DecimalCast,
205-
O::Native: DecimalCast,
206-
{
207-
if input_scale <= output_scale {
208-
let (f, f_infallible) =
209-
make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?;
210-
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
211-
} else {
212-
let Some((f, f_infallible)) =
213-
make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)
214-
else {
215-
// Scale reduction exceeds supported precision; result mathematically rounds to zero
216-
return Some(O::Native::ZERO);
217-
};
218-
apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
219-
}
220-
}
221-
222-
/// Apply the rescaler function to the value.
223-
/// If the rescaler is infallible, use the infallible function.
224-
/// Otherwise, use the fallible function and validate the precision.
225-
fn apply_rescaler<I: DecimalType, O: DecimalType>(
226-
value: I::Native,
227-
output_precision: u8,
228-
f: impl Fn(I::Native) -> Option<O::Native>,
229-
f_infallible: Option<impl Fn(I::Native) -> O::Native>,
230-
) -> Option<O::Native>
231-
where
232-
I::Native: DecimalCast,
233-
O::Native: DecimalCast,
234-
{
235-
if let Some(f_infallible) = f_infallible {
236-
Some(f_infallible(value))
237-
} else {
238-
f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision))
239-
}
240-
}
241-
242192
/// Convert the value at a specific index in the given array into a `Variant`.
243193
macro_rules! non_generic_conversion_single_value {
244194
($array:expr, $cast_fn:expr, $index:expr) => {{

0 commit comments

Comments
 (0)