Skip to content

Commit 0013170

Browse files
authored
collation the validate precision code for decimal array (#2446)
1 parent 2185ce2 commit 0013170

File tree

1 file changed

+72
-95
lines changed

1 file changed

+72
-95
lines changed

arrow/src/array/array_decimal.rs

Lines changed: 72 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -249,39 +249,6 @@ impl<const BYTE_WIDTH: usize> BasicDecimalArray<BYTE_WIDTH> {
249249
fn raw_value_data_ptr(&self) -> *const u8 {
250250
self.value_data.as_ptr()
251251
}
252-
}
253-
254-
impl Decimal128Array {
255-
/// Creates a [Decimal128Array] with default precision and scale,
256-
/// based on an iterator of `i128` values without nulls
257-
pub fn from_iter_values<I: IntoIterator<Item = i128>>(iter: I) -> Self {
258-
let val_buf: Buffer = iter.into_iter().collect();
259-
let data = unsafe {
260-
ArrayData::new_unchecked(
261-
Self::default_type(),
262-
val_buf.len() / std::mem::size_of::<i128>(),
263-
None,
264-
None,
265-
0,
266-
vec![val_buf],
267-
vec![],
268-
)
269-
};
270-
Decimal128Array::from(data)
271-
}
272-
273-
// Validates decimal values in this array can be properly interpreted
274-
// with the specified precision.
275-
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
276-
(0..self.len()).try_for_each(|idx| {
277-
if self.is_valid(idx) {
278-
let decimal = unsafe { self.value_unchecked(idx) };
279-
validate_decimal_precision(decimal.as_i128(), precision)
280-
} else {
281-
Ok(())
282-
}
283-
})
284-
}
285252

286253
/// Returns a Decimal array with the same data as self, with the
287254
/// specified precision.
@@ -294,6 +261,23 @@ impl Decimal128Array {
294261
where
295262
Self: Sized,
296263
{
264+
// validate precision and scale
265+
self.validate_precision_scale(precision, scale)?;
266+
267+
// Ensure that all values are within the requested
268+
// precision. For performance, only check if the precision is
269+
// decreased
270+
if precision < self.precision {
271+
self.validate_data(precision)?;
272+
}
273+
274+
// safety: self.data is valid DataType::Decimal as checked above
275+
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
276+
Ok(self.data().clone().with_data_type(new_data_type).into())
277+
}
278+
279+
// validate that the new precision and scale are valid or not
280+
fn validate_precision_scale(&self, precision: usize, scale: usize) -> Result<()> {
297281
if precision > Self::MAX_PRECISION {
298282
return Err(ArrowError::InvalidArgumentError(format!(
299283
"precision {} is greater than max {}",
@@ -314,26 +298,67 @@ impl Decimal128Array {
314298
scale, precision
315299
)));
316300
}
317-
318-
// Ensure that all values are within the requested
319-
// precision. For performance, only check if the precision is
320-
// decreased
321-
if precision < self.precision {
322-
self.validate_decimal_precision(precision)?;
323-
}
324-
325301
let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale);
326302
assert_eq!(self.data().data_type(), &data_type);
327303

328-
// safety: self.data is valid DataType::Decimal as checked above
329-
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
304+
Ok(())
305+
}
306+
307+
// validate all the data in the array are valid within the new precision or not
308+
fn validate_data(&self, precision: usize) -> Result<()> {
309+
match BYTE_WIDTH {
310+
16 => self
311+
.as_any()
312+
.downcast_ref::<Decimal128Array>()
313+
.unwrap()
314+
.validate_decimal_precision(precision),
315+
32 => self
316+
.as_any()
317+
.downcast_ref::<Decimal256Array>()
318+
.unwrap()
319+
.validate_decimal_precision(precision),
320+
other_width => {
321+
panic!("invalid byte width {}", other_width);
322+
}
323+
}
324+
}
325+
}
330326

331-
Ok(self.data().clone().with_data_type(new_data_type).into())
327+
impl Decimal128Array {
328+
/// Creates a [Decimal128Array] with default precision and scale,
329+
/// based on an iterator of `i128` values without nulls
330+
pub fn from_iter_values<I: IntoIterator<Item = i128>>(iter: I) -> Self {
331+
let val_buf: Buffer = iter.into_iter().collect();
332+
let data = unsafe {
333+
ArrayData::new_unchecked(
334+
Self::default_type(),
335+
val_buf.len() / std::mem::size_of::<i128>(),
336+
None,
337+
None,
338+
0,
339+
vec![val_buf],
340+
vec![],
341+
)
342+
};
343+
Decimal128Array::from(data)
344+
}
345+
346+
// Validates decimal128 values in this array can be properly interpreted
347+
// with the specified precision.
348+
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
349+
(0..self.len()).try_for_each(|idx| {
350+
if self.is_valid(idx) {
351+
let decimal = unsafe { self.value_unchecked(idx) };
352+
validate_decimal_precision(decimal.as_i128(), precision)
353+
} else {
354+
Ok(())
355+
}
356+
})
332357
}
333358
}
334359

335360
impl Decimal256Array {
336-
// Validates decimal values in this array can be properly interpreted
361+
// Validates decimal256 values in this array can be properly interpreted
337362
// with the specified precision.
338363
fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
339364
(0..self.len()).try_for_each(|idx| {
@@ -351,54 +376,6 @@ impl Decimal256Array {
351376
}
352377
})
353378
}
354-
355-
/// Returns a Decimal array with the same data as self, with the
356-
/// specified precision.
357-
///
358-
/// Returns an Error if:
359-
/// 1. `precision` is larger than [`Self::MAX_PRECISION`]
360-
/// 2. `scale` is larger than [`Self::MAX_SCALE`];
361-
/// 3. `scale` is > `precision`
362-
pub fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result<Self>
363-
where
364-
Self: Sized,
365-
{
366-
if precision > Self::MAX_PRECISION {
367-
return Err(ArrowError::InvalidArgumentError(format!(
368-
"precision {} is greater than max {}",
369-
precision,
370-
Self::MAX_PRECISION
371-
)));
372-
}
373-
if scale > Self::MAX_SCALE {
374-
return Err(ArrowError::InvalidArgumentError(format!(
375-
"scale {} is greater than max {}",
376-
scale,
377-
Self::MAX_SCALE
378-
)));
379-
}
380-
if scale > precision {
381-
return Err(ArrowError::InvalidArgumentError(format!(
382-
"scale {} is greater than precision {}",
383-
scale, precision
384-
)));
385-
}
386-
387-
// Ensure that all values are within the requested
388-
// precision. For performance, only check if the precision is
389-
// decreased
390-
if precision < self.precision {
391-
self.validate_decimal_precision(precision)?;
392-
}
393-
394-
let data_type = Self::TYPE_CONSTRUCTOR(self.precision, self.scale);
395-
assert_eq!(self.data().data_type(), &data_type);
396-
397-
// safety: self.data is valid DataType::Decimal as checked above
398-
let new_data_type = Self::TYPE_CONSTRUCTOR(precision, scale);
399-
400-
Ok(self.data().clone().with_data_type(new_data_type).into())
401-
}
402379
}
403380

404381
impl<const BYTE_WIDTH: usize> From<ArrayData> for BasicDecimalArray<BYTE_WIDTH> {
@@ -942,7 +919,7 @@ mod tests {
942919
Decimal256::from_big_int(
943920
&value1,
944921
DECIMAL256_MAX_PRECISION,
945-
DECIMAL_DEFAULT_SCALE
922+
DECIMAL_DEFAULT_SCALE,
946923
)
947924
.unwrap(),
948925
array.value(0)
@@ -953,7 +930,7 @@ mod tests {
953930
Decimal256::from_big_int(
954931
&value2,
955932
DECIMAL256_MAX_PRECISION,
956-
DECIMAL_DEFAULT_SCALE
933+
DECIMAL_DEFAULT_SCALE,
957934
)
958935
.unwrap(),
959936
array.value(2)

0 commit comments

Comments
 (0)