|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use arrow::array::{new_null_array, BooleanArray}; |
19 | | -use arrow::compute::kernels::zip::zip; |
20 | | -use arrow::compute::{and, is_not_null, is_null}; |
21 | 18 | use arrow::datatypes::{DataType, Field, FieldRef}; |
22 | | -use datafusion_common::{exec_err, internal_err, Result}; |
| 19 | +use datafusion_common::{exec_err, internal_err, plan_err, Result}; |
23 | 20 | use datafusion_expr::binary::try_type_union_resolution; |
| 21 | +use datafusion_expr::conditional_expressions::CaseBuilder; |
| 22 | +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; |
24 | 23 | use datafusion_expr::{ |
25 | | - ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, |
| 24 | + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, |
26 | 25 | }; |
27 | 26 | use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; |
28 | 27 | use datafusion_macros::user_doc; |
@@ -95,61 +94,36 @@ impl ScalarUDFImpl for CoalesceFunc { |
95 | 94 | Ok(Field::new(self.name(), return_type, nullable).into()) |
96 | 95 | } |
97 | 96 |
|
98 | | - /// coalesce evaluates to the first value which is not NULL |
99 | | - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
100 | | - let args = args.args; |
101 | | - // do not accept 0 arguments. |
| 97 | + fn simplify( |
| 98 | + &self, |
| 99 | + args: Vec<Expr>, |
| 100 | + _info: &dyn SimplifyInfo, |
| 101 | + ) -> Result<ExprSimplifyResult> { |
102 | 102 | if args.is_empty() { |
103 | | - return exec_err!( |
104 | | - "coalesce was called with {} arguments. It requires at least 1.", |
105 | | - args.len() |
106 | | - ); |
| 103 | + return plan_err!("coalesce must have at least one argument"); |
107 | 104 | } |
108 | | - |
109 | | - let return_type = args[0].data_type(); |
110 | | - let mut return_array = args.iter().filter_map(|x| match x { |
111 | | - ColumnarValue::Array(array) => Some(array.len()), |
112 | | - _ => None, |
113 | | - }); |
114 | | - |
115 | | - if let Some(size) = return_array.next() { |
116 | | - // start with nulls as default output |
117 | | - let mut current_value = new_null_array(&return_type, size); |
118 | | - let mut remainder = BooleanArray::from(vec![true; size]); |
119 | | - |
120 | | - for arg in args { |
121 | | - match arg { |
122 | | - ColumnarValue::Array(ref array) => { |
123 | | - let to_apply = and(&remainder, &is_not_null(array.as_ref())?)?; |
124 | | - current_value = zip(&to_apply, array, ¤t_value)?; |
125 | | - remainder = and(&remainder, &is_null(array)?)?; |
126 | | - } |
127 | | - ColumnarValue::Scalar(value) => { |
128 | | - if value.is_null() { |
129 | | - continue; |
130 | | - } else { |
131 | | - let last_value = value.to_scalar()?; |
132 | | - current_value = zip(&remainder, &last_value, ¤t_value)?; |
133 | | - break; |
134 | | - } |
135 | | - } |
136 | | - } |
137 | | - if remainder.iter().all(|x| x == Some(false)) { |
138 | | - break; |
139 | | - } |
140 | | - } |
141 | | - Ok(ColumnarValue::Array(current_value)) |
142 | | - } else { |
143 | | - let result = args |
144 | | - .iter() |
145 | | - .filter_map(|x| match x { |
146 | | - ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()), |
147 | | - _ => None, |
148 | | - }) |
149 | | - .next() |
150 | | - .unwrap_or_else(|| args[0].clone()); |
151 | | - Ok(result) |
| 105 | + if args.len() == 1 { |
| 106 | + return Ok(ExprSimplifyResult::Simplified( |
| 107 | + args.into_iter().next().unwrap(), |
| 108 | + )); |
152 | 109 | } |
| 110 | + |
| 111 | + let n = args.len(); |
| 112 | + let (init, last_elem) = args.split_at(n - 1); |
| 113 | + let whens = init |
| 114 | + .iter() |
| 115 | + .map(|x| x.clone().is_not_null()) |
| 116 | + .collect::<Vec<_>>(); |
| 117 | + let cases = init.to_vec(); |
| 118 | + Ok(ExprSimplifyResult::Simplified( |
| 119 | + CaseBuilder::new(None, whens, cases, Some(Box::new(last_elem[0].clone()))) |
| 120 | + .end()?, |
| 121 | + )) |
| 122 | + } |
| 123 | + |
| 124 | + /// coalesce evaluates to the first value which is not NULL |
| 125 | + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 126 | + internal_err!("coalesce should have been simplified to case") |
153 | 127 | } |
154 | 128 |
|
155 | 129 | fn short_circuits(&self) -> bool { |
|
0 commit comments