diff --git a/Cargo.lock b/Cargo.lock index 9cf4d96415c0..d0af7424ee07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2259,6 +2259,7 @@ version = "51.0.0" dependencies = [ "arrow", "async-trait", + "bitflags 2.10.0", "chrono", "ctor", "datafusion-common", diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6a75485c6284..dc6df06a8583 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -60,6 +60,7 @@ use crate::schema_equivalence::schema_satisfied_by; use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::Schema; +use arrow_schema::Field; use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeLevel; @@ -2517,7 +2518,9 @@ impl<'a> OptimizationInvariantChecker<'a> { previous_schema: Arc, ) -> Result<()> { // if the rule is not permitted to change the schema, confirm that it did not change. - if self.rule.schema_check() && plan.schema() != previous_schema { + if self.rule.schema_check() + && !is_allowed_schema_change(previous_schema.as_ref(), plan.schema().as_ref()) + { internal_err!("PhysicalOptimizer rule '{}' failed. Schema mismatch. Expected original schema: {:?}, got new schema: {:?}", self.rule.name(), previous_schema, @@ -2533,6 +2536,33 @@ impl<'a> OptimizationInvariantChecker<'a> { } } +/// Checks if the change from `old` schema to `new` is allowed or not. +/// The current implementation only allows nullability of individual fields to change +/// from 'nullable' to 'not nullable'. +fn is_allowed_schema_change(old: &Schema, new: &Schema) -> bool { + if new.metadata != old.metadata { + return false; + } + + if new.fields.len() != old.fields.len() { + return false; + } + + let new_fields = new.fields.iter().map(|f| f.as_ref()); + let old_fields = old.fields.iter().map(|f| f.as_ref()); + old_fields + .zip(new_fields) + .all(|(old, new)| is_allowed_field_change(old, new)) +} + +fn is_allowed_field_change(old_field: &Field, new_field: &Field) -> bool { + new_field.name() == old_field.name() + && new_field.data_type() == old_field.data_type() + && new_field.metadata() == old_field.metadata() + && (new_field.is_nullable() == old_field.is_nullable() + || !new_field.is_nullable()) +} + impl<'n> TreeNodeVisitor<'n> for OptimizationInvariantChecker<'_> { type Node = Arc; diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 252d76d0f9d9..3ad74962bc2c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -1052,9 +1052,12 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for sql in &sql { let df = ctx.sql(sql).await?; let (state, plan) = df.into_parts(); - let plan = state.optimize(&plan)?; if create_physical { let _ = state.create_physical_plan(&plan).await?; + } else { + // Run the logical optimizer even if we are not creating the physical plan + // to ensure it will properly succeed + let _ = state.optimize(&plan)?; } } diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index 7515b59b9221..478e95520f95 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -1762,6 +1762,16 @@ impl NullableInterval { } } + /// Return true if the value is definitely not true (either null or false). + pub fn is_certainly_not_true(&self) -> bool { + match self { + Self::Null { .. } => true, + Self::MaybeNull { values } | Self::NotNull { values } => { + values == &Interval::CERTAINLY_FALSE + } + } + } + /// Return true if the value is definitely false (and not null). pub fn is_certainly_false(&self) -> bool { match self { @@ -1967,6 +1977,7 @@ mod tests { operator::Operator, }; + use crate::interval_arithmetic::NullableInterval; use arrow::datatypes::DataType; use datafusion_common::rounding::{next_down, next_up}; use datafusion_common::{Result, ScalarValue}; @@ -4103,4 +4114,163 @@ mod tests { Ok(()) } + + #[test] + fn test_is_certainly_true() { + let test_cases = vec![ + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_true(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn test_is_certainly_not_true() { + let test_cases = vec![ + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + true, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_not_true(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } + + #[test] + fn test_is_certainly_false() { + let test_cases = vec![ + ( + NullableInterval::Null { + datatype: DataType::Boolean, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::CERTAINLY_FALSE, + }, + false, + ), + ( + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + }, + false, + ), + ( + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + }, + true, + ), + ( + NullableInterval::NotNull { + values: Interval::UNCERTAIN, + }, + false, + ), + ]; + + for (interval, expected) in test_cases { + let result = interval.is_certainly_false(); + assert_eq!(result, expected, "Failed for interval: {interval}",); + } + } } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 11d6ca1533db..84be57023d9a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -48,6 +48,7 @@ sql = ["sqlparser"] [dependencies] arrow = { workspace = true } async-trait = { workspace = true } +bitflags = "2.0.0" chrono = { workspace = true } datafusion-common = { workspace = true, default-features = false } datafusion-doc = { workspace = true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c777c4978f99..94d8009ce814 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -341,6 +341,11 @@ pub fn is_null(expr: Expr) -> Expr { Expr::IsNull(Box::new(expr)) } +/// Create is not null expression +pub fn is_not_null(expr: Expr) -> Expr { + Expr::IsNotNull(Box::new(expr)) +} + /// Create is true expression pub fn is_true(expr: Expr) -> Expr { Expr::IsTrue(Box::new(expr)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c..9b471d17ae25 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::{Between, Expr, Like}; +use super::{predicate_bounds, Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, @@ -282,14 +282,53 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(field, _) => Ok(field.is_nullable()), Expr::Literal(value, _) => Ok(value.is_null()), Expr::Case(case) => { - // This expression is nullable if any of the input expressions are nullable - let then_nullable = case + let nullable_then = case .when_then_expr .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if case.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + let bounds = match predicate_bounds::evaluate_bounds( + w, + Some(t), + input_schema, + ) { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + if bounds.is_certainly_not_true() { + // The predicate will never evaluate to true, so the 'then' expression + // is never reachable. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + None + } else { + // The branch might be taken + Some(Ok(())) + } + }) + .next(); + + if let Some(nullable_then) = nullable_then { + // There is at least one reachable nullable then + nullable_then.map(|_| true) } else if let Some(e) = &case.else_expr { e.nullable(input_schema) } else { @@ -773,7 +812,7 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::{col, lit, out_ref_col_with_metadata}; + use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when}; use datafusion_common::{internal_err, DFSchema, ScalarValue}; @@ -826,6 +865,137 @@ mod tests { assert!(expr.nullable(&get_schema(false)).unwrap()); } + fn assert_nullability(expr: &Expr, schema: &dyn ExprSchema, expected: bool) { + assert_eq!( + expr.nullable(schema).unwrap(), + expected, + "Nullability of '{expr}' should be {expected}" + ); + } + + fn assert_not_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, false); + } + + fn assert_nullable(expr: &Expr, schema: &dyn ExprSchema) { + assert_nullability(expr, schema, true); + } + + #[test] + fn test_case_expression_nullability() -> Result<()> { + let nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(true); + + let not_nullable_schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_nullable(false); + + // CASE WHEN x IS NOT NULL THEN x ELSE 0 + let e = when(col("x").is_not_null(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN NOT x IS NULL THEN x ELSE 0 + let e = when(not(col("x").is_null()), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN X = 5 THEN x ELSE 0 + let e = when(col("x").eq(lit(5)), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL AND x = 5 THEN x ELSE 0 + let e = when(and(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 AND x IS NOT NULL THEN x ELSE 0 + let e = when(and(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 + let e = when(or(col("x").is_not_null(), col("x").eq(lit(5))), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NOT NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_not_null()), col("x")) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN (x = 5 AND x IS NOT NULL) OR (x = bar AND x IS NOT NULL) THEN x ELSE 0 + let e = when( + or( + and(col("x").eq(lit(5)), col("x").is_not_null()), + and(col("x").eq(col("bar")), col("x").is_not_null()), + ), + col("x"), + ) + .otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x = 5 OR x IS NULL THEN x ELSE 0 + let e = when(or(col("x").eq(lit(5)), col("x").is_null()), col("x")) + .otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS TRUE THEN x ELSE 0 + let e = when(col("x").is_true(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT TRUE THEN x ELSE 0 + let e = when(col("x").is_not_true(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS FALSE THEN x ELSE 0 + let e = when(col("x").is_false(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT FALSE THEN x ELSE 0 + let e = when(col("x").is_not_false(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_unknown(), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x IS NOT UNKNOWN THEN x ELSE 0 + let e = when(col("x").is_not_unknown(), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN x LIKE 'x' THEN x ELSE 0 + let e = when(col("x").like(lit("x")), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 0 THEN x ELSE 0 + let e = when(lit(0), col("x")).otherwise(lit(0))?; + assert_not_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + // CASE WHEN 1 THEN x ELSE 0 + let e = when(lit(1), col("x")).otherwise(lit(0))?; + assert_nullable(&e, &nullable_schema); + assert_not_nullable(&e, ¬_nullable_schema); + + Ok(()) + } + #[test] fn test_inlist_nullability() { let get_schema = |nullable| { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 885e582ea6d4..c82b56aa58a3 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -73,6 +73,7 @@ pub mod async_udf; pub mod statistics { pub use datafusion_expr_common::statistics::*; } +mod predicate_bounds; pub mod ptr_eq; pub mod test; pub mod tree_node; diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs new file mode 100644 index 000000000000..547db3c9b0f0 --- /dev/null +++ b/datafusion/expr/src/predicate_bounds.rs @@ -0,0 +1,1045 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{BinaryExpr, Expr, ExprSchemable}; +use arrow::datatypes::DataType; +use bitflags::bitflags; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{DataFusionError, ExprSchema, Result, ScalarValue}; +use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr_common::operator::Operator; + +bitflags! { + /// A set representing the possible outcomes of a SQL boolean expression + #[derive(PartialEq, Eq, Clone, Debug)] + struct TernarySet: u8 { + const TRUE = 0b1; + const FALSE = 0b10; + const UNKNOWN = 0b100; + } +} + +impl TernarySet { + /// Returns the set of possible values after applying the `is true` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + fn is_true(&self) -> Self { + let mut is_true = Self::empty(); + if self.contains(Self::TRUE) { + is_true.toggle(Self::TRUE); + } + if self.intersects(Self::UNKNOWN | Self::FALSE) { + is_true.toggle(Self::FALSE); + } + is_true + } + + /// Returns the set of possible values after applying the `is false` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + fn is_false(&self) -> Self { + let mut is_false = Self::empty(); + if self.contains(Self::FALSE) { + is_false.toggle(Self::TRUE); + } + if self.intersects(Self::UNKNOWN | Self::TRUE) { + is_false.toggle(Self::FALSE); + } + is_false + } + + /// Returns the set of possible values after applying the `is unknown` test on all + /// values in this set. + /// The resulting set can only contain 'TRUE' and/or 'FALSE', never 'UNKNOWN'. + fn is_unknown(&self) -> Self { + let mut is_unknown = Self::empty(); + if self.contains(Self::UNKNOWN) { + is_unknown.toggle(Self::TRUE); + } + if self.intersects(Self::TRUE | Self::FALSE) { + is_unknown.toggle(Self::FALSE); + } + is_unknown + } + + /// Returns the set of possible values after applying SQL three-valued logical NOT + /// on each value in `value`. + /// + /// This method uses the following truth table. + /// + /// ```text + /// A | ¬A + /// ----|---- + /// F | T + /// U | U + /// T | F + /// ``` + fn not(set: &Self) -> Self { + let mut not = Self::empty(); + if set.contains(Self::TRUE) { + not.toggle(Self::FALSE); + } + if set.contains(Self::FALSE) { + not.toggle(Self::TRUE); + } + if set.contains(Self::UNKNOWN) { + not.toggle(Self::UNKNOWN); + } + not + } + + /// Returns the set of possible values after applying SQL three-valued logical AND + /// on each combination of values from `lhs` and `rhs`. + /// + /// This method uses the following truth table. + /// + /// ```text + /// A ∧ B │ F U T + /// ──────┼────── + /// F │ F F F + /// U │ F U U + /// T │ F U T + /// ``` + fn and(lhs: &Self, rhs: &Self) -> Self { + if lhs.is_empty() || rhs.is_empty() { + return Self::empty(); + } + + let mut and = Self::empty(); + if lhs.contains(Self::FALSE) || rhs.contains(Self::FALSE) { + and.toggle(Self::FALSE); + } + + if (lhs.contains(Self::UNKNOWN) && rhs.intersects(Self::TRUE | Self::UNKNOWN)) + || (rhs.contains(Self::UNKNOWN) && lhs.intersects(Self::TRUE | Self::UNKNOWN)) + { + and.toggle(Self::UNKNOWN); + } + + if lhs.contains(Self::TRUE) && rhs.contains(Self::TRUE) { + and.toggle(Self::TRUE); + } + + and + } + + /// Returns the set of possible values after applying SQL three-valued logical OR + /// on each combination of values from `lhs` and `rhs`. + /// + /// This method uses the following truth table. + /// + /// ```text + /// A ∨ B │ F U T + /// ──────┼────── + /// F │ F U T + /// U │ U U T + /// T │ T T T + /// ``` + fn or(lhs: &Self, rhs: &Self) -> Self { + let mut or = Self::empty(); + if lhs.contains(Self::TRUE) || rhs.contains(Self::TRUE) { + or.toggle(Self::TRUE); + } + + if (lhs.contains(Self::UNKNOWN) && rhs.intersects(Self::FALSE | Self::UNKNOWN)) + || (rhs.contains(Self::UNKNOWN) + && lhs.intersects(Self::FALSE | Self::UNKNOWN)) + { + or.toggle(Self::UNKNOWN); + } + + if lhs.contains(Self::FALSE) && rhs.contains(Self::FALSE) { + or.toggle(Self::FALSE); + } + + or + } +} + +impl TryFrom<&ScalarValue> for TernarySet { + type Error = DataFusionError; + + fn try_from(value: &ScalarValue) -> Result { + Ok(match value { + ScalarValue::Null => TernarySet::UNKNOWN, + ScalarValue::Boolean(b) => match b { + Some(true) => TernarySet::TRUE, + Some(false) => TernarySet::FALSE, + None => TernarySet::UNKNOWN, + }, + _ => { + let b = value.cast_to(&DataType::Boolean)?; + Self::try_from(&b)? + } + }) + } +} + +/// Computes the output interval for the given boolean expression based on statically +/// available information. +/// +/// # Arguments +/// +/// * `predicate` - The boolean expression to analyze +/// * `is_null` - A callback function that provides additional nullability information for +/// expressions. When called with an expression, it should return: +/// - `Some(true)` if the expression is known to evaluate to NULL +/// - `Some(false)` if the expression is known to NOT evaluate to NULL +/// - `None` if the nullability cannot be determined +/// +/// This callback allows the caller to provide context-specific knowledge about expression +/// nullability that cannot be determined from the schema alone. For example, it can be used +/// to indicate that a particular column reference is known to be NULL in a specific context, +/// or that certain expressions will never be NULL based on runtime constraints. +/// +/// * `input_schema` - Schema information for resolving expression types and nullability +/// +/// # Return Value +/// +/// The function returns a [NullableInterval] that describes the possible boolean values the +/// predicate can evaluate to. The return value will be one of the following: +/// +/// * `NullableInterval::NotNull { values: Interval::CERTAINLY_TRUE }` - The predicate will +/// always evaluate to TRUE (never FALSE or NULL) +/// +/// * `NullableInterval::NotNull { values: Interval::CERTAINLY_FALSE }` - The predicate will +/// always evaluate to FALSE (never TRUE or NULL) +/// +/// * `NullableInterval::NotNull { values: Interval::UNCERTAIN }` - The predicate will never +/// evaluate to NULL, but may be either TRUE or FALSE +/// +/// * `NullableInterval::Null { datatype: DataType::Boolean }` - The predicate will always +/// evaluate to NULL (SQL UNKNOWN in three-valued logic) +/// +/// * `NullableInterval::MaybeNull { values: Interval::CERTAINLY_TRUE }` - The predicate may +/// evaluate to TRUE or NULL, but never FALSE +/// +/// * `NullableInterval::MaybeNull { values: Interval::CERTAINLY_FALSE }` - The predicate may +/// evaluate to FALSE or NULL, but never TRUE +/// +/// * `NullableInterval::MaybeNull { values: Interval::UNCERTAIN }` - The predicate may +/// evaluate to any of TRUE, FALSE, or NULL +/// +pub(super) fn evaluate_bounds( + predicate: &Expr, + certainly_null_expr: Option<&Expr>, + input_schema: &dyn ExprSchema, +) -> Result { + let evaluator = PredicateBoundsEvaluator { + input_schema, + certainly_null_expr: certainly_null_expr.map(unwrap_certainly_null_expr), + }; + let possible_results = evaluator.evaluate_bounds(predicate)?; + + let interval = if possible_results.is_empty() || possible_results == TernarySet::all() + { + NullableInterval::MaybeNull { + values: Interval::UNCERTAIN, + } + } else if possible_results == TernarySet::TRUE { + NullableInterval::NotNull { + values: Interval::CERTAINLY_TRUE, + } + } else if possible_results == TernarySet::FALSE { + NullableInterval::NotNull { + values: Interval::CERTAINLY_FALSE, + } + } else if possible_results == TernarySet::UNKNOWN { + NullableInterval::Null { + datatype: DataType::Boolean, + } + } else { + let t = possible_results.contains(TernarySet::TRUE); + let f = possible_results.contains(TernarySet::FALSE); + let values = if t && f { + Interval::UNCERTAIN + } else if t { + Interval::CERTAINLY_TRUE + } else { + Interval::CERTAINLY_FALSE + }; + + if possible_results.contains(TernarySet::UNKNOWN) { + NullableInterval::MaybeNull { values } + } else { + NullableInterval::NotNull { values } + } + }; + + Ok(interval) +} + +/// Returns the innermost [Expr] that is provably null if `expr` is null. +fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { + match expr { + Expr::Not(e) => unwrap_certainly_null_expr(e), + Expr::Negative(e) => unwrap_certainly_null_expr(e), + Expr::Cast(e) => unwrap_certainly_null_expr(e.expr.as_ref()), + _ => expr, + } +} + +struct PredicateBoundsEvaluator<'a> { + input_schema: &'a dyn ExprSchema, + certainly_null_expr: Option<&'a Expr>, +} + +impl PredicateBoundsEvaluator<'_> { + /// Derives the bounds of the given boolean expression + fn evaluate_bounds(&self, predicate: &Expr) -> Result { + Ok(match predicate { + Expr::Literal(scalar, _) => { + // Interpret literals as boolean, coercing if necessary + TernarySet::try_from(scalar)? + } + Expr::IsNull(e) => { + // If `e` is not nullable, then `e IS NULL` is provably false + if !e.nullable(self.input_schema)? { + TernarySet::FALSE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, check if `e` is provably 'unknown'. + DataType::Boolean => self.evaluate_bounds(e)?.is_unknown(), + // If `e` is not a boolean expression, check if `e` is provably null + _ => self.is_null(e), + } + } + } + Expr::IsNotNull(e) => { + // If `e` is not nullable, then `e IS NOT NULL` is provably true + if !e.nullable(self.input_schema)? { + TernarySet::TRUE + } else { + match e.get_type(self.input_schema)? { + // If `e` is a boolean expression, try to evaluate it and test for not unknown + DataType::Boolean => { + TernarySet::not(&self.evaluate_bounds(e)?.is_unknown()) + } + // If `e` is not a boolean expression, check if `e` is provably null + _ => TernarySet::not(&self.is_null(e)), + } + } + } + Expr::IsTrue(e) => self.evaluate_bounds(e)?.is_true(), + Expr::IsNotTrue(e) => TernarySet::not(&self.evaluate_bounds(e)?.is_true()), + Expr::IsFalse(e) => self.evaluate_bounds(e)?.is_false(), + Expr::IsNotFalse(e) => TernarySet::not(&self.evaluate_bounds(e)?.is_false()), + Expr::IsUnknown(e) => self.evaluate_bounds(e)?.is_unknown(), + Expr::IsNotUnknown(e) => { + TernarySet::not(&self.evaluate_bounds(e)?.is_unknown()) + } + Expr::Not(e) => TernarySet::not(&self.evaluate_bounds(e)?), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) => TernarySet::and( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + ), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) => TernarySet::or( + &self.evaluate_bounds(left)?, + &self.evaluate_bounds(right)?, + ), + e => { + let mut result = TernarySet::empty(); + let is_null = self.is_null(e); + + // If an expression is null, then it's value is UNKNOWN + if is_null.contains(TernarySet::TRUE) { + result |= TernarySet::UNKNOWN + } + + // If an expression is not null, then it's either TRUE or FALSE + if is_null.contains(TernarySet::FALSE) { + result |= TernarySet::TRUE | TernarySet::FALSE + } + + result + } + }) + } + + /// Determines if the given expression can evaluate to `NULL`. + /// + /// This method only returns sets containing `TRUE`, `FALSE`, or both. + fn is_null(&self, expr: &Expr) -> TernarySet { + // Fast path for literals + if let Expr::Literal(scalar, _) = expr { + if scalar.is_null() { + return TernarySet::TRUE; + } else { + return TernarySet::FALSE; + } + } + + // If `expr` is not nullable, we can be certain `expr` is not null + if let Ok(false) = expr.nullable(self.input_schema) { + return TernarySet::FALSE; + } + + // Check if the expression is the `certainly_null_expr` that was passed in. + if let Some(certainly_null_expr) = &self.certainly_null_expr { + if expr.eq(certainly_null_expr) { + return TernarySet::TRUE; + } + } + + // `expr` is nullable, so our default answer for `is null` is going to be `{ TRUE, FALSE }`. + // Try to see if we can narrow it down to just one option. + match expr { + Expr::BinaryExpr(BinaryExpr { op, .. }) if op.returns_null_on_null() => { + self.is_null_if_any_child_null(expr) + } + Expr::Alias(_) + | Expr::Cast(_) + | Expr::Like(_) + | Expr::Negative(_) + | Expr::Not(_) + | Expr::SimilarTo(_) => self.is_null_if_any_child_null(expr), + _ => TernarySet::TRUE | TernarySet::FALSE, + } + } + + fn is_null_if_any_child_null(&self, expr: &Expr) -> TernarySet { + // These expressions are null if any of their direct children is null + // If any child is inconclusive, the result for this expression is also inconclusive + let mut is_null = TernarySet::FALSE.clone(); + let _ = expr.apply_children(|child| { + let child_is_null = self.is_null(child); + + if child_is_null.contains(TernarySet::TRUE) { + // If a child might be null, then the result may also be null + is_null.insert(TernarySet::TRUE); + } + + if !child_is_null.contains(TernarySet::FALSE) { + // If the child is never not null, then the result can also never be not null + // and we can stop traversing the children + is_null.remove(TernarySet::FALSE); + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + is_null + } +} + +#[cfg(test)] +mod tests { + use crate::expr::ScalarFunction; + use crate::predicate_bounds::{evaluate_bounds, TernarySet}; + use crate::{ + binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, + is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{DFSchema, ExprSchema, Result, ScalarValue}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::operator::Operator::{And, Eq, Or}; + use datafusion_expr_common::signature::Volatility; + use std::sync::Arc; + + #[test] + fn tristate_bool_from_scalar() { + let cases = vec![ + (ScalarValue::Null, TernarySet::UNKNOWN), + (ScalarValue::Boolean(None), TernarySet::UNKNOWN), + (ScalarValue::Boolean(Some(true)), TernarySet::TRUE), + (ScalarValue::Boolean(Some(false)), TernarySet::FALSE), + (ScalarValue::UInt8(None), TernarySet::UNKNOWN), + (ScalarValue::UInt8(Some(0)), TernarySet::FALSE), + (ScalarValue::UInt8(Some(1)), TernarySet::TRUE), + ]; + + for case in cases { + assert_eq!(TernarySet::try_from(&case.0).unwrap(), case.1); + } + + let error_cases = vec![ScalarValue::Utf8(Some("abc".to_string()))]; + + for case in error_cases { + assert!(TernarySet::try_from(&case).is_err()); + } + } + + #[test] + fn tristate_bool_not() { + let cases = vec![ + (TernarySet::UNKNOWN, TernarySet::UNKNOWN), + (TernarySet::TRUE, TernarySet::FALSE), + (TernarySet::FALSE, TernarySet::TRUE), + ( + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + ), + ( + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!(TernarySet::not(&case.0), case.1); + } + } + + #[test] + fn tristate_bool_and() { + let cases = vec![ + ( + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + ), + (TernarySet::UNKNOWN, TernarySet::TRUE, TernarySet::UNKNOWN), + (TernarySet::UNKNOWN, TernarySet::FALSE, TernarySet::FALSE), + (TernarySet::TRUE, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::TRUE, TernarySet::FALSE, TernarySet::FALSE), + (TernarySet::FALSE, TernarySet::FALSE, TernarySet::FALSE), + ( + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE, + TernarySet::FALSE, + ), + ( + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::FALSE, + ), + ( + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!( + TernarySet::and(&case.0, &case.1), + case.2.clone(), + "{:?} & {:?} = {:?}", + case.0.clone(), + case.1.clone(), + case.2.clone() + ); + assert_eq!( + TernarySet::and(&case.1, &case.0), + case.2.clone(), + "{:?} & {:?} = {:?}", + case.1, + case.0, + case.2 + ); + } + } + + #[test] + fn tristate_bool_or() { + let cases = vec![ + ( + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + TernarySet::UNKNOWN, + ), + (TernarySet::UNKNOWN, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::UNKNOWN, TernarySet::FALSE, TernarySet::UNKNOWN), + (TernarySet::TRUE, TernarySet::TRUE, TernarySet::TRUE), + (TernarySet::TRUE, TernarySet::FALSE, TernarySet::TRUE), + (TernarySet::FALSE, TernarySet::FALSE, TernarySet::FALSE), + ( + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE, + ), + ( + TernarySet::TRUE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, + ), + ( + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, + ), + ( + TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE, + TernarySet::TRUE, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ( + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + TernarySet::TRUE | TernarySet::FALSE | TernarySet::UNKNOWN, + ), + ]; + + for case in cases { + assert_eq!( + TernarySet::or(&case.0, &case.1), + case.2.clone(), + "{:?} | {:?} = {:?}", + case.0, + case.1, + case.2 + ); + assert_eq!( + TernarySet::or(&case.1, &case.0), + case.2.clone(), + "{:?} | {:?} = {:?}", + case.1, + case.0, + case.2 + ); + } + } + + fn try_eval_predicate_bounds( + predicate: &Expr, + evaluates_to_null: Option<&Expr>, + input_schema: &dyn ExprSchema, + ) -> Result> { + let bounds = evaluate_bounds(predicate, evaluates_to_null, input_schema)?; + + Ok(if bounds.is_certainly_true() { + Some(true) + } else if bounds.is_certainly_not_true() { + Some(false) + } else { + None + }) + } + + fn eval_predicate_bounds( + predicate: &Expr, + evaluates_to_null: Option<&Expr>, + input_schema: &dyn ExprSchema, + ) -> Option { + try_eval_predicate_bounds(predicate, evaluates_to_null, input_schema).unwrap() + } + + fn try_eval_bounds(predicate: &Expr) -> Result> { + let schema = DFSchema::try_from(Schema::empty())?; + try_eval_predicate_bounds(predicate, None, &schema) + } + + fn eval_bounds(predicate: &Expr) -> Option { + try_eval_bounds(predicate).unwrap() + } + + #[test] + fn evaluate_bounds_literal() { + assert_eq!(eval_bounds(&lit(ScalarValue::Null)), Some(false)); + + assert_eq!(eval_bounds(&lit(false)), Some(false)); + assert_eq!(eval_bounds(&lit(true)), Some(true)); + + assert_eq!(eval_bounds(&lit(0)), Some(false)); + assert_eq!(eval_bounds(&lit(1)), Some(true)); + + assert_eq!(eval_bounds(&lit(ScalarValue::Utf8(None))), Some(false)); + assert!(try_eval_bounds(&lit("foo")).is_err()); + } + + #[test] + fn evaluate_bounds_and() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, null.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, one.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, zero.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(one.clone(), And, one.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(one.clone(), And, zero.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, t.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(t.clone(), And, null.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), And, null.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(t.clone(), And, t.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(t.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), And, t.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), And, f.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(t.clone(), And, func.clone())), + None + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), And, t.clone())), + None + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), And, func.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), And, f.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), And, func.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), And, null.clone())), + Some(false) + ); + } + + #[test] + fn evaluate_bounds_or() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, null.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, one.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, zero.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(one.clone(), Or, one.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(one.clone(), Or, zero.clone())), + Some(true) + ); + + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(t.clone(), Or, null.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, f.clone())), + Some(false) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), Or, null.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(t.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(t.clone(), Or, f.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(f.clone(), Or, f.clone())), + Some(false) + ); + + assert_eq!( + eval_bounds(&binary_expr(t.clone(), Or, func.clone())), + Some(true) + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), Or, t.clone())), + Some(true) + ); + assert_eq!(eval_bounds(&binary_expr(f.clone(), Or, func.clone())), None); + assert_eq!(eval_bounds(&binary_expr(func.clone(), Or, f.clone())), None); + assert_eq!( + eval_bounds(&binary_expr(null.clone(), Or, func.clone())), + None + ); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), Or, null.clone())), + None + ); + } + + #[test] + fn evaluate_bounds_not() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let func = make_scalar_func_expr(); + + assert_eq!(eval_bounds(¬(null.clone())), Some(false)); + assert_eq!(eval_bounds(¬(one.clone())), Some(false)); + assert_eq!(eval_bounds(¬(zero.clone())), Some(true)); + + assert_eq!(eval_bounds(¬(t.clone())), Some(false)); + assert_eq!(eval_bounds(¬(f.clone())), Some(true)); + + assert_eq!(eval_bounds(¬(func.clone())), None); + } + + #[test] + fn evaluate_bounds_is() { + let null = lit(ScalarValue::Null); + let zero = lit(0); + let one = lit(1); + let t = lit(true); + let f = lit(false); + let col = col("col"); + let nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + true, + )])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "col", + DataType::UInt8, + false, + )])) + .unwrap(); + + assert_eq!(eval_bounds(&is_null(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_null(one.clone())), Some(false)); + let predicate = &is_null(col.clone()); + assert_eq!( + eval_predicate_bounds(predicate, Some(&col), &nullable_schema), + Some(true) + ); + let predicate = &is_null(col.clone()); + assert_eq!( + eval_predicate_bounds(predicate, Some(&col), ¬_nullable_schema), + Some(false) + ); + + assert_eq!(eval_bounds(&is_not_null(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_null(one.clone())), Some(true)); + let predicate = &is_not_null(col.clone()); + assert_eq!( + eval_predicate_bounds(predicate, Some(&col), &nullable_schema), + Some(false) + ); + let predicate = &is_not_null(col.clone()); + assert_eq!( + eval_predicate_bounds(predicate, Some(&col), ¬_nullable_schema), + Some(true) + ); + + assert_eq!(eval_bounds(&is_true(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_true(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_true(one.clone())), Some(true)); + + assert_eq!(eval_bounds(&is_not_true(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_true(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_true(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_false(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_false(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_false(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_false(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_false(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_not_false(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_false(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_false(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_false(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_false(one.clone())), Some(true)); + + assert_eq!(eval_bounds(&is_unknown(null.clone())), Some(true)); + assert_eq!(eval_bounds(&is_unknown(t.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(f.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(zero.clone())), Some(false)); + assert_eq!(eval_bounds(&is_unknown(one.clone())), Some(false)); + + assert_eq!(eval_bounds(&is_not_unknown(null.clone())), Some(false)); + assert_eq!(eval_bounds(&is_not_unknown(t.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(f.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(zero.clone())), Some(true)); + assert_eq!(eval_bounds(&is_not_unknown(one.clone())), Some(true)); + } + + #[test] + fn evaluate_bounds_udf() { + let func = make_scalar_func_expr(); + + assert_eq!(eval_bounds(&func.clone()), None); + assert_eq!(eval_bounds(¬(func.clone())), None); + assert_eq!( + eval_bounds(&binary_expr(func.clone(), And, func.clone())), + None + ); + } + + fn make_scalar_func_expr() -> Expr { + let scalar_func_impl = + |_: &[ColumnarValue]| Ok(ColumnarValue::Scalar(ScalarValue::Null)); + let udf = create_udf( + "foo", + vec![], + DataType::Boolean, + Volatility::Stable, + Arc::new(scalar_func_impl), + ); + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), vec![])) + } + + #[test] + fn evaluate_bounds_when_then() { + let nullable_schema = + DFSchema::try_from(Schema::new(vec![Field::new("x", DataType::UInt8, true)])) + .unwrap(); + let not_nullable_schema = DFSchema::try_from(Schema::new(vec![Field::new( + "x", + DataType::UInt8, + false, + )])) + .unwrap(); + + let x = col("x"); + + // CASE WHEN x IS NOT NULL OR x = 5 THEN x ELSE 0 END + let when = binary_expr( + is_not_null(x.clone()), + Or, + binary_expr(x.clone(), Eq, lit(5)), + ); + + assert_eq!( + eval_predicate_bounds(&when, Some(&x), &nullable_schema), + Some(false) + ); + assert_eq!( + eval_predicate_bounds(&when, Some(&x), ¬_nullable_schema), + Some(true) + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 7a33aa95c56b..0d090d6b8001 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,7 +17,7 @@ use super::{Column, Literal}; use crate::expressions::case::ResultState::{Complete, Empty, Partial}; -use crate::expressions::try_cast; +use crate::expressions::{lit, try_cast}; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; @@ -34,13 +34,14 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::datum::compare_with_eq; -use itertools::Itertools; use std::borrow::Cow; -use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::{any::Any, sync::Arc}; +use datafusion_physical_expr_common::datum::compare_with_eq; +use itertools::Itertools; +use std::fmt::{Debug, Formatter}; + type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] @@ -1282,15 +1283,57 @@ impl PhysicalExpr for CaseExpr { } fn nullable(&self, input_schema: &Schema) -> Result { - // this expression is nullable if any of the input expressions are nullable - let then_nullable = self + let nullable_then = self .body .when_then_expr .iter() - .map(|(_, t)| t.nullable(input_schema)) - .collect::>>()?; - if then_nullable.contains(&true) { - Ok(true) + .filter_map(|(w, t)| { + let is_nullable = match t.nullable(input_schema) { + // Pass on error determining nullability verbatim + Err(e) => return Some(Err(e)), + Ok(n) => n, + }; + + // Branches with a then expression that is not nullable do not impact the + // nullability of the case expression. + if !is_nullable { + return None; + } + + // For case-with-expression assume all 'then' expressions are reachable + if self.body.expr.is_some() { + return Some(Ok(())); + } + + // For branches with a nullable 'then' expression, try to determine + // if the 'then' expression is ever reachable in the situation where + // it would evaluate to null. + + // Replace the `then` expression with `NULL` in the `when` expression + let with_null = match replace_with_null(w, t.as_ref(), input_schema) { + Err(e) => return Some(Err(e)), + Ok(e) => e, + }; + + // Try to const evaluate the modified `when` expression. + let predicate_result = match evaluate_predicate(&with_null) { + Err(e) => return Some(Err(e)), + Ok(b) => b, + }; + + match predicate_result { + // Evaluation was inconclusive or true, so the 'then' expression is reachable + None | Some(true) => Some(Ok(())), + // Evaluation proves the branch will never be taken. + // The most common pattern for this is `WHEN x IS NOT NULL THEN x`. + Some(false) => None, + } + }) + .next(); + + if let Some(nullable_then) = nullable_then { + // There is at least one reachable nullable then + nullable_then.map(|_| true) } else if let Some(e) = &self.body.else_expr { e.nullable(input_schema) } else { @@ -1394,6 +1437,51 @@ impl PhysicalExpr for CaseExpr { } } +/// Attempts to const evaluate the given `predicate`. +/// Returns: +/// - `Some(true)` if the predicate evaluates to a truthy value. +/// - `Some(false)` if the predicate evaluates to a falsy value. +/// - `None` if the predicate could not be evaluated. +fn evaluate_predicate(predicate: &Arc) -> Result> { + // Create a dummy record with no columns and one row + let batch = RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(1)), + )?; + + // Evaluate the predicate and interpret the result as a boolean + let result = match predicate.evaluate(&batch) { + // An error during evaluation means we couldn't const evaluate the predicate, so return `None` + Err(_) => None, + Ok(ColumnarValue::Array(array)) => Some( + ScalarValue::try_from_array(array.as_ref(), 0)? + .cast_to(&DataType::Boolean)?, + ), + Ok(ColumnarValue::Scalar(scalar)) => Some(scalar.cast_to(&DataType::Boolean)?), + }; + Ok(result.map(|v| matches!(v, ScalarValue::Boolean(Some(true))))) +} + +fn replace_with_null( + expr: &Arc, + expr_to_replace: &dyn PhysicalExpr, + input_schema: &Schema, +) -> Result, DataFusionError> { + let with_null = Arc::clone(expr) + .transform_down(|e| { + if e.as_ref().dyn_eq(expr_to_replace) { + let data_type = e.data_type(input_schema)?; + let null_literal = lit(ScalarValue::try_new_null(&data_type)?); + Ok(Transformed::yes(null_literal)) + } else { + Ok(Transformed::no(e)) + } + })? + .data; + Ok(with_null) +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -1407,7 +1495,8 @@ pub fn case( mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions; + use crate::expressions::{binary, cast, col, is_not_null, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::Field; @@ -1415,7 +1504,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; - use datafusion_expr::Operator; + use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; #[test] @@ -2292,4 +2381,182 @@ mod tests { assert!(merged.is_valid(2)); assert_eq!(merged.value(2), "C"); } + + fn when_then_else( + when: &Arc, + then: &Arc, + els: &Arc, + ) -> Result> { + let case = CaseExpr::try_new( + None, + vec![(Arc::clone(when), Arc::clone(then))], + Some(Arc::clone(els)), + )?; + Ok(Arc::new(case)) + } + + #[test] + fn test_case_expression_nullability_with_nullable_column() -> Result<()> { + case_expression_nullability(true) + } + + #[test] + fn test_case_expression_nullability_with_not_nullable_column() -> Result<()> { + case_expression_nullability(false) + } + + fn case_expression_nullability(col_is_nullable: bool) -> Result<()> { + let schema = + Schema::new(vec![Field::new("foo", DataType::Int32, col_is_nullable)]); + + let foo = col("foo", &schema)?; + let foo_is_not_null = is_not_null(Arc::clone(&foo))?; + let foo_is_null = expressions::is_null(Arc::clone(&foo))?; + let not_foo_is_null = expressions::not(Arc::clone(&foo_is_null))?; + let zero = lit(0); + let foo_eq_zero = + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?; + + assert_not_nullable(when_then_else(&foo_is_not_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(¬_foo_is_null, &foo, &zero)?, &schema); + assert_not_nullable(when_then_else(&foo_eq_zero, &foo, &zero)?, &schema); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::And, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_is_not_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_not_nullable( + when_then_else( + &binary( + Arc::clone(&foo_eq_zero), + Operator::Or, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + assert_nullability( + when_then_else( + &binary( + Arc::clone(&foo_is_null), + Operator::Or, + Arc::clone(&foo_eq_zero), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_nullability( + when_then_else( + &binary( + binary(Arc::clone(&foo), Operator::Eq, Arc::clone(&zero), &schema)?, + Operator::Or, + Arc::clone(&foo_is_null), + &schema, + )?, + &foo, + &zero, + )?, + &schema, + col_is_nullable, + ); + + assert_not_nullable( + when_then_else( + &binary( + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&zero), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + Operator::Or, + binary( + binary( + Arc::clone(&foo), + Operator::Eq, + Arc::clone(&foo), + &schema, + )?, + Operator::And, + Arc::clone(&foo_is_not_null), + &schema, + )?, + &schema, + )?, + &foo, + &zero, + )?, + &schema, + ); + + Ok(()) + } + + fn assert_not_nullable(expr: Arc, schema: &Schema) { + assert!(!expr.nullable(schema).unwrap()); + } + + fn assert_nullable(expr: Arc, schema: &Schema) { + assert!(expr.nullable(schema).unwrap()); + } + + fn assert_nullability(expr: Arc, schema: &Schema, nullable: bool) { + if nullable { + assert_nullable(expr, schema); + } else { + assert_not_nullable(expr, schema); + } + } } diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 1a4b6a7a2b4a..3905575d22dc 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -683,3 +683,23 @@ FROM ( 10 10 100 -20 20 200 NULL 30 300 + +# Case-with-expression that was incorrectly classified as not-nullable, but evaluates to null +query I +SELECT CASE 0 WHEN 0 THEN NULL WHEN SUM(1) + COUNT(*) THEN 10 ELSE 20 END +---- +NULL + +query TT +EXPLAIN SELECT CASE WHEN CASE WHEN a IS NOT NULL THEN a ELSE 1 END IS NOT NULL THEN a ELSE 1 END FROM ( + VALUES (10), (20), (30) + ) t(a); +---- +logical_plan +01)Projection: t.a AS CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END +02)--SubqueryAlias: t +03)----Projection: column1 AS a +04)------Values: (Int64(10)), (Int64(20)), (Int64(30)) +physical_plan +01)ProjectionExec: expr=[column1@0 as CASE WHEN CASE WHEN t.a IS NOT NULL THEN t.a ELSE Int64(1) END IS NOT NULL THEN t.a ELSE Int64(1) END] +02)--DataSourceExec: partitions=1, partition_sizes=[1]