/Users/andrewlamb/Software/arrow-rs/arrow-data/src/equal/mod.rs
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | //! Module containing functionality to compute array equality. |
19 | | //! This module uses [ArrayData] and does not |
20 | | //! depend on dynamic casting of `Array`. |
21 | | |
22 | | use crate::data::ArrayData; |
23 | | use arrow_buffer::i256; |
24 | | use arrow_schema::{DataType, IntervalUnit}; |
25 | | use half::f16; |
26 | | |
27 | | mod boolean; |
28 | | mod byte_view; |
29 | | mod dictionary; |
30 | | mod fixed_binary; |
31 | | mod fixed_list; |
32 | | mod list; |
33 | | mod null; |
34 | | mod primitive; |
35 | | mod run; |
36 | | mod structure; |
37 | | mod union; |
38 | | mod utils; |
39 | | mod variable_size; |
40 | | |
41 | | // these methods assume the same type, len and null count. |
42 | | // For this reason, they are not exposed and are instead used |
43 | | // to build the generic functions below (`equal_range` and `equal`). |
44 | | use boolean::boolean_equal; |
45 | | use byte_view::byte_view_equal; |
46 | | use dictionary::dictionary_equal; |
47 | | use fixed_binary::fixed_binary_equal; |
48 | | use fixed_list::fixed_list_equal; |
49 | | use list::list_equal; |
50 | | use null::null_equal; |
51 | | use primitive::primitive_equal; |
52 | | use structure::struct_equal; |
53 | | use union::union_equal; |
54 | | use variable_size::variable_sized_equal; |
55 | | |
56 | | use self::run::run_equal; |
57 | | |
58 | | /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively |
59 | | /// for `len` slots. |
60 | | #[inline] |
61 | 697 | fn equal_values( |
62 | 697 | lhs: &ArrayData, |
63 | 697 | rhs: &ArrayData, |
64 | 697 | lhs_start: usize, |
65 | 697 | rhs_start: usize, |
66 | 697 | len: usize, |
67 | 697 | ) -> bool { |
68 | 697 | match lhs.data_type() { |
69 | 0 | DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), |
70 | 37 | DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), |
71 | 0 | DataType::UInt8 => primitive_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len), |
72 | 0 | DataType::UInt16 => primitive_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len), |
73 | 0 | DataType::UInt32 => primitive_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len), |
74 | 0 | DataType::UInt64 => primitive_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len), |
75 | 0 | DataType::Int8 => primitive_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len), |
76 | 0 | DataType::Int16 => primitive_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len), |
77 | 168 | DataType::Int32 => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len), |
78 | 43 | DataType::Int64 => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len), |
79 | 41 | DataType::Float32 => primitive_equal::<f32>(lhs, rhs, lhs_start, rhs_start, len), |
80 | 63 | DataType::Float64 => primitive_equal::<f64>(lhs, rhs, lhs_start, rhs_start, len), |
81 | 0 | DataType::Decimal32(_, _) => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len), |
82 | 0 | DataType::Decimal64(_, _) => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len), |
83 | 8 | DataType::Decimal128(_, _) => primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len), |
84 | 0 | DataType::Decimal256(_, _) => primitive_equal::<i256>(lhs, rhs, lhs_start, rhs_start, len), |
85 | | DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { |
86 | 0 | primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len) |
87 | | } |
88 | | DataType::Date64 |
89 | | | DataType::Interval(IntervalUnit::DayTime) |
90 | | | DataType::Time64(_) |
91 | | | DataType::Timestamp(_, _) |
92 | 33 | | DataType::Duration(_) => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len), |
93 | | DataType::Interval(IntervalUnit::MonthDayNano) => { |
94 | 2 | primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len) |
95 | | } |
96 | | DataType::Utf8 | DataType::Binary => { |
97 | 148 | variable_sized_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len) |
98 | | } |
99 | | DataType::LargeUtf8 | DataType::LargeBinary => { |
100 | 0 | variable_sized_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len) |
101 | | } |
102 | 7 | DataType::FixedSizeBinary(_) => fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len), |
103 | | DataType::BinaryView | DataType::Utf8View => { |
104 | 0 | byte_view_equal(lhs, rhs, lhs_start, rhs_start, len) |
105 | | } |
106 | 62 | DataType::List(_) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len), |
107 | | DataType::ListView(_) | DataType::LargeListView(_) => { |
108 | 0 | unimplemented!("ListView/LargeListView not yet implemented") |
109 | | } |
110 | 0 | DataType::LargeList(_) => list_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len), |
111 | 0 | DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), |
112 | 62 | DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), |
113 | 0 | DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), |
114 | 6 | DataType::Dictionary(data_type, _) => match data_type.as_ref() { |
115 | 0 | DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len), |
116 | 0 | DataType::Int16 => dictionary_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len), |
117 | 6 | DataType::Int32 => dictionary_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len), |
118 | 0 | DataType::Int64 => dictionary_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len), |
119 | 0 | DataType::UInt8 => dictionary_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len), |
120 | 0 | DataType::UInt16 => dictionary_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len), |
121 | 0 | DataType::UInt32 => dictionary_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len), |
122 | 0 | DataType::UInt64 => dictionary_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len), |
123 | 0 | _ => unreachable!(), |
124 | | }, |
125 | 0 | DataType::Float16 => primitive_equal::<f16>(lhs, rhs, lhs_start, rhs_start, len), |
126 | 17 | DataType::Map(_, _) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len), |
127 | 0 | DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), |
128 | | } |
129 | 697 | } |
130 | | |
131 | 258 | fn equal_range( |
132 | 258 | lhs: &ArrayData, |
133 | 258 | rhs: &ArrayData, |
134 | 258 | lhs_start: usize, |
135 | 258 | rhs_start: usize, |
136 | 258 | len: usize, |
137 | 258 | ) -> bool { |
138 | 258 | utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) |
139 | 258 | && equal_values(lhs, rhs, lhs_start, rhs_start, len) |
140 | 258 | } |
141 | | |
142 | | /// Logically compares two [ArrayData]. |
143 | | /// |
144 | | /// Two arrays are logically equal if and only if: |
145 | | /// * their data types are equal |
146 | | /// * their lengths are equal |
147 | | /// * their null counts are equal |
148 | | /// * their null bitmaps are equal |
149 | | /// * each of their items are equal |
150 | | /// |
151 | | /// Two items are equal when their in-memory representation is physically equal |
152 | | /// (i.e. has the same bit content). |
153 | | /// |
154 | | /// The physical comparison depend on the data type. |
155 | | /// |
156 | | /// # Panics |
157 | | /// |
158 | | /// This function may panic whenever any of the [ArrayData] does not follow the |
159 | | /// Arrow specification. (e.g. wrong number of buffers, buffer `len` does not |
160 | | /// correspond to the declared `len`) |
161 | 439 | pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { |
162 | 439 | utils::base_equal(lhs, rhs) |
163 | 439 | && lhs.null_count() == rhs.null_count() |
164 | 439 | && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) |
165 | 439 | && equal_values(lhs, rhs, 0, 0, lhs.len()) |
166 | 439 | } |
167 | | |
168 | | // See arrow/tests/array_equal.rs for tests |