/Users/andrewlamb/Software/arrow-rs/arrow-arith/src/arithmetic.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Defines basic arithmetic kernels for `PrimitiveArrays`. |
19 | | //! |
20 | | //! These kernels can leverage SIMD if available on your system. Currently no runtime |
21 | | //! detection is provided, you should enable the specific SIMD intrinsics using |
22 | | //! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation |
23 | | //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. |
24 | | |
25 | | use crate::arity::*; |
26 | | use arrow_array::types::*; |
27 | | use arrow_array::*; |
28 | | use arrow_buffer::i256; |
29 | | use arrow_buffer::ArrowNativeType; |
30 | | use arrow_schema::*; |
31 | | use std::cmp::min; |
32 | | use std::sync::Arc; |
33 | | |
34 | | /// Returns the precision and scale of the result of a multiplication of two decimal types, |
35 | | /// and the divisor for fixed point multiplication. |
36 | | fn get_fixed_point_info( |
37 | | left: (u8, i8), |
38 | | right: (u8, i8), |
39 | | required_scale: i8, |
40 | | ) -> Result<(u8, i8, i256), ArrowError> { |
41 | | let product_scale = left.1 + right.1; |
42 | | let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION); |
43 | | |
44 | | if required_scale > product_scale { |
45 | | return Err(ArrowError::ComputeError(format!( |
46 | | "Required scale {required_scale} is greater than product scale {product_scale}", |
47 | | ))); |
48 | | } |
49 | | |
50 | | let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32); |
51 | | |
52 | | Ok((precision, product_scale, divisor)) |
53 | | } |
54 | | |
55 | | /// Perform `left * right` operation on two decimal arrays. If either left or right value is |
56 | | /// null then the result is also null. |
57 | | /// |
58 | | /// This performs decimal multiplication which allows precision loss if an exact representation |
59 | | /// is not possible for the result, according to the required scale. In the case, the result |
60 | | /// will be rounded to the required scale. |
61 | | /// |
62 | | /// If the required scale is greater than the product scale, an error is returned. |
63 | | /// |
64 | | /// This doesn't detect overflow. Once overflowing, the result will wrap around. |
65 | | /// |
66 | | /// It is implemented for compatibility with precision loss `multiply` function provided by |
67 | | /// other data processing engines. For multiplication with precision loss detection, use |
68 | | /// `multiply_dyn` or `multiply_dyn_checked` instead. |
69 | | pub fn multiply_fixed_point_dyn( |
70 | | left: &dyn Array, |
71 | | right: &dyn Array, |
72 | | required_scale: i8, |
73 | | ) -> Result<ArrayRef, ArrowError> { |
74 | | match (left.data_type(), right.data_type()) { |
75 | | (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => { |
76 | | let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap(); |
77 | | let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap(); |
78 | | |
79 | 0 | multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef) |
80 | | } |
81 | | (_, _) => Err(ArrowError::CastError(format!( |
82 | | "Unsupported data type {}, {}", |
83 | | left.data_type(), |
84 | | right.data_type() |
85 | | ))), |
86 | | } |
87 | | } |
88 | | |
89 | | /// Perform `left * right` operation on two decimal arrays. If either left or right value is |
90 | | /// null then the result is also null. |
91 | | /// |
92 | | /// This performs decimal multiplication which allows precision loss if an exact representation |
93 | | /// is not possible for the result, according to the required scale. In the case, the result |
94 | | /// will be rounded to the required scale. |
95 | | /// |
96 | | /// If the required scale is greater than the product scale, an error is returned. |
97 | | /// |
98 | | /// It is implemented for compatibility with precision loss `multiply` function provided by |
99 | | /// other data processing engines. For multiplication with precision loss detection, use |
100 | | /// `multiply` or `multiply_checked` instead. |
101 | | pub fn multiply_fixed_point_checked( |
102 | | left: &PrimitiveArray<Decimal128Type>, |
103 | | right: &PrimitiveArray<Decimal128Type>, |
104 | | required_scale: i8, |
105 | | ) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> { |
106 | | let (precision, product_scale, divisor) = get_fixed_point_info( |
107 | | (left.precision(), left.scale()), |
108 | | (right.precision(), right.scale()), |
109 | | required_scale, |
110 | | )?; |
111 | | |
112 | | if required_scale == product_scale { |
113 | 0 | return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))? |
114 | | .with_precision_and_scale(precision, required_scale); |
115 | | } |
116 | | |
117 | 0 | try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| { |
118 | 0 | let a = i256::from_i128(a); |
119 | 0 | let b = i256::from_i128(b); |
120 | | |
121 | 0 | let mut mul = a.wrapping_mul(b); |
122 | 0 | mul = divide_and_round::<Decimal256Type>(mul, divisor); |
123 | 0 | mul.to_i128().ok_or_else(|| { |
124 | 0 | ArrowError::ArithmeticOverflow(format!("Overflow happened on: {a:?} * {b:?}")) |
125 | 0 | }) |
126 | 0 | }) |
127 | 0 | .and_then(|a| a.with_precision_and_scale(precision, required_scale)) |
128 | | } |
129 | | |
130 | | /// Perform `left * right` operation on two decimal arrays. If either left or right value is |
131 | | /// null then the result is also null. |
132 | | /// |
133 | | /// This performs decimal multiplication which allows precision loss if an exact representation |
134 | | /// is not possible for the result, according to the required scale. In the case, the result |
135 | | /// will be rounded to the required scale. |
136 | | /// |
137 | | /// If the required scale is greater than the product scale, an error is returned. |
138 | | /// |
139 | | /// This doesn't detect overflow. Once overflowing, the result will wrap around. |
140 | | /// For an overflow-checking variant, use `multiply_fixed_point_checked` instead. |
141 | | /// |
142 | | /// It is implemented for compatibility with precision loss `multiply` function provided by |
143 | | /// other data processing engines. For multiplication with precision loss detection, use |
144 | | /// `multiply` or `multiply_checked` instead. |
145 | | pub fn multiply_fixed_point( |
146 | | left: &PrimitiveArray<Decimal128Type>, |
147 | | right: &PrimitiveArray<Decimal128Type>, |
148 | | required_scale: i8, |
149 | | ) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> { |
150 | | let (precision, product_scale, divisor) = get_fixed_point_info( |
151 | | (left.precision(), left.scale()), |
152 | | (right.precision(), right.scale()), |
153 | | required_scale, |
154 | | )?; |
155 | | |
156 | | if required_scale == product_scale { |
157 | 0 | return binary(left, right, |a, b| a.mul_wrapping(b))? |
158 | | .with_precision_and_scale(precision, required_scale); |
159 | | } |
160 | | |
161 | 0 | binary::<_, _, _, Decimal128Type>(left, right, |a, b| { |
162 | 0 | let a = i256::from_i128(a); |
163 | 0 | let b = i256::from_i128(b); |
164 | | |
165 | 0 | let mut mul = a.wrapping_mul(b); |
166 | 0 | mul = divide_and_round::<Decimal256Type>(mul, divisor); |
167 | 0 | mul.as_i128() |
168 | 0 | }) |
169 | 0 | .and_then(|a| a.with_precision_and_scale(precision, required_scale)) |
170 | | } |
171 | | |
172 | | /// Divide a decimal native value by given divisor and round the result. |
173 | 0 | fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native |
174 | 0 | where |
175 | 0 | I: DecimalType, |
176 | 0 | I::Native: ArrowNativeTypeOp, |
177 | | { |
178 | 0 | let d = input.div_wrapping(div); |
179 | 0 | let r = input.mod_wrapping(div); |
180 | | |
181 | 0 | let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); |
182 | 0 | let half_neg = half.neg_wrapping(); |
183 | | |
184 | | // Round result |
185 | 0 | match input >= I::Native::ZERO { |
186 | 0 | true if r >= half => d.add_wrapping(I::Native::ONE), |
187 | 0 | false if r <= half_neg => d.sub_wrapping(I::Native::ONE), |
188 | 0 | _ => d, |
189 | | } |
190 | 0 | } |
191 | | |
192 | | #[cfg(test)] |
193 | | mod tests { |
194 | | use super::*; |
195 | | use crate::numeric::mul; |
196 | | |
197 | | #[test] |
198 | | fn test_decimal_multiply_allow_precision_loss() { |
199 | | // Overflow happening as i128 cannot hold multiplying result. |
200 | | // [123456789] |
201 | | let a = Decimal128Array::from(vec![123456789000000000000000000]) |
202 | | .with_precision_and_scale(38, 18) |
203 | | .unwrap(); |
204 | | |
205 | | // [10] |
206 | | let b = Decimal128Array::from(vec![10000000000000000000]) |
207 | | .with_precision_and_scale(38, 18) |
208 | | .unwrap(); |
209 | | |
210 | | let err = mul(&a, &b).unwrap_err(); |
211 | | assert!(err |
212 | | .to_string() |
213 | | .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000")); |
214 | | |
215 | | // Allow precision loss. |
216 | | let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); |
217 | | // [1234567890] |
218 | | let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) |
219 | | .with_precision_and_scale(38, 28) |
220 | | .unwrap(); |
221 | | |
222 | | assert_eq!(&expected, &result); |
223 | | assert_eq!( |
224 | | result.value_as_string(0), |
225 | | "1234567890.0000000000000000000000000000" |
226 | | ); |
227 | | |
228 | | // Rounding case |
229 | | // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555] |
230 | | let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555]) |
231 | | .with_precision_and_scale(38, 18) |
232 | | .unwrap(); |
233 | | |
234 | | // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001] |
235 | | let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1]) |
236 | | .with_precision_and_scale(38, 18) |
237 | | .unwrap(); |
238 | | |
239 | | let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); |
240 | | // [ |
241 | | // 0.0000000000000000015555555556, |
242 | | // 1385459527.2345679012071330528765432099, |
243 | | // 0.0000000000000000015555555556 |
244 | | // ] |
245 | | let expected = Decimal128Array::from(vec![ |
246 | | 15555555556, |
247 | | 13854595272345679012071330528765432099, |
248 | | 15555555556, |
249 | | ]) |
250 | | .with_precision_and_scale(38, 28) |
251 | | .unwrap(); |
252 | | |
253 | | assert_eq!(&expected, &result); |
254 | | |
255 | | // Rounded the value "1385459527.234567901207133052876543209876543210". |
256 | | assert_eq!( |
257 | | result.value_as_string(1), |
258 | | "1385459527.2345679012071330528765432099" |
259 | | ); |
260 | | assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556"); |
261 | | assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556"); |
262 | | |
263 | | let a = Decimal128Array::from(vec![1230]) |
264 | | .with_precision_and_scale(4, 2) |
265 | | .unwrap(); |
266 | | |
267 | | let b = Decimal128Array::from(vec![1000]) |
268 | | .with_precision_and_scale(4, 2) |
269 | | .unwrap(); |
270 | | |
271 | | // Required scale is same as the product of the input scales. Behavior is same as multiply. |
272 | | let result = multiply_fixed_point_checked(&a, &b, 4).unwrap(); |
273 | | assert_eq!(result.precision(), 9); |
274 | | assert_eq!(result.scale(), 4); |
275 | | |
276 | | let expected = mul(&a, &b).unwrap(); |
277 | | assert_eq!(expected.as_ref(), &result); |
278 | | |
279 | | // Required scale cannot be larger than the product of the input scales. |
280 | | let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); |
281 | | assert!(result |
282 | | .to_string() |
283 | | .contains("Required scale 5 is greater than product scale 4")); |
284 | | } |
285 | | |
286 | | #[test] |
287 | | fn test_decimal_multiply_allow_precision_loss_overflow() { |
288 | | // [99999999999123456789] |
289 | | let a = Decimal128Array::from(vec![99999999999123456789000000000000000000]) |
290 | | .with_precision_and_scale(38, 18) |
291 | | .unwrap(); |
292 | | |
293 | | // [9999999999910] |
294 | | let b = Decimal128Array::from(vec![9999999999910000000000000000000]) |
295 | | .with_precision_and_scale(38, 18) |
296 | | .unwrap(); |
297 | | |
298 | | let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err(); |
299 | | assert!(err.to_string().contains( |
300 | | "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000" |
301 | | )); |
302 | | |
303 | | let result = multiply_fixed_point(&a, &b, 28).unwrap(); |
304 | | let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960]) |
305 | | .with_precision_and_scale(38, 28) |
306 | | .unwrap(); |
307 | | |
308 | | assert_eq!(&expected, &result); |
309 | | } |
310 | | |
311 | | #[test] |
312 | | fn test_decimal_multiply_fixed_point() { |
313 | | // [123456789] |
314 | | let a = Decimal128Array::from(vec![123456789000000000000000000]) |
315 | | .with_precision_and_scale(38, 18) |
316 | | .unwrap(); |
317 | | |
318 | | // [10] |
319 | | let b = Decimal128Array::from(vec![10000000000000000000]) |
320 | | .with_precision_and_scale(38, 18) |
321 | | .unwrap(); |
322 | | |
323 | | // `multiply` overflows on this case. |
324 | | let err = mul(&a, &b).unwrap_err(); |
325 | | assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"); |
326 | | |
327 | | // Avoid overflow by reducing the scale. |
328 | | let result = multiply_fixed_point(&a, &b, 28).unwrap(); |
329 | | // [1234567890] |
330 | | let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000]) |
331 | | .with_precision_and_scale(38, 28) |
332 | | .unwrap(); |
333 | | |
334 | | assert_eq!(&expected, &result); |
335 | | assert_eq!( |
336 | | result.value_as_string(0), |
337 | | "1234567890.0000000000000000000000000000" |
338 | | ); |
339 | | } |
340 | | } |