|
18 | 18 | //! Module for transforming a typed arrow `Array` to `VariantArray`. |
19 | 19 |
|
20 | 20 | use arrow::array::ArrowNativeTypeOp; |
21 | | -use arrow::compute::DecimalCast; |
| 21 | +use arrow::compute::{DecimalCast, make_downscaler, make_upscaler}; |
22 | 22 | use arrow::datatypes::{ |
23 | 23 | self, ArrowPrimitiveType, ArrowTimestampType, Decimal32Type, Decimal64Type, Decimal128Type, |
24 | 24 | DecimalType, |
@@ -189,93 +189,49 @@ where |
189 | 189 | /// Rescale a decimal from (input_precision, input_scale) to (output_precision, output_scale) |
190 | 190 | /// and return the scaled value if it fits the output precision. Similar to the implementation in |
191 | 191 | /// decimal.rs in arrow-cast. |
192 | | -pub(crate) fn rescale_decimal<I, O>( |
| 192 | +pub(crate) fn rescale_decimal<I: DecimalType, O: DecimalType>( |
193 | 193 | value: I::Native, |
194 | 194 | input_precision: u8, |
195 | 195 | input_scale: i8, |
196 | 196 | output_precision: u8, |
197 | 197 | output_scale: i8, |
198 | 198 | ) -> Option<O::Native> |
199 | 199 | where |
200 | | - I: DecimalType, |
201 | | - O: DecimalType, |
202 | 200 | I::Native: DecimalCast, |
203 | 201 | O::Native: DecimalCast, |
204 | 202 | { |
205 | | - let delta_scale = output_scale - input_scale; |
206 | | - |
207 | | - // Determine if the cast is infallible based on precision/scale math |
208 | | - let is_infallible_cast = |
209 | | - is_infallible_decimal_cast(input_precision, input_scale, output_precision, output_scale); |
210 | | - |
211 | | - let scaled = if delta_scale == 0 { |
212 | | - O::Native::from_decimal(value) |
213 | | - } else if delta_scale > 0 { |
214 | | - let mul = O::Native::from_decimal(10_i128) |
215 | | - .and_then(|t| t.pow_checked(delta_scale as u32).ok())?; |
216 | | - O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok()) |
| 203 | + if input_scale <= output_scale { |
| 204 | + let (f, f_infallible) = |
| 205 | + make_upscaler::<I, O>(input_precision, input_scale, output_precision, output_scale)?; |
| 206 | + apply_rescaler::<I, O>(value, output_precision, f, f_infallible) |
217 | 207 | } else { |
218 | | - // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the |
219 | | - // scale change divides out more digits than the input has precision and the result of the cast |
220 | | - // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest |
221 | | - // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values |
222 | | - // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even |
223 | | - // smaller results, which also round to zero. In that case, just return an array of zeros. |
224 | | - let delta_scale = delta_scale.unsigned_abs() as usize; |
225 | | - let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else { |
| 208 | + let Some((f, f_infallible)) = |
| 209 | + make_downscaler::<I, O>(input_precision, input_scale, output_precision, output_scale) |
| 210 | + else { |
| 211 | + // Scale reduction exceeds supported precision; result mathematically rounds to zero |
226 | 212 | return Some(O::Native::ZERO); |
227 | 213 | }; |
228 | | - let div = max.add_wrapping(I::Native::ONE); |
229 | | - let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); |
230 | | - let half_neg = half.neg_wrapping(); |
231 | | - |
232 | | - // div is >= 10 and so this cannot overflow |
233 | | - let d = value.div_wrapping(div); |
234 | | - let r = value.mod_wrapping(div); |
235 | | - |
236 | | - // Round result |
237 | | - let adjusted = match value >= I::Native::ZERO { |
238 | | - true if r >= half => d.add_wrapping(I::Native::ONE), |
239 | | - false if r <= half_neg => d.sub_wrapping(I::Native::ONE), |
240 | | - _ => d, |
241 | | - }; |
242 | | - O::Native::from_decimal(adjusted) |
243 | | - }; |
244 | | - |
245 | | - scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v, output_precision)) |
| 214 | + apply_rescaler::<I, O>(value, output_precision, f, f_infallible) |
| 215 | + } |
246 | 216 | } |
247 | 217 |
|
248 | | -/// Returns true if casting from (input_precision, input_scale) to |
249 | | -/// (output_precision, output_scale) is infallible based on precision/scale math. |
250 | | -fn is_infallible_decimal_cast( |
251 | | - input_precision: u8, |
252 | | - input_scale: i8, |
| 218 | +/// Apply the rescaler function to the value. |
| 219 | +/// If the rescaler is infallible, use the infallible function. |
| 220 | +/// Otherwise, use the fallible function and validate the precision. |
| 221 | +fn apply_rescaler<I: DecimalType, O: DecimalType>( |
| 222 | + value: I::Native, |
253 | 223 | output_precision: u8, |
254 | | - output_scale: i8, |
255 | | -) -> bool { |
256 | | - let delta_scale = output_scale - input_scale; |
257 | | - let input_precision = input_precision as i8; |
258 | | - let output_precision = output_precision as i8; |
259 | | - if delta_scale >= 0 { |
260 | | - // if the gain in precision (digits) is greater than the multiplication due to scaling |
261 | | - // every number will fit into the output type |
262 | | - // Example: If we are starting with any number of precision 5 [xxxxx], |
263 | | - // then an increase of scale by 3 will have the following effect on the representation: |
264 | | - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type |
265 | | - // needs to provide at least 8 digits precision |
266 | | - input_precision + delta_scale <= output_precision |
| 224 | + f: impl Fn(I::Native) -> Option<O::Native>, |
| 225 | + f_infallible: Option<impl Fn(I::Native) -> O::Native>, |
| 226 | +) -> Option<O::Native> |
| 227 | +where |
| 228 | + I::Native: DecimalCast, |
| 229 | + O::Native: DecimalCast, |
| 230 | +{ |
| 231 | + if let Some(f_infallible) = f_infallible { |
| 232 | + Some(f_infallible(value)) |
267 | 233 | } else { |
268 | | - // if the reduction of the input number through scaling (dividing) is greater |
269 | | - // than a possible precision loss (plus potential increase via rounding) |
270 | | - // every input number will fit into the output type |
271 | | - // Example: If we are starting with any number of precision 5 [xxxxx], |
272 | | - // then and decrease the scale by 3 will have the following effect on the representation: |
273 | | - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). |
274 | | - // The rounding may add an additional digit, so for the cast to be infallible, |
275 | | - // the output type needs to have at least 3 digits of precision. |
276 | | - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: |
277 | | - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible |
278 | | - input_precision + delta_scale < output_precision |
| 234 | + f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) |
279 | 235 | } |
280 | 236 | } |
281 | 237 |
|
|
0 commit comments