Coverage Report

Created: 2025-08-26 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/Users/andrewlamb/Software/arrow-rs/arrow-arith/src/arithmetic.rs
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
//! Defines basic arithmetic kernels for `PrimitiveArrays`.
19
//!
20
//! These kernels can leverage SIMD if available on your system.  Currently no runtime
21
//! detection is provided, you should enable the specific SIMD intrinsics using
22
//! `RUSTFLAGS="-C target-feature=+avx2"` for example.  See the documentation
23
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24
25
use crate::arity::*;
26
use arrow_array::types::*;
27
use arrow_array::*;
28
use arrow_buffer::i256;
29
use arrow_buffer::ArrowNativeType;
30
use arrow_schema::*;
31
use std::cmp::min;
32
use std::sync::Arc;
33
34
/// Returns the precision and scale of the result of a multiplication of two decimal types,
35
/// and the divisor for fixed point multiplication.
36
fn get_fixed_point_info(
37
    left: (u8, i8),
38
    right: (u8, i8),
39
    required_scale: i8,
40
) -> Result<(u8, i8, i256), ArrowError> {
41
    let product_scale = left.1 + right.1;
42
    let precision = min(left.0 + right.0 + 1, DECIMAL128_MAX_PRECISION);
43
44
    if required_scale > product_scale {
45
        return Err(ArrowError::ComputeError(format!(
46
            "Required scale {required_scale} is greater than product scale {product_scale}",
47
        )));
48
    }
49
50
    let divisor = i256::from_i128(10).pow_wrapping((product_scale - required_scale) as u32);
51
52
    Ok((precision, product_scale, divisor))
53
}
54
55
/// Perform `left * right` operation on two decimal arrays. If either left or right value is
56
/// null then the result is also null.
57
///
58
/// This performs decimal multiplication which allows precision loss if an exact representation
59
/// is not possible for the result, according to the required scale. In the case, the result
60
/// will be rounded to the required scale.
61
///
62
/// If the required scale is greater than the product scale, an error is returned.
63
///
64
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
65
///
66
/// It is implemented for compatibility with precision loss `multiply` function provided by
67
/// other data processing engines. For multiplication with precision loss detection, use
68
/// `multiply_dyn` or `multiply_dyn_checked` instead.
69
pub fn multiply_fixed_point_dyn(
70
    left: &dyn Array,
71
    right: &dyn Array,
72
    required_scale: i8,
73
) -> Result<ArrayRef, ArrowError> {
74
    match (left.data_type(), right.data_type()) {
75
        (DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => {
76
            let left = left.as_any().downcast_ref::<Decimal128Array>().unwrap();
77
            let right = right.as_any().downcast_ref::<Decimal128Array>().unwrap();
78
79
0
            multiply_fixed_point(left, right, required_scale).map(|a| Arc::new(a) as ArrayRef)
80
        }
81
        (_, _) => Err(ArrowError::CastError(format!(
82
            "Unsupported data type {}, {}",
83
            left.data_type(),
84
            right.data_type()
85
        ))),
86
    }
87
}
88
89
/// Perform `left * right` operation on two decimal arrays. If either left or right value is
90
/// null then the result is also null.
91
///
92
/// This performs decimal multiplication which allows precision loss if an exact representation
93
/// is not possible for the result, according to the required scale. In the case, the result
94
/// will be rounded to the required scale.
95
///
96
/// If the required scale is greater than the product scale, an error is returned.
97
///
98
/// It is implemented for compatibility with precision loss `multiply` function provided by
99
/// other data processing engines. For multiplication with precision loss detection, use
100
/// `multiply` or `multiply_checked` instead.
101
pub fn multiply_fixed_point_checked(
102
    left: &PrimitiveArray<Decimal128Type>,
103
    right: &PrimitiveArray<Decimal128Type>,
104
    required_scale: i8,
105
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
106
    let (precision, product_scale, divisor) = get_fixed_point_info(
107
        (left.precision(), left.scale()),
108
        (right.precision(), right.scale()),
109
        required_scale,
110
    )?;
111
112
    if required_scale == product_scale {
113
0
        return try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| a.mul_checked(b))?
114
            .with_precision_and_scale(precision, required_scale);
115
    }
116
117
0
    try_binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
118
0
        let a = i256::from_i128(a);
119
0
        let b = i256::from_i128(b);
120
121
0
        let mut mul = a.wrapping_mul(b);
122
0
        mul = divide_and_round::<Decimal256Type>(mul, divisor);
123
0
        mul.to_i128().ok_or_else(|| {
124
0
            ArrowError::ArithmeticOverflow(format!("Overflow happened on: {a:?} * {b:?}"))
125
0
        })
126
0
    })
127
0
    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
128
}
129
130
/// Perform `left * right` operation on two decimal arrays. If either left or right value is
131
/// null then the result is also null.
132
///
133
/// This performs decimal multiplication which allows precision loss if an exact representation
134
/// is not possible for the result, according to the required scale. In the case, the result
135
/// will be rounded to the required scale.
136
///
137
/// If the required scale is greater than the product scale, an error is returned.
138
///
139
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
140
/// For an overflow-checking variant, use `multiply_fixed_point_checked` instead.
141
///
142
/// It is implemented for compatibility with precision loss `multiply` function provided by
143
/// other data processing engines. For multiplication with precision loss detection, use
144
/// `multiply` or `multiply_checked` instead.
145
pub fn multiply_fixed_point(
146
    left: &PrimitiveArray<Decimal128Type>,
147
    right: &PrimitiveArray<Decimal128Type>,
148
    required_scale: i8,
149
) -> Result<PrimitiveArray<Decimal128Type>, ArrowError> {
150
    let (precision, product_scale, divisor) = get_fixed_point_info(
151
        (left.precision(), left.scale()),
152
        (right.precision(), right.scale()),
153
        required_scale,
154
    )?;
155
156
    if required_scale == product_scale {
157
0
        return binary(left, right, |a, b| a.mul_wrapping(b))?
158
            .with_precision_and_scale(precision, required_scale);
159
    }
160
161
0
    binary::<_, _, _, Decimal128Type>(left, right, |a, b| {
162
0
        let a = i256::from_i128(a);
163
0
        let b = i256::from_i128(b);
164
165
0
        let mut mul = a.wrapping_mul(b);
166
0
        mul = divide_and_round::<Decimal256Type>(mul, divisor);
167
0
        mul.as_i128()
168
0
    })
169
0
    .and_then(|a| a.with_precision_and_scale(precision, required_scale))
170
}
171
172
/// Divide a decimal native value by given divisor and round the result.
173
0
fn divide_and_round<I>(input: I::Native, div: I::Native) -> I::Native
174
0
where
175
0
    I: DecimalType,
176
0
    I::Native: ArrowNativeTypeOp,
177
{
178
0
    let d = input.div_wrapping(div);
179
0
    let r = input.mod_wrapping(div);
180
181
0
    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
182
0
    let half_neg = half.neg_wrapping();
183
184
    // Round result
185
0
    match input >= I::Native::ZERO {
186
0
        true if r >= half => d.add_wrapping(I::Native::ONE),
187
0
        false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
188
0
        _ => d,
189
    }
190
0
}
191
192
#[cfg(test)]
193
mod tests {
194
    use super::*;
195
    use crate::numeric::mul;
196
197
    #[test]
198
    fn test_decimal_multiply_allow_precision_loss() {
199
        // Overflow happening as i128 cannot hold multiplying result.
200
        // [123456789]
201
        let a = Decimal128Array::from(vec![123456789000000000000000000])
202
            .with_precision_and_scale(38, 18)
203
            .unwrap();
204
205
        // [10]
206
        let b = Decimal128Array::from(vec![10000000000000000000])
207
            .with_precision_and_scale(38, 18)
208
            .unwrap();
209
210
        let err = mul(&a, &b).unwrap_err();
211
        assert!(err
212
            .to_string()
213
            .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000"));
214
215
        // Allow precision loss.
216
        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
217
        // [1234567890]
218
        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
219
            .with_precision_and_scale(38, 28)
220
            .unwrap();
221
222
        assert_eq!(&expected, &result);
223
        assert_eq!(
224
            result.value_as_string(0),
225
            "1234567890.0000000000000000000000000000"
226
        );
227
228
        // Rounding case
229
        // [0.000000000000000001, 123456789.555555555555555555, 1.555555555555555555]
230
        let a = Decimal128Array::from(vec![1, 123456789555555555555555555, 1555555555555555555])
231
            .with_precision_and_scale(38, 18)
232
            .unwrap();
233
234
        // [1.555555555555555555, 11.222222222222222222, 0.000000000000000001]
235
        let b = Decimal128Array::from(vec![1555555555555555555, 11222222222222222222, 1])
236
            .with_precision_and_scale(38, 18)
237
            .unwrap();
238
239
        let result = multiply_fixed_point_checked(&a, &b, 28).unwrap();
240
        // [
241
        //    0.0000000000000000015555555556,
242
        //    1385459527.2345679012071330528765432099,
243
        //    0.0000000000000000015555555556
244
        // ]
245
        let expected = Decimal128Array::from(vec![
246
            15555555556,
247
            13854595272345679012071330528765432099,
248
            15555555556,
249
        ])
250
        .with_precision_and_scale(38, 28)
251
        .unwrap();
252
253
        assert_eq!(&expected, &result);
254
255
        // Rounded the value "1385459527.234567901207133052876543209876543210".
256
        assert_eq!(
257
            result.value_as_string(1),
258
            "1385459527.2345679012071330528765432099"
259
        );
260
        assert_eq!(result.value_as_string(0), "0.0000000000000000015555555556");
261
        assert_eq!(result.value_as_string(2), "0.0000000000000000015555555556");
262
263
        let a = Decimal128Array::from(vec![1230])
264
            .with_precision_and_scale(4, 2)
265
            .unwrap();
266
267
        let b = Decimal128Array::from(vec![1000])
268
            .with_precision_and_scale(4, 2)
269
            .unwrap();
270
271
        // Required scale is same as the product of the input scales. Behavior is same as multiply.
272
        let result = multiply_fixed_point_checked(&a, &b, 4).unwrap();
273
        assert_eq!(result.precision(), 9);
274
        assert_eq!(result.scale(), 4);
275
276
        let expected = mul(&a, &b).unwrap();
277
        assert_eq!(expected.as_ref(), &result);
278
279
        // Required scale cannot be larger than the product of the input scales.
280
        let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err();
281
        assert!(result
282
            .to_string()
283
            .contains("Required scale 5 is greater than product scale 4"));
284
    }
285
286
    #[test]
287
    fn test_decimal_multiply_allow_precision_loss_overflow() {
288
        // [99999999999123456789]
289
        let a = Decimal128Array::from(vec![99999999999123456789000000000000000000])
290
            .with_precision_and_scale(38, 18)
291
            .unwrap();
292
293
        // [9999999999910]
294
        let b = Decimal128Array::from(vec![9999999999910000000000000000000])
295
            .with_precision_and_scale(38, 18)
296
            .unwrap();
297
298
        let err = multiply_fixed_point_checked(&a, &b, 28).unwrap_err();
299
        assert!(err.to_string().contains(
300
            "Overflow happened on: 99999999999123456789000000000000000000 * 9999999999910000000000000000000"
301
        ));
302
303
        let result = multiply_fixed_point(&a, &b, 28).unwrap();
304
        let expected = Decimal128Array::from(vec![62946009661555981610246871926660136960])
305
            .with_precision_and_scale(38, 28)
306
            .unwrap();
307
308
        assert_eq!(&expected, &result);
309
    }
310
311
    #[test]
312
    fn test_decimal_multiply_fixed_point() {
313
        // [123456789]
314
        let a = Decimal128Array::from(vec![123456789000000000000000000])
315
            .with_precision_and_scale(38, 18)
316
            .unwrap();
317
318
        // [10]
319
        let b = Decimal128Array::from(vec![10000000000000000000])
320
            .with_precision_and_scale(38, 18)
321
            .unwrap();
322
323
        // `multiply` overflows on this case.
324
        let err = mul(&a, &b).unwrap_err();
325
        assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000");
326
327
        // Avoid overflow by reducing the scale.
328
        let result = multiply_fixed_point(&a, &b, 28).unwrap();
329
        // [1234567890]
330
        let expected = Decimal128Array::from(vec![12345678900000000000000000000000000000])
331
            .with_precision_and_scale(38, 28)
332
            .unwrap();
333
334
        assert_eq!(&expected, &result);
335
        assert_eq!(
336
            result.value_as_string(0),
337
            "1234567890.0000000000000000000000000000"
338
        );
339
    }
340
}