/Users/andrewlamb/Software/arrow-rs/arrow-cast/src/cast/decimal.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | use crate::cast::*; |
19 | | |
20 | | /// A utility trait that provides checked conversions between |
21 | | /// decimal types inspired by [`NumCast`] |
22 | | pub(crate) trait DecimalCast: Sized { |
23 | | fn to_i32(self) -> Option<i32>; |
24 | | |
25 | | fn to_i64(self) -> Option<i64>; |
26 | | |
27 | | fn to_i128(self) -> Option<i128>; |
28 | | |
29 | | fn to_i256(self) -> Option<i256>; |
30 | | |
31 | | fn from_decimal<T: DecimalCast>(n: T) -> Option<Self>; |
32 | | |
33 | | fn from_f64(n: f64) -> Option<Self>; |
34 | | } |
35 | | |
36 | | impl DecimalCast for i32 { |
37 | 0 | fn to_i32(self) -> Option<i32> { |
38 | 0 | Some(self) |
39 | 0 | } |
40 | | |
41 | 0 | fn to_i64(self) -> Option<i64> { |
42 | 0 | Some(self as i64) |
43 | 0 | } |
44 | | |
45 | 0 | fn to_i128(self) -> Option<i128> { |
46 | 0 | Some(self as i128) |
47 | 0 | } |
48 | | |
49 | 0 | fn to_i256(self) -> Option<i256> { |
50 | 0 | Some(i256::from_i128(self as i128)) |
51 | 0 | } |
52 | | |
53 | 0 | fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> { |
54 | 0 | n.to_i32() |
55 | 0 | } |
56 | | |
57 | 0 | fn from_f64(n: f64) -> Option<Self> { |
58 | 0 | n.to_i32() |
59 | 0 | } |
60 | | } |
61 | | |
62 | | impl DecimalCast for i64 { |
63 | 0 | fn to_i32(self) -> Option<i32> { |
64 | 0 | i32::try_from(self).ok() |
65 | 0 | } |
66 | | |
67 | 0 | fn to_i64(self) -> Option<i64> { |
68 | 0 | Some(self) |
69 | 0 | } |
70 | | |
71 | 0 | fn to_i128(self) -> Option<i128> { |
72 | 0 | Some(self as i128) |
73 | 0 | } |
74 | | |
75 | 0 | fn to_i256(self) -> Option<i256> { |
76 | 0 | Some(i256::from_i128(self as i128)) |
77 | 0 | } |
78 | | |
79 | 0 | fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> { |
80 | 0 | n.to_i64() |
81 | 0 | } |
82 | | |
83 | 0 | fn from_f64(n: f64) -> Option<Self> { |
84 | 0 | n.to_i64() |
85 | 0 | } |
86 | | } |
87 | | |
88 | | impl DecimalCast for i128 { |
89 | 0 | fn to_i32(self) -> Option<i32> { |
90 | 0 | i32::try_from(self).ok() |
91 | 0 | } |
92 | | |
93 | 0 | fn to_i64(self) -> Option<i64> { |
94 | 0 | i64::try_from(self).ok() |
95 | 0 | } |
96 | | |
97 | 0 | fn to_i128(self) -> Option<i128> { |
98 | 0 | Some(self) |
99 | 0 | } |
100 | | |
101 | 0 | fn to_i256(self) -> Option<i256> { |
102 | 0 | Some(i256::from_i128(self)) |
103 | 0 | } |
104 | | |
105 | 0 | fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> { |
106 | 0 | n.to_i128() |
107 | 0 | } |
108 | | |
109 | 0 | fn from_f64(n: f64) -> Option<Self> { |
110 | 0 | n.to_i128() |
111 | 0 | } |
112 | | } |
113 | | |
114 | | impl DecimalCast for i256 { |
115 | 0 | fn to_i32(self) -> Option<i32> { |
116 | 0 | self.to_i128().map(|x| i32::try_from(x).ok())? |
117 | 0 | } |
118 | | |
119 | 0 | fn to_i64(self) -> Option<i64> { |
120 | 0 | self.to_i128().map(|x| i64::try_from(x).ok())? |
121 | 0 | } |
122 | | |
123 | 0 | fn to_i128(self) -> Option<i128> { |
124 | 0 | self.to_i128() |
125 | 0 | } |
126 | | |
127 | 0 | fn to_i256(self) -> Option<i256> { |
128 | 0 | Some(self) |
129 | 0 | } |
130 | | |
131 | 0 | fn from_decimal<T: DecimalCast>(n: T) -> Option<Self> { |
132 | 0 | n.to_i256() |
133 | 0 | } |
134 | | |
135 | 0 | fn from_f64(n: f64) -> Option<Self> { |
136 | 0 | i256::from_f64(n) |
137 | 0 | } |
138 | | } |
139 | | |
140 | 0 | pub(crate) fn cast_decimal_to_decimal_error<I, O>( |
141 | 0 | output_precision: u8, |
142 | 0 | output_scale: i8, |
143 | 0 | ) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError |
144 | 0 | where |
145 | 0 | I: DecimalType, |
146 | 0 | O: DecimalType, |
147 | 0 | I::Native: DecimalCast + ArrowNativeTypeOp, |
148 | 0 | O::Native: DecimalCast + ArrowNativeTypeOp, |
149 | | { |
150 | 0 | move |x: I::Native| { |
151 | 0 | ArrowError::CastError(format!( |
152 | 0 | "Cannot cast to {}({}, {}). Overflowing on {:?}", |
153 | 0 | O::PREFIX, |
154 | 0 | output_precision, |
155 | 0 | output_scale, |
156 | 0 | x |
157 | 0 | )) |
158 | 0 | } |
159 | 0 | } |
160 | | |
161 | 0 | pub(crate) fn convert_to_smaller_scale_decimal<I, O>( |
162 | 0 | array: &PrimitiveArray<I>, |
163 | 0 | input_precision: u8, |
164 | 0 | input_scale: i8, |
165 | 0 | output_precision: u8, |
166 | 0 | output_scale: i8, |
167 | 0 | cast_options: &CastOptions, |
168 | 0 | ) -> Result<PrimitiveArray<O>, ArrowError> |
169 | 0 | where |
170 | 0 | I: DecimalType, |
171 | 0 | O: DecimalType, |
172 | 0 | I::Native: DecimalCast + ArrowNativeTypeOp, |
173 | 0 | O::Native: DecimalCast + ArrowNativeTypeOp, |
174 | | { |
175 | 0 | let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); |
176 | 0 | let delta_scale = input_scale - output_scale; |
177 | | // if the reduction of the input number through scaling (dividing) is greater |
178 | | // than a possible precision loss (plus potential increase via rounding) |
179 | | // every input number will fit into the output type |
180 | | // Example: If we are starting with any number of precision 5 [xxxxx], |
181 | | // then and decrease the scale by 3 will have the following effect on the representation: |
182 | | // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). |
183 | | // The rounding may add an additional digit, so the cast to be infallible, |
184 | | // the output type needs to have at least 3 digits of precision. |
185 | | // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: |
186 | | // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible |
187 | 0 | let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); |
188 | | |
189 | 0 | let div = I::Native::from_decimal(10_i128) |
190 | 0 | .unwrap() |
191 | 0 | .pow_checked(delta_scale as u32)?; |
192 | | |
193 | 0 | let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); |
194 | 0 | let half_neg = half.neg_wrapping(); |
195 | | |
196 | 0 | let f = |x: I::Native| { |
197 | | // div is >= 10 and so this cannot overflow |
198 | 0 | let d = x.div_wrapping(div); |
199 | 0 | let r = x.mod_wrapping(div); |
200 | | |
201 | | // Round result |
202 | 0 | let adjusted = match x >= I::Native::ZERO { |
203 | 0 | true if r >= half => d.add_wrapping(I::Native::ONE), |
204 | 0 | false if r <= half_neg => d.sub_wrapping(I::Native::ONE), |
205 | 0 | _ => d, |
206 | | }; |
207 | 0 | O::Native::from_decimal(adjusted) |
208 | 0 | }; |
209 | | |
210 | 0 | Ok(if is_infallible_cast { |
211 | | // make sure we don't perform calculations that don't make sense w/o validation |
212 | 0 | validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; |
213 | 0 | let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed |
214 | | // to fit into the target type |
215 | 0 | array.unary(g) |
216 | 0 | } else if cast_options.safe { |
217 | 0 | array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) |
218 | | } else { |
219 | 0 | array.try_unary(|x| { |
220 | 0 | f(x).ok_or_else(|| error(x)) |
221 | 0 | .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) |
222 | 0 | })? |
223 | | }) |
224 | 0 | } |
225 | | |
226 | 0 | pub(crate) fn convert_to_bigger_or_equal_scale_decimal<I, O>( |
227 | 0 | array: &PrimitiveArray<I>, |
228 | 0 | input_precision: u8, |
229 | 0 | input_scale: i8, |
230 | 0 | output_precision: u8, |
231 | 0 | output_scale: i8, |
232 | 0 | cast_options: &CastOptions, |
233 | 0 | ) -> Result<PrimitiveArray<O>, ArrowError> |
234 | 0 | where |
235 | 0 | I: DecimalType, |
236 | 0 | O: DecimalType, |
237 | 0 | I::Native: DecimalCast + ArrowNativeTypeOp, |
238 | 0 | O::Native: DecimalCast + ArrowNativeTypeOp, |
239 | | { |
240 | 0 | let error = cast_decimal_to_decimal_error::<I, O>(output_precision, output_scale); |
241 | 0 | let delta_scale = output_scale - input_scale; |
242 | 0 | let mul = O::Native::from_decimal(10_i128) |
243 | 0 | .unwrap() |
244 | 0 | .pow_checked(delta_scale as u32)?; |
245 | | |
246 | | // if the gain in precision (digits) is greater than the multiplication due to scaling |
247 | | // every number will fit into the output type |
248 | | // Example: If we are starting with any number of precision 5 [xxxxx], |
249 | | // then an increase of scale by 3 will have the following effect on the representation: |
250 | | // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type |
251 | | // needs to provide at least 8 digits precision |
252 | 0 | let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); |
253 | 0 | let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); |
254 | | |
255 | 0 | Ok(if is_infallible_cast { |
256 | | // make sure we don't perform calculations that don't make sense w/o validation |
257 | 0 | validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?; |
258 | | // unwrapping is safe since the result is guaranteed to fit into the target type |
259 | 0 | let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); |
260 | 0 | array.unary(f) |
261 | 0 | } else if cast_options.safe { |
262 | 0 | array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) |
263 | | } else { |
264 | 0 | array.try_unary(|x| { |
265 | 0 | f(x).ok_or_else(|| error(x)) |
266 | 0 | .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) |
267 | 0 | })? |
268 | | }) |
269 | 0 | } |
270 | | |
271 | | // Only support one type of decimal cast operations |
272 | 0 | pub(crate) fn cast_decimal_to_decimal_same_type<T>( |
273 | 0 | array: &PrimitiveArray<T>, |
274 | 0 | input_precision: u8, |
275 | 0 | input_scale: i8, |
276 | 0 | output_precision: u8, |
277 | 0 | output_scale: i8, |
278 | 0 | cast_options: &CastOptions, |
279 | 0 | ) -> Result<ArrayRef, ArrowError> |
280 | 0 | where |
281 | 0 | T: DecimalType, |
282 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
283 | | { |
284 | 0 | let array: PrimitiveArray<T> = |
285 | 0 | if input_scale == output_scale && input_precision <= output_precision { |
286 | 0 | array.clone() |
287 | 0 | } else if input_scale <= output_scale { |
288 | 0 | convert_to_bigger_or_equal_scale_decimal::<T, T>( |
289 | 0 | array, |
290 | 0 | input_precision, |
291 | 0 | input_scale, |
292 | 0 | output_precision, |
293 | 0 | output_scale, |
294 | 0 | cast_options, |
295 | 0 | )? |
296 | | } else { |
297 | | // input_scale > output_scale |
298 | 0 | convert_to_smaller_scale_decimal::<T, T>( |
299 | 0 | array, |
300 | 0 | input_precision, |
301 | 0 | input_scale, |
302 | 0 | output_precision, |
303 | 0 | output_scale, |
304 | 0 | cast_options, |
305 | 0 | )? |
306 | | }; |
307 | | |
308 | 0 | Ok(Arc::new(array.with_precision_and_scale( |
309 | 0 | output_precision, |
310 | 0 | output_scale, |
311 | 0 | )?)) |
312 | 0 | } |
313 | | |
314 | | // Support two different types of decimal cast operations |
315 | 0 | pub(crate) fn cast_decimal_to_decimal<I, O>( |
316 | 0 | array: &PrimitiveArray<I>, |
317 | 0 | input_precision: u8, |
318 | 0 | input_scale: i8, |
319 | 0 | output_precision: u8, |
320 | 0 | output_scale: i8, |
321 | 0 | cast_options: &CastOptions, |
322 | 0 | ) -> Result<ArrayRef, ArrowError> |
323 | 0 | where |
324 | 0 | I: DecimalType, |
325 | 0 | O: DecimalType, |
326 | 0 | I::Native: DecimalCast + ArrowNativeTypeOp, |
327 | 0 | O::Native: DecimalCast + ArrowNativeTypeOp, |
328 | | { |
329 | 0 | let array: PrimitiveArray<O> = if input_scale > output_scale { |
330 | 0 | convert_to_smaller_scale_decimal::<I, O>( |
331 | 0 | array, |
332 | 0 | input_precision, |
333 | 0 | input_scale, |
334 | 0 | output_precision, |
335 | 0 | output_scale, |
336 | 0 | cast_options, |
337 | 0 | )? |
338 | | } else { |
339 | 0 | convert_to_bigger_or_equal_scale_decimal::<I, O>( |
340 | 0 | array, |
341 | 0 | input_precision, |
342 | 0 | input_scale, |
343 | 0 | output_precision, |
344 | 0 | output_scale, |
345 | 0 | cast_options, |
346 | 0 | )? |
347 | | }; |
348 | | |
349 | 0 | Ok(Arc::new(array.with_precision_and_scale( |
350 | 0 | output_precision, |
351 | 0 | output_scale, |
352 | 0 | )?)) |
353 | 0 | } |
354 | | |
355 | | /// Parses given string to specified decimal native (i128/i256) based on given |
356 | | /// scale. Returns an `Err` if it cannot parse given string. |
357 | 0 | pub(crate) fn parse_string_to_decimal_native<T: DecimalType>( |
358 | 0 | value_str: &str, |
359 | 0 | scale: usize, |
360 | 0 | ) -> Result<T::Native, ArrowError> |
361 | 0 | where |
362 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
363 | | { |
364 | 0 | let value_str = value_str.trim(); |
365 | 0 | let parts: Vec<&str> = value_str.split('.').collect(); |
366 | 0 | if parts.len() > 2 { |
367 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
368 | 0 | "Invalid decimal format: {value_str:?}" |
369 | 0 | ))); |
370 | 0 | } |
371 | | |
372 | 0 | let (negative, first_part) = if parts[0].is_empty() { |
373 | 0 | (false, parts[0]) |
374 | | } else { |
375 | 0 | match parts[0].as_bytes()[0] { |
376 | 0 | b'-' => (true, &parts[0][1..]), |
377 | 0 | b'+' => (false, &parts[0][1..]), |
378 | 0 | _ => (false, parts[0]), |
379 | | } |
380 | | }; |
381 | | |
382 | 0 | let integers = first_part; |
383 | 0 | let decimals = if parts.len() == 2 { parts[1] } else { "" }; |
384 | | |
385 | 0 | if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { |
386 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
387 | 0 | "Invalid decimal format: {value_str:?}" |
388 | 0 | ))); |
389 | 0 | } |
390 | | |
391 | 0 | if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { |
392 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
393 | 0 | "Invalid decimal format: {value_str:?}" |
394 | 0 | ))); |
395 | 0 | } |
396 | | |
397 | | // Adjust decimal based on scale |
398 | 0 | let mut number_decimals = if decimals.len() > scale { |
399 | 0 | let decimal_number = i256::from_string(decimals).ok_or_else(|| { |
400 | 0 | ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) |
401 | 0 | })?; |
402 | | |
403 | 0 | let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; |
404 | | |
405 | 0 | let half = div.div_wrapping(i256::from_i128(2)); |
406 | 0 | let half_neg = half.neg_wrapping(); |
407 | | |
408 | 0 | let d = decimal_number.div_wrapping(div); |
409 | 0 | let r = decimal_number.mod_wrapping(div); |
410 | | |
411 | | // Round result |
412 | 0 | let adjusted = match decimal_number >= i256::ZERO { |
413 | 0 | true if r >= half => d.add_wrapping(i256::ONE), |
414 | 0 | false if r <= half_neg => d.sub_wrapping(i256::ONE), |
415 | 0 | _ => d, |
416 | | }; |
417 | | |
418 | 0 | let integers = if !integers.is_empty() { |
419 | 0 | i256::from_string(integers) |
420 | 0 | .ok_or_else(|| { |
421 | 0 | ArrowError::InvalidArgumentError(format!( |
422 | 0 | "Cannot parse decimal format: {value_str}" |
423 | 0 | )) |
424 | 0 | }) |
425 | 0 | .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? |
426 | | } else { |
427 | 0 | i256::ZERO |
428 | | }; |
429 | | |
430 | 0 | format!("{}", integers.add_wrapping(adjusted)) |
431 | | } else { |
432 | 0 | let padding = if scale > decimals.len() { scale } else { 0 }; |
433 | | |
434 | 0 | let decimals = format!("{decimals:0<padding$}"); |
435 | 0 | format!("{integers}{decimals}") |
436 | | }; |
437 | | |
438 | 0 | if negative { |
439 | 0 | number_decimals.insert(0, '-'); |
440 | 0 | } |
441 | | |
442 | 0 | let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| { |
443 | 0 | ArrowError::InvalidArgumentError(format!( |
444 | 0 | "Cannot convert {} to {}: Overflow", |
445 | 0 | value_str, |
446 | 0 | T::PREFIX |
447 | 0 | )) |
448 | 0 | })?; |
449 | | |
450 | 0 | T::Native::from_decimal(value).ok_or_else(|| { |
451 | 0 | ArrowError::InvalidArgumentError(format!("Cannot convert {} to {}", value_str, T::PREFIX)) |
452 | 0 | }) |
453 | 0 | } |
454 | | |
455 | 0 | pub(crate) fn generic_string_to_decimal_cast<'a, T, S>( |
456 | 0 | from: &'a S, |
457 | 0 | precision: u8, |
458 | 0 | scale: i8, |
459 | 0 | cast_options: &CastOptions, |
460 | 0 | ) -> Result<PrimitiveArray<T>, ArrowError> |
461 | 0 | where |
462 | 0 | T: DecimalType, |
463 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
464 | 0 | &'a S: StringArrayType<'a>, |
465 | | { |
466 | 0 | if cast_options.safe { |
467 | 0 | let iter = from.iter().map(|v| { |
468 | 0 | v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok()) |
469 | 0 | .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) |
470 | 0 | }); |
471 | | // Benefit: |
472 | | // 20% performance improvement |
473 | | // Soundness: |
474 | | // The iterator is trustedLen because it comes from an `StringArray`. |
475 | | Ok(unsafe { |
476 | 0 | PrimitiveArray::<T>::from_trusted_len_iter(iter) |
477 | 0 | .with_precision_and_scale(precision, scale)? |
478 | | }) |
479 | | } else { |
480 | 0 | let vec = from |
481 | 0 | .iter() |
482 | 0 | .map(|v| { |
483 | 0 | v.map(|v| { |
484 | 0 | parse_string_to_decimal_native::<T>(v, scale as usize) |
485 | 0 | .map_err(|_| { |
486 | 0 | ArrowError::CastError(format!( |
487 | 0 | "Cannot cast string '{}' to value of {:?} type", |
488 | 0 | v, |
489 | 0 | T::DATA_TYPE, |
490 | 0 | )) |
491 | 0 | }) |
492 | 0 | .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) |
493 | 0 | }) |
494 | 0 | .transpose() |
495 | 0 | }) |
496 | 0 | .collect::<Result<Vec<_>, _>>()?; |
497 | | // Benefit: |
498 | | // 20% performance improvement |
499 | | // Soundness: |
500 | | // The iterator is trustedLen because it comes from an `StringArray`. |
501 | | Ok(unsafe { |
502 | 0 | PrimitiveArray::<T>::from_trusted_len_iter(vec.iter()) |
503 | 0 | .with_precision_and_scale(precision, scale)? |
504 | | }) |
505 | | } |
506 | 0 | } |
507 | | |
508 | 0 | pub(crate) fn string_to_decimal_cast<T, Offset: OffsetSizeTrait>( |
509 | 0 | from: &GenericStringArray<Offset>, |
510 | 0 | precision: u8, |
511 | 0 | scale: i8, |
512 | 0 | cast_options: &CastOptions, |
513 | 0 | ) -> Result<PrimitiveArray<T>, ArrowError> |
514 | 0 | where |
515 | 0 | T: DecimalType, |
516 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
517 | | { |
518 | 0 | generic_string_to_decimal_cast::<T, GenericStringArray<Offset>>( |
519 | 0 | from, |
520 | 0 | precision, |
521 | 0 | scale, |
522 | 0 | cast_options, |
523 | | ) |
524 | 0 | } |
525 | | |
526 | 0 | pub(crate) fn string_view_to_decimal_cast<T>( |
527 | 0 | from: &StringViewArray, |
528 | 0 | precision: u8, |
529 | 0 | scale: i8, |
530 | 0 | cast_options: &CastOptions, |
531 | 0 | ) -> Result<PrimitiveArray<T>, ArrowError> |
532 | 0 | where |
533 | 0 | T: DecimalType, |
534 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
535 | | { |
536 | 0 | generic_string_to_decimal_cast::<T, StringViewArray>(from, precision, scale, cast_options) |
537 | 0 | } |
538 | | |
539 | | /// Cast Utf8 to decimal |
540 | 0 | pub(crate) fn cast_string_to_decimal<T, Offset: OffsetSizeTrait>( |
541 | 0 | from: &dyn Array, |
542 | 0 | precision: u8, |
543 | 0 | scale: i8, |
544 | 0 | cast_options: &CastOptions, |
545 | 0 | ) -> Result<ArrayRef, ArrowError> |
546 | 0 | where |
547 | 0 | T: DecimalType, |
548 | 0 | T::Native: DecimalCast + ArrowNativeTypeOp, |
549 | | { |
550 | 0 | if scale < 0 { |
551 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
552 | 0 | "Cannot cast string to decimal with negative scale {scale}" |
553 | 0 | ))); |
554 | 0 | } |
555 | | |
556 | 0 | if scale > T::MAX_SCALE { |
557 | 0 | return Err(ArrowError::InvalidArgumentError(format!( |
558 | 0 | "Cannot cast string to decimal greater than maximum scale {}", |
559 | 0 | T::MAX_SCALE |
560 | 0 | ))); |
561 | 0 | } |
562 | | |
563 | 0 | let result = match from.data_type() { |
564 | 0 | DataType::Utf8View => string_view_to_decimal_cast::<T>( |
565 | 0 | from.as_any().downcast_ref::<StringViewArray>().unwrap(), |
566 | 0 | precision, |
567 | 0 | scale, |
568 | 0 | cast_options, |
569 | 0 | )?, |
570 | 0 | DataType::Utf8 | DataType::LargeUtf8 => string_to_decimal_cast::<T, Offset>( |
571 | 0 | from.as_any() |
572 | 0 | .downcast_ref::<GenericStringArray<Offset>>() |
573 | 0 | .unwrap(), |
574 | 0 | precision, |
575 | 0 | scale, |
576 | 0 | cast_options, |
577 | 0 | )?, |
578 | 0 | other => { |
579 | 0 | return Err(ArrowError::ComputeError(format!( |
580 | 0 | "Cannot cast {other:?} to decimal", |
581 | 0 | ))) |
582 | | } |
583 | | }; |
584 | | |
585 | 0 | Ok(Arc::new(result)) |
586 | 0 | } |
587 | | |
588 | 0 | pub(crate) fn cast_floating_point_to_decimal<T: ArrowPrimitiveType, D>( |
589 | 0 | array: &PrimitiveArray<T>, |
590 | 0 | precision: u8, |
591 | 0 | scale: i8, |
592 | 0 | cast_options: &CastOptions, |
593 | 0 | ) -> Result<ArrayRef, ArrowError> |
594 | 0 | where |
595 | 0 | <T as ArrowPrimitiveType>::Native: AsPrimitive<f64>, |
596 | 0 | D: DecimalType + ArrowPrimitiveType, |
597 | 0 | <D as ArrowPrimitiveType>::Native: DecimalCast, |
598 | | { |
599 | 0 | let mul = 10_f64.powi(scale as i32); |
600 | | |
601 | 0 | if cast_options.safe { |
602 | 0 | array |
603 | 0 | .unary_opt::<_, D>(|v| { |
604 | 0 | D::Native::from_f64((mul * v.as_()).round()) |
605 | 0 | .filter(|v| D::is_valid_decimal_precision(*v, precision)) |
606 | 0 | }) |
607 | 0 | .with_precision_and_scale(precision, scale) |
608 | 0 | .map(|a| Arc::new(a) as ArrayRef) |
609 | | } else { |
610 | 0 | array |
611 | 0 | .try_unary::<_, D, _>(|v| { |
612 | 0 | D::Native::from_f64((mul * v.as_()).round()) |
613 | 0 | .ok_or_else(|| { |
614 | 0 | ArrowError::CastError(format!( |
615 | 0 | "Cannot cast to {}({}, {}). Overflowing on {:?}", |
616 | 0 | D::PREFIX, |
617 | 0 | precision, |
618 | 0 | scale, |
619 | 0 | v |
620 | 0 | )) |
621 | 0 | }) |
622 | 0 | .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) |
623 | 0 | })? |
624 | 0 | .with_precision_and_scale(precision, scale) |
625 | 0 | .map(|a| Arc::new(a) as ArrayRef) |
626 | | } |
627 | 0 | } |
628 | | |
629 | 0 | pub(crate) fn cast_decimal_to_integer<D, T>( |
630 | 0 | array: &dyn Array, |
631 | 0 | base: D::Native, |
632 | 0 | scale: i8, |
633 | 0 | cast_options: &CastOptions, |
634 | 0 | ) -> Result<ArrayRef, ArrowError> |
635 | 0 | where |
636 | 0 | T: ArrowPrimitiveType, |
637 | 0 | <T as ArrowPrimitiveType>::Native: NumCast, |
638 | 0 | D: DecimalType + ArrowPrimitiveType, |
639 | 0 | <D as ArrowPrimitiveType>::Native: ArrowNativeTypeOp + ToPrimitive, |
640 | | { |
641 | 0 | let array = array.as_primitive::<D>(); |
642 | | |
643 | 0 | let div: D::Native = base.pow_checked(scale as u32).map_err(|_| { |
644 | 0 | ArrowError::CastError(format!( |
645 | 0 | "Cannot cast to {:?}. The scale {} causes overflow.", |
646 | 0 | D::PREFIX, |
647 | 0 | scale, |
648 | 0 | )) |
649 | 0 | })?; |
650 | | |
651 | 0 | let mut value_builder = PrimitiveBuilder::<T>::with_capacity(array.len()); |
652 | | |
653 | 0 | if cast_options.safe { |
654 | 0 | for i in 0..array.len() { |
655 | 0 | if array.is_null(i) { |
656 | 0 | value_builder.append_null(); |
657 | 0 | } else { |
658 | 0 | let v = array |
659 | 0 | .value(i) |
660 | 0 | .div_checked(div) |
661 | 0 | .ok() |
662 | 0 | .and_then(<T::Native as NumCast>::from::<D::Native>); |
663 | 0 |
|
664 | 0 | value_builder.append_option(v); |
665 | 0 | } |
666 | | } |
667 | | } else { |
668 | 0 | for i in 0..array.len() { |
669 | 0 | if array.is_null(i) { |
670 | 0 | value_builder.append_null(); |
671 | 0 | } else { |
672 | 0 | let v = array.value(i).div_checked(div)?; |
673 | | |
674 | 0 | let value = <T::Native as NumCast>::from::<D::Native>(v).ok_or_else(|| { |
675 | 0 | ArrowError::CastError(format!( |
676 | 0 | "value of {:?} is out of range {}", |
677 | 0 | v, |
678 | 0 | T::DATA_TYPE |
679 | 0 | )) |
680 | 0 | })?; |
681 | | |
682 | 0 | value_builder.append_value(value); |
683 | | } |
684 | | } |
685 | | } |
686 | 0 | Ok(Arc::new(value_builder.finish())) |
687 | 0 | } |
688 | | |
689 | | /// Cast a decimal array to a floating point array. |
690 | | /// |
691 | | /// Conversion is lossy and follows standard floating point semantics. Values |
692 | | /// that exceed the representable range become `INFINITY` or `-INFINITY` without |
693 | | /// returning an error. |
694 | 0 | pub(crate) fn cast_decimal_to_float<D: DecimalType, T: ArrowPrimitiveType, F>( |
695 | 0 | array: &dyn Array, |
696 | 0 | op: F, |
697 | 0 | ) -> Result<ArrayRef, ArrowError> |
698 | 0 | where |
699 | 0 | F: Fn(D::Native) -> T::Native, |
700 | | { |
701 | 0 | let array = array.as_primitive::<D>(); |
702 | 0 | let array = array.unary::<_, T>(op); |
703 | 0 | Ok(Arc::new(array)) |
704 | 0 | } |
705 | | |
706 | | #[cfg(test)] |
707 | | mod tests { |
708 | | use super::*; |
709 | | |
710 | | #[test] |
711 | | fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> { |
712 | | assert_eq!( |
713 | | parse_string_to_decimal_native::<Decimal128Type>("0", 0)?, |
714 | | 0_i128 |
715 | | ); |
716 | | assert_eq!( |
717 | | parse_string_to_decimal_native::<Decimal128Type>("0", 5)?, |
718 | | 0_i128 |
719 | | ); |
720 | | |
721 | | assert_eq!( |
722 | | parse_string_to_decimal_native::<Decimal128Type>("123", 0)?, |
723 | | 123_i128 |
724 | | ); |
725 | | assert_eq!( |
726 | | parse_string_to_decimal_native::<Decimal128Type>("123", 5)?, |
727 | | 12300000_i128 |
728 | | ); |
729 | | |
730 | | assert_eq!( |
731 | | parse_string_to_decimal_native::<Decimal128Type>("123.45", 0)?, |
732 | | 123_i128 |
733 | | ); |
734 | | assert_eq!( |
735 | | parse_string_to_decimal_native::<Decimal128Type>("123.45", 5)?, |
736 | | 12345000_i128 |
737 | | ); |
738 | | |
739 | | assert_eq!( |
740 | | parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 0)?, |
741 | | 123_i128 |
742 | | ); |
743 | | assert_eq!( |
744 | | parse_string_to_decimal_native::<Decimal128Type>("123.4567891", 5)?, |
745 | | 12345679_i128 |
746 | | ); |
747 | | Ok(()) |
748 | | } |
749 | | } |