diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index ed3a693e..53e78fd3 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -20,7 +20,10 @@ set(ICEBERG_INCLUDES "$" set(ICEBERG_SOURCES catalog/in_memory_catalog.cc expression/expression.cc + expression/expressions.cc expression/literal.cc + expression/predicate.cc + expression/term.cc file_reader.cc file_writer.cc inheritable_metadata.cc diff --git a/src/iceberg/expression/expression.cc b/src/iceberg/expression/expression.cc index c6fa9406..f6f6d0f7 100644 --- a/src/iceberg/expression/expression.cc +++ b/src/iceberg/expression/expression.cc @@ -20,6 +20,10 @@ #include "iceberg/expression/expression.h" #include +#include + +#include "iceberg/util/formatter_internal.h" +#include "iceberg/util/macros.h" namespace iceberg { @@ -29,7 +33,7 @@ const std::shared_ptr& True::Instance() { return instance; } -std::shared_ptr True::Negate() const { return False::Instance(); } +Result> True::Negate() const { return False::Instance(); } // False implementation const std::shared_ptr& False::Instance() { @@ -37,7 +41,7 @@ const std::shared_ptr& False::Instance() { return instance; } -std::shared_ptr False::Negate() const { return True::Instance(); } +Result> False::Negate() const { return True::Instance(); } // And implementation And::And(std::shared_ptr left, std::shared_ptr right) @@ -47,11 +51,11 @@ std::string And::ToString() const { return std::format("({} and {})", left_->ToString(), right_->ToString()); } -std::shared_ptr And::Negate() const { +Result> And::Negate() const { // De Morgan's law: not(A and B) = (not A) or (not B) - auto left_negated = left_->Negate(); - auto right_negated = right_->Negate(); - return std::make_shared(left_negated, right_negated); + ICEBERG_ASSIGN_OR_RAISE(auto left_negated, left_->Negate()); + ICEBERG_ASSIGN_OR_RAISE(auto right_negated, right_->Negate()); + return std::make_shared(std::move(left_negated), std::move(right_negated)); } bool And::Equals(const Expression& expr) const { @@ -71,11 +75,11 @@ std::string Or::ToString() const { return std::format("({} or {})", left_->ToString(), right_->ToString()); } -std::shared_ptr Or::Negate() const { +Result> Or::Negate() const { // De Morgan's law: not(A or B) = (not A) and (not B) - auto left_negated = left_->Negate(); - auto right_negated = right_->Negate(); - return std::make_shared(left_negated, right_negated); + ICEBERG_ASSIGN_OR_RAISE(auto left_negated, left_->Negate()); + ICEBERG_ASSIGN_OR_RAISE(auto right_negated, right_->Negate()); + return std::make_shared(std::move(left_negated), std::move(right_negated)); } bool Or::Equals(const Expression& expr) const { @@ -87,4 +91,104 @@ bool Or::Equals(const Expression& expr) const { return false; } +std::string_view ToString(Expression::Operation op) { + switch (op) { + case Expression::Operation::kAnd: + return "AND"; + case Expression::Operation::kOr: + return "OR"; + case Expression::Operation::kTrue: + return "TRUE"; + case Expression::Operation::kFalse: + return "FALSE"; + case Expression::Operation::kIsNull: + return "IS_NULL"; + case Expression::Operation::kNotNull: + return "NOT_NULL"; + case Expression::Operation::kIsNan: + return "IS_NAN"; + case Expression::Operation::kNotNan: + return "NOT_NAN"; + case Expression::Operation::kLt: + return "LT"; + case Expression::Operation::kLtEq: + return "LT_EQ"; + case Expression::Operation::kGt: + return "GT"; + case Expression::Operation::kGtEq: + return "GT_EQ"; + case Expression::Operation::kEq: + return "EQ"; + case Expression::Operation::kNotEq: + return "NOT_EQ"; + case Expression::Operation::kIn: + return "IN"; + case Expression::Operation::kNotIn: + return "NOT_IN"; + case Expression::Operation::kStartsWith: + return "STARTS_WITH"; + case Expression::Operation::kNotStartsWith: + return "NOT_STARTS_WITH"; + case Expression::Operation::kCount: + return "COUNT"; + case Expression::Operation::kNot: + return "NOT"; + case Expression::Operation::kCountStar: + return "COUNT_STAR"; + case Expression::Operation::kMax: + return "MAX"; + case Expression::Operation::kMin: + return "MIN"; + } + std::unreachable(); +} + +Result Negate(Expression::Operation op) { + switch (op) { + case Expression::Operation::kIsNull: + return Expression::Operation::kNotNull; + case Expression::Operation::kNotNull: + return Expression::Operation::kIsNull; + case Expression::Operation::kIsNan: + return Expression::Operation::kNotNan; + case Expression::Operation::kNotNan: + return Expression::Operation::kIsNan; + case Expression::Operation::kLt: + return Expression::Operation::kGtEq; + case Expression::Operation::kLtEq: + return Expression::Operation::kGt; + case Expression::Operation::kGt: + return Expression::Operation::kLtEq; + case Expression::Operation::kGtEq: + return Expression::Operation::kLt; + case Expression::Operation::kEq: + return Expression::Operation::kNotEq; + case Expression::Operation::kNotEq: + return Expression::Operation::kEq; + case Expression::Operation::kIn: + return Expression::Operation::kNotIn; + case Expression::Operation::kNotIn: + return Expression::Operation::kIn; + case Expression::Operation::kStartsWith: + return Expression::Operation::kNotStartsWith; + case Expression::Operation::kNotStartsWith: + return Expression::Operation::kStartsWith; + case Expression::Operation::kTrue: + return Expression::Operation::kFalse; + case Expression::Operation::kFalse: + return Expression::Operation::kTrue; + case Expression::Operation::kAnd: + return Expression::Operation::kOr; + case Expression::Operation::kOr: + return Expression::Operation::kAnd; + case Expression::Operation::kNot: + case Expression::Operation::kCountStar: + case Expression::Operation::kMax: + case Expression::Operation::kMin: + case Expression::Operation::kCount: + return InvalidArgument("No negation for operation: {}", op); + } + std::unreachable(); +} + } // namespace iceberg diff --git a/src/iceberg/expression/expression.h b/src/iceberg/expression/expression.h index 9ceae1c6..e0708c4e 100644 --- a/src/iceberg/expression/expression.h +++ b/src/iceberg/expression/expression.h @@ -25,13 +25,14 @@ #include #include -#include "iceberg/exception.h" #include "iceberg/iceberg_export.h" +#include "iceberg/result.h" +#include "iceberg/util/formattable.h" namespace iceberg { /// \brief Represents a boolean expression tree. -class ICEBERG_EXPORT Expression { +class ICEBERG_EXPORT Expression : public util::Formattable { public: /// Operation types for expressions enum class Operation { @@ -66,8 +67,8 @@ class ICEBERG_EXPORT Expression { virtual Operation op() const = 0; /// \brief Returns the negation of this expression, equivalent to not(this). - virtual std::shared_ptr Negate() const { - throw IcebergError("Expression cannot be negated"); + virtual Result> Negate() const { + return NotSupported("Expression cannot be negated"); } /// \brief Returns whether this expression will accept the same values as another. @@ -78,7 +79,7 @@ class ICEBERG_EXPORT Expression { return false; } - virtual std::string ToString() const { return "Expression"; } + std::string ToString() const override { return "Expression"; } }; /// \brief An Expression that is always true. @@ -93,7 +94,7 @@ class ICEBERG_EXPORT True : public Expression { std::string ToString() const override { return "true"; } - std::shared_ptr Negate() const override; + Result> Negate() const override; bool Equals(const Expression& other) const override { return other.op() == Operation::kTrue; @@ -113,7 +114,7 @@ class ICEBERG_EXPORT False : public Expression { std::string ToString() const override { return "false"; } - std::shared_ptr Negate() const override; + Result> Negate() const override; bool Equals(const Expression& other) const override { return other.op() == Operation::kFalse; @@ -149,7 +150,7 @@ class ICEBERG_EXPORT And : public Expression { std::string ToString() const override; - std::shared_ptr Negate() const override; + Result> Negate() const override; bool Equals(const Expression& other) const override; @@ -184,7 +185,7 @@ class ICEBERG_EXPORT Or : public Expression { std::string ToString() const override; - std::shared_ptr Negate() const override; + Result> Negate() const override; bool Equals(const Expression& other) const override; @@ -193,4 +194,10 @@ class ICEBERG_EXPORT Or : public Expression { std::shared_ptr right_; }; +/// \brief Returns a string representation of an expression operation. +ICEBERG_EXPORT std::string_view ToString(Expression::Operation op); + +/// \brief Returns the negated operation. +ICEBERG_EXPORT Result Negate(Expression::Operation op); + } // namespace iceberg diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc new file mode 100644 index 00000000..a775c993 --- /dev/null +++ b/src/iceberg/expression/expressions.cc @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/expressions.h" + +#include "iceberg/exception.h" +#include "iceberg/transform.h" +#include "iceberg/type.h" + +namespace iceberg { + +// Transform functions + +std::shared_ptr Expressions::Bucket(std::string name, + int32_t num_buckets) { + return std::make_shared(Ref(std::move(name)), + Transform::Bucket(num_buckets)); +} + +std::shared_ptr Expressions::Year(std::string name) { + return std::make_shared(Ref(std::move(name)), Transform::Year()); +} + +std::shared_ptr Expressions::Month(std::string name) { + return std::make_shared(Ref(std::move(name)), Transform::Month()); +} + +std::shared_ptr Expressions::Day(std::string name) { + return std::make_shared(Ref(std::move(name)), Transform::Day()); +} + +std::shared_ptr Expressions::Hour(std::string name) { + return std::make_shared(Ref(std::move(name)), Transform::Hour()); +} + +std::shared_ptr Expressions::Truncate(std::string name, int32_t width) { + return std::make_shared(Ref(std::move(name)), + Transform::Truncate(width)); +} + +std::shared_ptr Expressions::Transform( + std::string name, std::shared_ptr<::iceberg::Transform> transform) { + return std::make_shared(Ref(std::move(name)), std::move(transform)); +} + +// Template implementations for unary predicates + +std::shared_ptr> Expressions::IsNull(std::string name) { + return IsNull(Ref(std::move(name))); +} + +template +std::shared_ptr> Expressions::IsNull( + std::shared_ptr> expr) { + return std::make_shared>(Expression::Operation::kIsNull, + std::move(expr)); +} + +std::shared_ptr> Expressions::NotNull(std::string name) { + return NotNull(Ref(std::move(name))); +} + +template +std::shared_ptr> Expressions::NotNull( + std::shared_ptr> expr) { + return std::make_shared>(Expression::Operation::kNotNull, + std::move(expr)); +} + +std::shared_ptr> Expressions::IsNaN(std::string name) { + return IsNaN(Ref(std::move(name))); +} + +template +std::shared_ptr> Expressions::IsNaN( + std::shared_ptr> expr) { + return std::make_shared>(Expression::Operation::kIsNan, + std::move(expr)); +} + +std::shared_ptr> Expressions::NotNaN(std::string name) { + return NotNaN(Ref(std::move(name))); +} + +template +std::shared_ptr> Expressions::NotNaN( + std::shared_ptr> expr) { + return std::make_shared>(Expression::Operation::kNotNan, + std::move(expr)); +} + +// Template implementations for comparison predicates + +std::shared_ptr> Expressions::LessThan(std::string name, + Literal value) { + return LessThan(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::LessThan( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kLt, + std::move(expr), std::move(value)); +} + +std::shared_ptr> Expressions::LessThanOrEqual( + std::string name, Literal value) { + return LessThanOrEqual(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::LessThanOrEqual( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kLtEq, + std::move(expr), std::move(value)); +} + +std::shared_ptr> Expressions::GreaterThan( + std::string name, Literal value) { + return GreaterThan(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::GreaterThan( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kGt, + std::move(expr), std::move(value)); +} + +std::shared_ptr> Expressions::GreaterThanOrEqual( + std::string name, Literal value) { + return GreaterThanOrEqual(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::GreaterThanOrEqual( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kGtEq, + std::move(expr), std::move(value)); +} + +std::shared_ptr> Expressions::Equal(std::string name, + Literal value) { + return Equal(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::Equal( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kEq, + std::move(expr), std::move(value)); +} + +std::shared_ptr> Expressions::NotEqual(std::string name, + Literal value) { + return NotEqual(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::NotEqual( + std::shared_ptr> expr, Literal value) { + return std::make_shared>(Expression::Operation::kNotEq, + std::move(expr), std::move(value)); +} + +// String predicates + +std::shared_ptr> Expressions::StartsWith( + std::string name, std::string value) { + return StartsWith(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::StartsWith( + std::shared_ptr> expr, std::string value) { + return std::make_shared>(Expression::Operation::kStartsWith, + std::move(expr), + Literal::String(std::move(value))); +} + +std::shared_ptr> Expressions::NotStartsWith( + std::string name, std::string value) { + return NotStartsWith(Ref(std::move(name)), std::move(value)); +} + +template +std::shared_ptr> Expressions::NotStartsWith( + std::shared_ptr> expr, std::string value) { + return std::make_shared>(Expression::Operation::kNotStartsWith, + std::move(expr), + Literal::String(std::move(value))); +} + +// Template implementations for set predicates + +std::shared_ptr> Expressions::In( + std::string name, std::vector values) { + return In(Ref(std::move(name)), std::move(values)); +} + +template +std::shared_ptr> Expressions::In(std::shared_ptr> expr, + std::vector values) { + return std::make_shared>(Expression::Operation::kIn, + std::move(expr), std::move(values)); +} + +std::shared_ptr> Expressions::In( + std::string name, std::initializer_list values) { + return In(Ref(std::move(name)), std::vector(values)); +} + +template +std::shared_ptr> Expressions::In( + std::shared_ptr> expr, std::initializer_list values) { + return In(std::move(expr), std::vector(values)); +} + +std::shared_ptr> Expressions::NotIn( + std::string name, std::vector values) { + return NotIn(Ref(std::move(name)), std::move(values)); +} + +template +std::shared_ptr> Expressions::NotIn( + std::shared_ptr> expr, std::vector values) { + return std::make_shared>(Expression::Operation::kNotIn, + std::move(expr), std::move(values)); +} + +std::shared_ptr> Expressions::NotIn( + std::string name, std::initializer_list values) { + return NotIn(Ref(std::move(name)), std::vector(values)); +} + +template +std::shared_ptr> Expressions::NotIn( + std::shared_ptr> expr, std::initializer_list values) { + return NotIn(expr, std::vector(values)); +} + +// Template implementations for generic predicate factory + +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::string name, Literal value) { + return std::make_shared>(op, Ref(std::move(name)), + std::move(value)); +} + +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::string name, std::vector values) { + return std::make_shared>(op, Ref(std::move(name)), + std::move(values)); +} + +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::string name, std::initializer_list values) { + return Predicate(op, name, std::vector(values)); +} + +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::string name) { + return std::make_shared>(op, Ref(std::move(name))); +} + +template +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::shared_ptr> expr, + std::vector values) { + return std::make_shared>(op, std::move(expr), std::move(values)); +} + +template +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::shared_ptr> expr, + std::initializer_list values) { + return Predicate(op, std::move(expr), std::vector(values)); +} + +template +std::shared_ptr> Expressions::Predicate( + Expression::Operation op, std::shared_ptr> expr) { + return std::make_shared>(op, std::move(expr)); +} + +// Constants + +std::shared_ptr Expressions::AlwaysTrue() { return True::Instance(); } + +std::shared_ptr Expressions::AlwaysFalse() { return False::Instance(); } + +// Utilities + +std::shared_ptr Expressions::Ref(std::string name) { + return std::make_shared(std::move(name)); +} + +Literal Expressions::Lit(Literal::Value value, std::shared_ptr type) { + throw IcebergError("Literal creation is not implemented"); +} + +} // namespace iceberg diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h new file mode 100644 index 00000000..7d9f9a1d --- /dev/null +++ b/src/iceberg/expression/expressions.h @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/expressions.h +/// Factory methods for creating expressions. + +#include +#include +#include +#include + +#include "iceberg/expression/literal.h" +#include "iceberg/expression/predicate.h" +#include "iceberg/expression/term.h" +#include "iceberg/iceberg_export.h" + +namespace iceberg { + +/// \brief Factory methods for creating expressions. +class ICEBERG_EXPORT Expressions { + public: + // Logical operations + + /// \brief Create an AND expression. + template + static std::shared_ptr And(std::shared_ptr left, + std::shared_ptr right, + Args&&... args) + requires std::conjunction_v>...> + { + if constexpr (sizeof...(args) == 0) { + if (left->op() == Expression::Operation::kFalse || + right->op() == Expression::Operation::kFalse) { + return AlwaysFalse(); + } + + if (left->op() == Expression::Operation::kTrue) { + return right; + } + + if (right->op() == Expression::Operation::kTrue) { + return left; + } + + return std::make_shared<::iceberg::And>(std::move(left), std::move(right)); + } else { + return And(And(std::move(left), std::move(right)), std::forward(args)...); + } + } + + /// \brief Create an OR expression. + template + static std::shared_ptr Or(std::shared_ptr left, + std::shared_ptr right, Args&&... args) + requires std::conjunction_v>...> + { + if constexpr (sizeof...(args) == 0) { + if (left->op() == Expression::Operation::kTrue || + right->op() == Expression::Operation::kTrue) { + return AlwaysTrue(); + } + + if (left->op() == Expression::Operation::kFalse) { + return right; + } + + if (right->op() == Expression::Operation::kFalse) { + return left; + } + + return std::make_shared<::iceberg::Or>(std::move(left), std::move(right)); + } else { + return Or(Or(std::move(left), std::move(right)), std::forward(args)...); + } + } + + // Transform functions + + /// \brief Create a bucket transform term. + static std::shared_ptr Bucket(std::string name, int32_t num_buckets); + + /// \brief Create a year transform term. + static std::shared_ptr Year(std::string name); + + /// \brief Create a month transform term. + static std::shared_ptr Month(std::string name); + + /// \brief Create a day transform term. + static std::shared_ptr Day(std::string name); + + /// \brief Create an hour transform term. + static std::shared_ptr Hour(std::string name); + + /// \brief Create a truncate transform term. + static std::shared_ptr Truncate(std::string name, int32_t width); + + /// \brief Create a transform expression. + static std::shared_ptr Transform( + std::string name, std::shared_ptr transform); + + // Unary predicates + + /// \brief Create an IS NULL predicate for a field name. + static std::shared_ptr> IsNull(std::string name); + + /// \brief Create an IS NULL predicate for an unbound term. + template + static std::shared_ptr> IsNull( + std::shared_ptr> expr); + + /// \brief Create a NOT NULL predicate for a field name. + static std::shared_ptr> NotNull(std::string name); + + /// \brief Create a NOT NULL predicate for an unbound term. + template + static std::shared_ptr> NotNull( + std::shared_ptr> expr); + + /// \brief Create an IS NaN predicate for a field name. + static std::shared_ptr> IsNaN(std::string name); + + /// \brief Create an IS NaN predicate for an unbound term. + template + static std::shared_ptr> IsNaN(std::shared_ptr> expr); + + /// \brief Create a NOT NaN predicate for a field name. + static std::shared_ptr> NotNaN(std::string name); + + /// \brief Create a NOT NaN predicate for an unbound term. + template + static std::shared_ptr> NotNaN( + std::shared_ptr> expr); + + // Comparison predicates + + /// \brief Create a less than predicate for a field name. + static std::shared_ptr> LessThan(std::string name, + Literal value); + + /// \brief Create a less than predicate for an unbound term. + template + static std::shared_ptr> LessThan( + std::shared_ptr> expr, Literal value); + + /// \brief Create a less than or equal predicate for a field name. + static std::shared_ptr> LessThanOrEqual( + std::string name, Literal value); + + /// \brief Create a less than or equal predicate for an unbound term. + template + static std::shared_ptr> LessThanOrEqual( + std::shared_ptr> expr, Literal value); + + /// \brief Create a greater than predicate for a field name. + static std::shared_ptr> GreaterThan(std::string name, + Literal value); + + /// \brief Create a greater than predicate for an unbound term. + template + static std::shared_ptr> GreaterThan( + std::shared_ptr> expr, Literal value); + + /// \brief Create a greater than or equal predicate for a field name. + static std::shared_ptr> GreaterThanOrEqual( + std::string name, Literal value); + + /// \brief Create a greater than or equal predicate for an unbound term. + template + static std::shared_ptr> GreaterThanOrEqual( + std::shared_ptr> expr, Literal value); + + /// \brief Create an equal predicate for a field name. + static std::shared_ptr> Equal(std::string name, + Literal value); + + /// \brief Create an equal predicate for an unbound term. + template + static std::shared_ptr> Equal(std::shared_ptr> expr, + Literal value); + + /// \brief Create a not equal predicate for a field name. + static std::shared_ptr> NotEqual(std::string name, + Literal value); + + /// \brief Create a not equal predicate for an unbound term. + template + static std::shared_ptr> NotEqual( + std::shared_ptr> expr, Literal value); + + // String predicates + + /// \brief Create a starts with predicate for a field name. + static std::shared_ptr> StartsWith(std::string name, + std::string value); + + /// \brief Create a starts with predicate for an unbound term. + template + static std::shared_ptr> StartsWith( + std::shared_ptr> expr, std::string value); + + /// \brief Create a not starts with predicate for a field name. + static std::shared_ptr> NotStartsWith( + std::string name, std::string value); + + /// \brief Create a not starts with predicate for an unbound term. + template + static std::shared_ptr> NotStartsWith( + std::shared_ptr> expr, std::string value); + + // Set predicates + + /// \brief Create an IN predicate for a field name. + static std::shared_ptr> In( + std::string name, std::vector values); + + /// \brief Create an IN predicate for an unbound term. + template + static std::shared_ptr> In(std::shared_ptr> expr, + std::vector values); + + /// \brief Create an IN predicate for a field name with initializer list. + static std::shared_ptr> In( + std::string name, std::initializer_list values); + + /// \brief Create an IN predicate for an unbound term with initializer list. + template + static std::shared_ptr> In(std::shared_ptr> expr, + std::initializer_list values); + + /// \brief Create a NOT IN predicate for a field name. + static std::shared_ptr> NotIn( + std::string name, std::vector values); + + /// \brief Create a NOT IN predicate for an unbound term. + template + static std::shared_ptr> NotIn(std::shared_ptr> expr, + std::vector values); + + /// \brief Create a NOT IN predicate for a field name with initializer list. + static std::shared_ptr> NotIn( + std::string name, std::initializer_list values); + + /// \brief Create a NOT IN predicate for an unbound term with initializer list. + template + static std::shared_ptr> NotIn( + std::shared_ptr> expr, std::initializer_list values); + + // Generic predicate factory + + /// \brief Create a predicate with operation and single value. + static std::shared_ptr> Predicate( + Expression::Operation op, std::string name, Literal value); + + /// \brief Create a predicate with operation and multiple values. + static std::shared_ptr> Predicate( + Expression::Operation op, std::string name, std::vector values); + + /// \brief Create a predicate with operation and multiple values. + static std::shared_ptr> Predicate( + Expression::Operation op, std::string name, std::initializer_list values); + + /// \brief Create a unary predicate (no values). + static std::shared_ptr> Predicate( + Expression::Operation op, std::string name); + + /// \brief Create a predicate for unbound term with multiple values. + template + static std::shared_ptr> Predicate( + Expression::Operation op, std::shared_ptr> expr, + std::vector values); + + /// \brief Create a predicate with operation and multiple values. + template + static std::shared_ptr> Predicate( + Expression::Operation op, std::shared_ptr> expr, + std::initializer_list values); + + /// \brief Create a unary predicate for unbound term. + template + static std::shared_ptr> Predicate( + Expression::Operation op, std::shared_ptr> expr); + + // Constants + + /// \brief Return the always true expression. + static std::shared_ptr AlwaysTrue(); + + /// \brief Return the always false expression. + static std::shared_ptr AlwaysFalse(); + + // Utilities + + /// \brief Create a named reference to a field. + static std::shared_ptr Ref(std::string name); + + /// \brief Create a literal from a value. + static Literal Lit(Literal::Value value, std::shared_ptr type); +}; + +} // namespace iceberg diff --git a/src/iceberg/expression/literal.h b/src/iceberg/expression/literal.h index 4c880ef3..1c16b8ed 100644 --- a/src/iceberg/expression/literal.h +++ b/src/iceberg/expression/literal.h @@ -27,11 +27,12 @@ #include "iceberg/result.h" #include "iceberg/type.h" +#include "iceberg/util/formattable.h" namespace iceberg { /// \brief Literal is a literal value that is associated with a primitive type. -class ICEBERG_EXPORT Literal { +class ICEBERG_EXPORT Literal : public util::Formattable { public: /// \brief Sentinel value to indicate that the literal value is below the valid range /// of a specific primitive type. It can happen when casting a literal to a narrower @@ -138,7 +139,7 @@ class ICEBERG_EXPORT Literal { /// \return true if this literal is null, false otherwise bool IsNull() const; - std::string ToString() const; + std::string ToString() const override; private: Literal(Value value, std::shared_ptr type); diff --git a/src/iceberg/expression/predicate.cc b/src/iceberg/expression/predicate.cc new file mode 100644 index 00000000..144ef2b0 --- /dev/null +++ b/src/iceberg/expression/predicate.cc @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/predicate.h" + +#include +#include + +#include "iceberg/exception.h" +#include "iceberg/expression/expressions.h" +#include "iceberg/expression/literal.h" +#include "iceberg/result.h" +#include "iceberg/type.h" +#include "iceberg/util/checked_cast.h" +#include "iceberg/util/formatter_internal.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +// Predicate template implementations +template +Predicate::Predicate(Expression::Operation op, std::shared_ptr term) + : operation_(op), term_(std::move(term)) {} + +template +Predicate::~Predicate() = default; + +// UnboundPredicate template implementations +template +UnboundPredicate::UnboundPredicate(Expression::Operation op, + std::shared_ptr> term) + : BASE(op, std::move(term)) {} + +template +UnboundPredicate::UnboundPredicate(Expression::Operation op, + std::shared_ptr> term, Literal value) + : BASE(op, std::move(term)), values_{std::move(value)} {} + +template +UnboundPredicate::UnboundPredicate(Expression::Operation op, + std::shared_ptr> term, + std::vector values) + : BASE(op, std::move(term)), values_(std::move(values)) {} + +template +UnboundPredicate::~UnboundPredicate() = default; + +namespace {} + +template +std::string UnboundPredicate::ToString() const { + auto invalid_predicate_string = [](Expression::Operation op) { + return std::format("Invalid predicate: operation = {}", op); + }; + + const auto& term = *BASE::term(); + const auto op = BASE::op(); + + switch (op) { + case Expression::Operation::kIsNull: + return std::format("is_null({})", term); + case Expression::Operation::kNotNull: + return std::format("not_null({})", term); + case Expression::Operation::kIsNan: + return std::format("is_nan({})", term); + case Expression::Operation::kNotNan: + return std::format("not_nan({})", term); + case Expression::Operation::kLt: + return values_.size() == 1 ? std::format("{} < {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kLtEq: + return values_.size() == 1 ? std::format("{} <= {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kGt: + return values_.size() == 1 ? std::format("{} > {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kGtEq: + return values_.size() == 1 ? std::format("{} >= {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kEq: + return values_.size() == 1 ? std::format("{} == {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kNotEq: + return values_.size() == 1 ? std::format("{} != {}", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kStartsWith: + return values_.size() == 1 ? std::format("{} startsWith \"{}\"", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kNotStartsWith: + return values_.size() == 1 + ? std::format("{} notStartsWith \"{}\"", term, values_[0]) + : invalid_predicate_string(op); + case Expression::Operation::kIn: + return std::format("{} in {}", term, values_); + case Expression::Operation::kNotIn: + return std::format("{} not in {}", term, values_); + default: + return invalid_predicate_string(op); + } +} + +template +Result> UnboundPredicate::Negate() const { + ICEBERG_ASSIGN_OR_RAISE(auto negated_op, ::iceberg::Negate(BASE::op())); + return std::make_shared(negated_op, BASE::term(), values_); +} + +template +Result> UnboundPredicate::Bind(const Schema& schema, + bool case_sensitive) const { + ICEBERG_ASSIGN_OR_RAISE(auto bound_term, BASE::term()->Bind(schema, case_sensitive)); + + if (values_.empty()) { + return BindUnaryOperation(std::move(bound_term)); + } + + if (BASE::op() == Expression::Operation::kIn || + BASE::op() == Expression::Operation::kNotIn) { + return BindInOperation(std::move(bound_term)); + } + + return BindLiteralOperation(std::move(bound_term)); +} + +namespace { + +bool IsFloatingType(TypeId type) { + return type == TypeId::kFloat || type == TypeId::kDouble; +} + +} // namespace + +template +Result> UnboundPredicate::BindUnaryOperation( + std::shared_ptr bound_term) const { + switch (BASE::op()) { + case Expression::Operation::kIsNull: + if (!bound_term->MayProduceNull()) { + return Expressions::AlwaysFalse(); + } + // TODO(gangwu): deal with UnknownType + return std::make_shared(Expression::Operation::kIsNull, + std::move(bound_term)); + case Expression::Operation::kNotNull: + if (!bound_term->MayProduceNull()) { + return Expressions::AlwaysTrue(); + } + return std::make_shared(Expression::Operation::kNotNull, + std::move(bound_term)); + case Expression::Operation::kIsNan: + case Expression::Operation::kNotNan: + if (!IsFloatingType(bound_term->type()->type_id())) { + return InvalidExpression("{} cannot be used with a non-floating-point column", + BASE::op()); + } + return std::make_shared(BASE::op(), std::move(bound_term)); + default: + return InvalidExpression("Operation must be IS_NULL, NOT_NULL, IS_NAN, or NOT_NAN"); + } +} + +template +Result> UnboundPredicate::BindLiteralOperation( + std::shared_ptr bound_term) const { + if (BASE::op() == Expression::Operation::kStartsWith || + BASE::op() == Expression::Operation::kNotStartsWith) { + if (bound_term->type()->type_id() != TypeId::kString) { + return InvalidExpression( + "Term for STARTS_WITH or NOT_STARTS_WITH must produce a string: {}: {}", + *bound_term, *bound_term->type()); + } + } + + if (values_.size() != 1) { + return InvalidExpression("Literal operation requires a single value but got {}", + values_.size()); + } + + ICEBERG_ASSIGN_OR_RAISE(auto literal, + values_[0].CastTo(internal::checked_pointer_cast( + bound_term->type()))); + + if (literal.IsNull()) { + return InvalidExpression("Invalid value for conversion to type {}: {} ({})", + *bound_term->type(), literal.ToString(), *literal.type()); + } else if (literal.IsAboveMax()) { + switch (BASE::op()) { + case Expression::Operation::kLt: + case Expression::Operation::kLtEq: + case Expression::Operation::kNotEq: + return Expressions::AlwaysTrue(); + case Expression::Operation::kGt: + case Expression::Operation::kGtEq: + case Expression::Operation::kEq: + return Expressions::AlwaysFalse(); + default: + break; + } + } else if (literal.IsBelowMin()) { + switch (BASE::op()) { + case Expression::Operation::kGt: + case Expression::Operation::kGtEq: + case Expression::Operation::kNotEq: + return Expressions::AlwaysTrue(); + case Expression::Operation::kLt: + case Expression::Operation::kLtEq: + case Expression::Operation::kEq: + return Expressions::AlwaysFalse(); + default: + break; + } + } + + // TODO(gangwu): translate truncate(col) == value to startsWith(value) + return std::make_shared(BASE::op(), std::move(bound_term), + std::move(literal)); +} + +template +Result> UnboundPredicate::BindInOperation( + std::shared_ptr bound_term) const { + std::vector converted_literals; + for (const auto& literal : values_) { + auto primitive_type = + internal::checked_pointer_cast(bound_term->type()); + ICEBERG_ASSIGN_OR_RAISE(auto converted, literal.CastTo(primitive_type)); + if (converted.IsNull()) { + return InvalidExpression("Invalid value for conversion to type {}: {} ({})", + *bound_term->type(), literal.ToString(), *literal.type()); + } + // Filter out literals that are out of range after conversion. + if (!converted.IsBelowMin() && !converted.IsAboveMax()) { + converted_literals.push_back(std::move(converted)); + } + } + + // If no valid literals remain after conversion and filtering + if (converted_literals.empty()) { + switch (BASE::op()) { + case Expression::Operation::kIn: + return Expressions::AlwaysFalse(); + case Expression::Operation::kNotIn: + return Expressions::AlwaysTrue(); + default: + return InvalidExpression("Operation must be IN or NOT_IN"); + } + } + + // If only one unique literal remains, convert to equality/inequality + if (converted_literals.size() == 1) { + const auto& single_literal = converted_literals[0]; + switch (BASE::op()) { + case Expression::Operation::kIn: + return std::make_shared( + Expression::Operation::kEq, std::move(bound_term), single_literal); + case Expression::Operation::kNotIn: + return std::make_shared( + Expression::Operation::kNotEq, std::move(bound_term), single_literal); + default: + return InvalidExpression("Operation must be IN or NOT_IN"); + } + } + + // Multiple literals - create a set predicate + return std::make_shared( + BASE::op(), std::move(bound_term), std::span(converted_literals)); +} + +// BoundPredicate implementation +BoundPredicate::BoundPredicate(Expression::Operation op, std::shared_ptr term) + : Predicate(op, std::move(term)) {} + +BoundPredicate::~BoundPredicate() = default; + +Result BoundPredicate::Evaluate(const StructLike& data) const { + ICEBERG_ASSIGN_OR_RAISE(auto eval_result, term_->Evaluate(data)); + ICEBERG_ASSIGN_OR_RAISE(auto test_result, Test(eval_result)); + return Literal::Value{test_result}; +} + +// BoundUnaryPredicate implementation +BoundUnaryPredicate::BoundUnaryPredicate(Expression::Operation op, + std::shared_ptr term) + : BoundPredicate(op, std::move(term)) {} + +BoundUnaryPredicate::~BoundUnaryPredicate() = default; + +Result BoundUnaryPredicate::Test(const Literal::Value& value) const { + return NotImplemented("BoundUnaryPredicate::Test not implemented"); +} + +bool BoundUnaryPredicate::Equals(const Expression& other) const { + throw IcebergError("BoundUnaryPredicate::Equals not implemented"); +} + +std::string BoundUnaryPredicate::ToString() const { + switch (op()) { + case Expression::Operation::kIsNull: + return std::format("is_null({})", *term()); + case Expression::Operation::kNotNull: + return std::format("not_null({})", *term()); + case Expression::Operation::kIsNan: + return std::format("is_nan({})", *term()); + case Expression::Operation::kNotNan: + return std::format("not_nan({})", *term()); + default: + return std::format("Invalid unary predicate: operation = {}", op()); + } +} + +// BoundLiteralPredicate implementation +BoundLiteralPredicate::BoundLiteralPredicate(Expression::Operation op, + std::shared_ptr term, + Literal literal) + : BoundPredicate(op, std::move(term)), literal_(std::move(literal)) {} + +BoundLiteralPredicate::~BoundLiteralPredicate() = default; + +Result BoundLiteralPredicate::Test(const Literal::Value& value) const { + return NotImplemented("BoundLiteralPredicate::Test not implemented"); +} + +bool BoundLiteralPredicate::Equals(const Expression& other) const { + throw IcebergError("BoundLiteralPredicate::Equals not implemented"); +} + +std::string BoundLiteralPredicate::ToString() const { + switch (op()) { + case Expression::Operation::kLt: + return std::format("{} < {}", *term(), literal()); + case Expression::Operation::kLtEq: + return std::format("{} <= {}", *term(), literal()); + case Expression::Operation::kGt: + return std::format("{} > {}", *term(), literal()); + case Expression::Operation::kGtEq: + return std::format("{} >= {}", *term(), literal()); + case Expression::Operation::kEq: + return std::format("{} == {}", *term(), literal()); + case Expression::Operation::kNotEq: + return std::format("{} != {}", *term(), literal()); + case Expression::Operation::kStartsWith: + return std::format("{} startsWith \"{}\"", *term(), literal()); + case Expression::Operation::kNotStartsWith: + return std::format("{} notStartsWith \"{}\"", *term(), literal()); + case Expression::Operation::kIn: + return std::format("{} in ({})", *term(), literal()); + case Expression::Operation::kNotIn: + return std::format("{} not in ({})", *term(), literal()); + default: + return std::format("Invalid literal predicate: operation = {}", op()); + } +} + +// BoundSetPredicate implementation +BoundSetPredicate::BoundSetPredicate(Expression::Operation op, + std::shared_ptr term, + std::span literals) + : BoundPredicate(op, std::move(term)) { + for (const auto& literal : literals) { + ICEBERG_DCHECK((*literal.type() == *term_->type()), + "Literal type does not match term type"); + value_set_.push_back(literal.value()); + } +} + +BoundSetPredicate::~BoundSetPredicate() = default; + +Result BoundSetPredicate::Test(const Literal::Value& value) const { + return NotImplemented("BoundSetPredicate::Test not implemented"); +} + +bool BoundSetPredicate::Equals(const Expression& other) const { + throw IcebergError("BoundSetPredicate::Equals not implemented"); +} + +std::string BoundSetPredicate::ToString() const { + // TODO(gangwu): Literal::Value does not have std::format support. + throw IcebergError("BoundSetPredicate::ToString not implemented"); +} + +// Explicit template instantiations +template class Predicate>; +template class Predicate>; +template class Predicate; + +template class UnboundPredicate; +template class UnboundPredicate; + +} // namespace iceberg diff --git a/src/iceberg/expression/predicate.h b/src/iceberg/expression/predicate.h new file mode 100644 index 00000000..3c40af69 --- /dev/null +++ b/src/iceberg/expression/predicate.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/predicate.h +/// Predicate interface for boolean expressions that test terms. + +#include + +#include "iceberg/expression/expression.h" +#include "iceberg/expression/term.h" + +namespace iceberg { + +template +concept TermType = std::derived_from; + +/// \brief A predicate is a boolean expression that tests a term against some criteria. +/// +/// \tparam TermType The type of the term being tested +template +class ICEBERG_EXPORT Predicate : public Expression { + public: + /// \brief Create a predicate with an operation and term. + /// + /// \param op The operation this predicate performs + /// \param term The term this predicate tests + Predicate(Expression::Operation op, std::shared_ptr term); + + ~Predicate() override; + + Expression::Operation op() const override { return operation_; } + + /// \brief Returns the term this predicate tests. + const std::shared_ptr& term() const { return term_; } + + protected: + Expression::Operation operation_; + std::shared_ptr term_; +}; + +/// \brief Unbound predicates contain unbound terms and must be bound to a concrete schema +/// before they can be evaluated. +/// +/// \tparam B The bound type this predicate produces when binding is successful +template +class ICEBERG_EXPORT UnboundPredicate : public Predicate>, + public Unbound { + using BASE = Predicate>; + + public: + UnboundPredicate(Expression::Operation op, std::shared_ptr> term); + UnboundPredicate(Expression::Operation op, std::shared_ptr> term, + Literal value); + UnboundPredicate(Expression::Operation op, std::shared_ptr> term, + std::vector values); + + ~UnboundPredicate() override; + + std::shared_ptr reference() override { + return BASE::term()->reference(); + } + + std::string ToString() const override; + + /// \brief Bind this UnboundPredicate. + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + Result> Negate() const override; + + private: + Result> BindUnaryOperation( + std::shared_ptr bound_term) const; + Result> BindLiteralOperation( + std::shared_ptr bound_term) const; + Result> BindInOperation( + std::shared_ptr bound_term) const; + + private: + std::vector values_; +}; + +/// \brief Bound predicates contain bound terms and can be evaluated. +class ICEBERG_EXPORT BoundPredicate : public Predicate, public Bound { + public: + BoundPredicate(Expression::Operation op, std::shared_ptr term); + + ~BoundPredicate() override; + + using Predicate::op; + + using Predicate::term; + + std::shared_ptr reference() override { return term_->reference(); } + + Result Evaluate(const StructLike& data) const override; + + /// \brief Test a value against this predicate. + /// + /// \param value The value to test + /// \return true if the predicate passes, false otherwise + virtual Result Test(const Literal::Value& value) const = 0; + + enum class Kind : int8_t { + // A unary predicate (tests for null, not-null, etc.). + kUnary = 0, + // A literal predicate (compares against a literal). + kLiteral, + // A set predicate (tests membership in a set). + kSet, + }; + + /// \brief Returns the kind of this bound predicate. + virtual Kind kind() const = 0; +}; + +/// \brief Bound unary predicate (null, not-null, etc.). +class ICEBERG_EXPORT BoundUnaryPredicate : public BoundPredicate { + public: + /// \brief Create a bound unary predicate. + /// + /// \param op The unary operation (kIsNull, kNotNull, kIsNan, kNotNan) + /// \param term The bound term to test + BoundUnaryPredicate(Expression::Operation op, std::shared_ptr term); + + ~BoundUnaryPredicate() override; + + Result Test(const Literal::Value& value) const override; + + Kind kind() const override { return Kind::kUnary; } + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; +}; + +/// \brief Bound literal predicate (comparison against a single value). +class ICEBERG_EXPORT BoundLiteralPredicate : public BoundPredicate { + public: + /// \brief Create a bound literal predicate. + /// + /// \param op The comparison operation (kLt, kLtEq, kGt, kGtEq, kEq, kNotEq) + /// \param term The bound term to compare + /// \param literal The literal value to compare against + BoundLiteralPredicate(Expression::Operation op, std::shared_ptr term, + Literal literal); + + ~BoundLiteralPredicate() override; + + /// \brief Returns the literal being compared against. + const Literal& literal() const { return literal_; } + + Result Test(const Literal::Value& value) const override; + + Kind kind() const override { return Kind::kLiteral; } + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; + + private: + Literal literal_; +}; + +/// \brief Bound set predicate (membership testing against a set of values). +class ICEBERG_EXPORT BoundSetPredicate : public BoundPredicate { + public: + /// \brief Create a bound set predicate. + /// + /// \param op The set operation (kIn, kNotIn) + /// \param term The bound term to test for membership + /// \param literals The set of literal values to test against + BoundSetPredicate(Expression::Operation op, std::shared_ptr term, + std::span literals); + + ~BoundSetPredicate() override; + + /// \brief Returns the set of literals to test against. + const std::vector& literal_set() const { return value_set_; } + + Result Test(const Literal::Value& value) const override; + + Kind kind() const override { return Kind::kSet; } + + std::string ToString() const override; + + bool Equals(const Expression& other) const override; + + private: + /// FIXME: Literal::Value does not have hash support. We need to add this + /// and replace the vector with a unordered_set. + std::vector value_set_; +}; + +} // namespace iceberg diff --git a/src/iceberg/expression/term.cc b/src/iceberg/expression/term.cc new file mode 100644 index 00000000..5bb9b71d --- /dev/null +++ b/src/iceberg/expression/term.cc @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/term.h" + +#include + +#include "iceberg/exception.h" +#include "iceberg/result.h" +#include "iceberg/schema.h" +#include "iceberg/transform.h" +#include "iceberg/util/checked_cast.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +Bound::~Bound() = default; + +BoundTerm::~BoundTerm() = default; + +Reference::~Reference() = default; + +template +Result> Unbound::Bind(const Schema& schema) const { + return Bind(schema, /*case_sensitive=*/true); +} + +// NamedReference implementation +NamedReference::NamedReference(std::string field_name) + : field_name_(std::move(field_name)) {} + +NamedReference::~NamedReference() = default; + +Result> NamedReference::Bind(const Schema& schema, + bool case_sensitive) const { + ICEBERG_ASSIGN_OR_RAISE(auto field_opt, + schema.GetFieldByName(field_name_, case_sensitive)); + if (!field_opt.has_value()) { + return InvalidExpression("Cannot find field '{}' in struct: {}", field_name_, + schema.ToString()); + } + return std::make_shared(field_opt.value().get()); +} + +std::string NamedReference::ToString() const { + return std::format("ref(name=\"{}\")", field_name_); +} + +// BoundReference implementation +BoundReference::BoundReference(SchemaField field) : field_(std::move(field)) {} + +BoundReference::~BoundReference() = default; + +std::string BoundReference::ToString() const { + return std::format("ref(id={}, type={})", field_.field_id(), field_.type()->ToString()); +} + +Result BoundReference::Evaluate(const StructLike& data) const { + return NotImplemented("BoundReference::Evaluate(StructLike) not implemented"); +} + +bool BoundReference::Equals(const BoundTerm& other) const { + if (other.kind() != Term::Kind::kReference) { + return false; + } + + const auto& other_ref = internal::checked_cast(other); + return field_.field_id() == other_ref.field_.field_id() && + field_.optional() == other_ref.field_.optional() && + *field_.type() == *other_ref.field_.type(); +} + +// UnboundTransform implementation +UnboundTransform::UnboundTransform(std::shared_ptr ref, + std::shared_ptr transform) + : ref_(std::move(ref)), transform_(std::move(transform)) {} + +UnboundTransform::~UnboundTransform() = default; + +std::string UnboundTransform::ToString() const { + return std::format("{}({})", transform_->ToString(), ref_->ToString()); +} + +Result> UnboundTransform::Bind( + const Schema& schema, bool case_sensitive) const { + ICEBERG_ASSIGN_OR_RAISE(auto bound_ref, ref_->Bind(schema, case_sensitive)); + ICEBERG_ASSIGN_OR_RAISE(auto transform_func, transform_->Bind(bound_ref->type())); + return std::make_shared(std::move(bound_ref), transform_, + std::move(transform_func)); +} + +// BoundTransform implementation +BoundTransform::BoundTransform(std::shared_ptr ref, + std::shared_ptr transform, + std::shared_ptr transform_func) + : ref_(std::move(ref)), + transform_(std::move(transform)), + transform_func_(std::move(transform_func)) {} + +BoundTransform::~BoundTransform() = default; + +std::string BoundTransform::ToString() const { + return std::format("{}({})", transform_->ToString(), ref_->ToString()); +} + +Result BoundTransform::Evaluate(const StructLike& data) const { + throw IcebergError("BoundTransform::Evaluate(StructLike) not implemented"); +} + +bool BoundTransform::MayProduceNull() const { + // transforms must produce null for null input values + // transforms may produce null for non-null inputs when not order-preserving + // FIXME: add Transform::is_order_preserving() + return ref_->MayProduceNull(); // || !transform_->is_order_preserving(); +} + +std::shared_ptr BoundTransform::type() const { + return transform_func_->ResultType(); +} + +bool BoundTransform::Equals(const BoundTerm& other) const { + if (other.kind() == Term::Kind::kTransform) { + const auto& other_transform = internal::checked_cast(other); + return *ref_ == *other_transform.ref_ && *transform_ == *other_transform.transform_; + } + + if (transform_->transform_type() == TransformType::kIdentity && + other.kind() == Term::Kind::kReference) { + return *ref_ == other; + } + + return false; +} + +// Explicit template instantiations +template Result> Unbound::Bind( + const Schema& schema) const; +template Result> Unbound::Bind( + const Schema& schema) const; + +} // namespace iceberg diff --git a/src/iceberg/expression/term.h b/src/iceberg/expression/term.h new file mode 100644 index 00000000..2911dfaa --- /dev/null +++ b/src/iceberg/expression/term.h @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/term.h +/// Term interface for Iceberg expressions - represents values that can be evaluated. + +#include +#include +#include + +#include "iceberg/expression/literal.h" +#include "iceberg/type_fwd.h" +#include "iceberg/util/formattable.h" + +namespace iceberg { + +// TODO(gangwu): add a struct-like interface to wrap a row of data from ArrowArray or +// structs like ManifestFile and ManifestEntry to facilitate generailization of the +// evaluation of expressions on top of different data structures. +class StructLike; + +/// \brief A term is an expression node that produces a typed value when evaluated. +class ICEBERG_EXPORT Term : public util::Formattable { + public: + enum class Kind : uint8_t { kReference = 0, kTransform, kExtract }; + + /// \brief Returns the kind of this term. + virtual Kind kind() const = 0; +}; + +/// \brief Interface for unbound expressions that need schema binding. +/// +/// Unbound expressions contain string-based references that must be resolved +/// against a concrete schema to produce bound expressions that can be evaluated. +/// +/// \tparam B The bound type this term produces when binding is successful +template +class ICEBERG_EXPORT Unbound { + public: + /// \brief Bind this expression to a concrete schema. + /// + /// \param schema The schema to bind against + /// \param case_sensitive Whether field name matching should be case sensitive + /// \return A bound expression or an error if binding fails + virtual Result> Bind(const Schema& schema, + bool case_sensitive) const = 0; + + /// \brief Overloaded Bind method that uses case-sensitive matching by default. + Result> Bind(const Schema& schema) const; + + /// \brief Returns the underlying named reference for this unbound term. + virtual std::shared_ptr reference() = 0; +}; + +/// \brief Interface for bound expressions that can be evaluated. +/// +/// Bound expressions have been resolved against a concrete schema and contain +/// all necessary information to evaluate against data structures. +class ICEBERG_EXPORT Bound { + public: + virtual ~Bound(); + + /// \brief Evaluate this expression against a row-based data. + virtual Result Evaluate(const StructLike& data) const = 0; + + /// \brief Returns the underlying bound reference for this term. + virtual std::shared_ptr reference() = 0; +}; + +/// \brief Base class for unbound terms. +/// +/// \tparam B The bound type this term produces when binding is successful. +template +class ICEBERG_EXPORT UnboundTerm : public Unbound, public Term { + public: + using BoundType = B; +}; + +/// \brief Base class for bound terms. +class ICEBERG_EXPORT BoundTerm : public Bound, public Term { + public: + ~BoundTerm() override; + + /// \brief Returns the type produced by this term. + virtual std::shared_ptr type() const = 0; + + /// \brief Returns whether this term may produce null values. + virtual bool MayProduceNull() const = 0; + + // TODO(gangwu): add a comparator function to Literal and BoundTerm. + + /// \brief Returns whether this term is equivalent to another. + /// + /// Two terms are equivalent if they produce the same values when evaluated. + /// + /// \param other Another bound term to compare against + /// \return true if the terms are equivalent, false otherwise + virtual bool Equals(const BoundTerm& other) const = 0; + + friend bool operator==(const BoundTerm& lhs, const BoundTerm& rhs) { + return lhs.Equals(rhs); + } +}; + +/// \brief A reference represents a named field in an expression. +class ICEBERG_EXPORT Reference { + public: + virtual ~Reference(); + + /// \brief Returns the name of the referenced field. + virtual std::string_view name() const = 0; +}; + +/// \brief A reference to an unbound named field. +class ICEBERG_EXPORT NamedReference + : public Reference, + public UnboundTerm, + public std::enable_shared_from_this { + public: + /// \brief Create a named reference to a field. + /// + /// \param field_name The name of the field to reference + explicit NamedReference(std::string field_name); + + ~NamedReference() override; + + std::string_view name() const override { return field_name_; } + + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + std::shared_ptr reference() override { return shared_from_this(); } + + std::string ToString() const override; + + Kind kind() const override { return Kind::kReference; } + + private: + std::string field_name_; +}; + +/// \brief A reference to a bound field. +class ICEBERG_EXPORT BoundReference + : public Reference, + public BoundTerm, + public std::enable_shared_from_this { + public: + /// \brief Create a bound reference. + /// + /// \param field The schema field + explicit BoundReference(SchemaField field); + + ~BoundReference() override; + + const SchemaField& field() const { return field_; } + + std::string_view name() const override { return field_.name(); } + + std::string ToString() const override; + + Result Evaluate(const StructLike& data) const override; + + std::shared_ptr reference() override { return shared_from_this(); } + + std::shared_ptr type() const override { return field_.type(); } + + bool MayProduceNull() const override { return field_.optional(); } + + bool Equals(const BoundTerm& other) const override; + + Kind kind() const override { return Kind::kReference; } + + private: + SchemaField field_; +}; + +/// \brief An unbound transform expression. +class ICEBERG_EXPORT UnboundTransform : public UnboundTerm { + public: + /// \brief Create an unbound transform. + /// + /// \param ref The term to apply the transformation to + /// \param transform The transformation function to apply + UnboundTransform(std::shared_ptr ref, + std::shared_ptr transform); + + ~UnboundTransform() override; + + std::string ToString() const override; + + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + std::shared_ptr reference() override { return ref_; } + + const std::shared_ptr& transform() const { return transform_; } + + Kind kind() const override { return Kind::kTransform; } + + private: + std::shared_ptr ref_; + std::shared_ptr transform_; +}; + +/// \brief A bound transform expression. +class ICEBERG_EXPORT BoundTransform : public BoundTerm { + public: + /// \brief Create a bound transform. + /// + /// \param ref The bound term to apply the transformation to + /// \param transform The transform to apply + /// \param transform_func The bound transform function to apply + BoundTransform(std::shared_ptr ref, + std::shared_ptr transform, + std::shared_ptr transform_func); + + ~BoundTransform() override; + + std::string ToString() const override; + + Result Evaluate(const StructLike& data) const override; + + std::shared_ptr reference() override { return ref_; } + + std::shared_ptr type() const override; + + bool MayProduceNull() const override; + + bool Equals(const BoundTerm& other) const override; + + const std::shared_ptr& transform() const { return transform_; } + + Kind kind() const override { return Kind::kTransform; } + + private: + std::shared_ptr ref_; + std::shared_ptr transform_; + std::shared_ptr transform_func_; +}; + +} // namespace iceberg diff --git a/src/iceberg/test/CMakeLists.txt b/src/iceberg/test/CMakeLists.txt index cb3b6082..ca31682f 100644 --- a/src/iceberg/test/CMakeLists.txt +++ b/src/iceberg/test/CMakeLists.txt @@ -74,7 +74,11 @@ add_iceberg_test(table_test table_test.cc schema_json_test.cc) -add_iceberg_test(expression_test SOURCES expression_test.cc literal_test.cc) +add_iceberg_test(expression_test + SOURCES + expression_test.cc + literal_test.cc + predicate_test.cc) add_iceberg_test(json_serde_test SOURCES diff --git a/src/iceberg/test/expression_test.cc b/src/iceberg/test/expression_test.cc index c14c7d9a..8baaf56f 100644 --- a/src/iceberg/test/expression_test.cc +++ b/src/iceberg/test/expression_test.cc @@ -23,12 +23,16 @@ #include +#include "matchers.h" + namespace iceberg { TEST(TrueFalseTest, Basic) { // Test negation of False returns True auto false_instance = False::Instance(); - auto negated = false_instance->Negate(); + auto negated_result = false_instance->Negate(); + ASSERT_THAT(negated_result, IsOk()); + auto negated = negated_result.value(); // Check that negated expression is True EXPECT_EQ(negated->op(), Expression::Operation::kTrue); @@ -36,7 +40,9 @@ TEST(TrueFalseTest, Basic) { // Test negation of True returns false auto true_instance = True::Instance(); - negated = true_instance->Negate(); + negated_result = true_instance->Negate(); + ASSERT_THAT(negated_result, IsOk()); + negated = negated_result.value(); // Check that negated expression is False EXPECT_EQ(negated->op(), Expression::Operation::kFalse); @@ -77,7 +83,9 @@ TEST(ORTest, Negation) { auto false_expr = False::Instance(); auto or_expr = std::make_shared(true_expr, false_expr); - auto negated_or = or_expr->Negate(); + auto negated_or_result = or_expr->Negate(); + ASSERT_THAT(negated_or_result, IsOk()); + auto negated_or = negated_or_result.value(); // Should become AND expression EXPECT_EQ(negated_or->op(), Expression::Operation::kAnd); @@ -112,7 +120,9 @@ TEST(ANDTest, Negation) { auto false_expr = False::Instance(); auto and_expr = std::make_shared(true_expr, false_expr); - auto negated_and = and_expr->Negate(); + auto negated_and_result = and_expr->Negate(); + ASSERT_THAT(negated_and_result, IsOk()); + auto negated_and = negated_and_result.value(); // Should become OR expression EXPECT_EQ(negated_and->op(), Expression::Operation::kOr); @@ -141,7 +151,7 @@ TEST(ANDTest, Equals) { EXPECT_FALSE(and_expr1->Equals(*or_expr)); } -TEST(ExpressionTest, BaseClassNegateThrowsException) { +TEST(ExpressionTest, BaseClassNegateErrorOut) { // Create a mock expression that doesn't override Negate() class MockExpression : public Expression { public: @@ -151,7 +161,8 @@ TEST(ExpressionTest, BaseClassNegateThrowsException) { auto mock_expr = std::make_shared(); - // Should throw IcebergError when calling Negate() on base class - EXPECT_THROW(mock_expr->Negate(), IcebergError); + // Should return NotSupported error when calling Negate() on base class + auto negate_result = mock_expr->Negate(); + EXPECT_THAT(negate_result, IsError(ErrorKind::kNotSupported)); } } // namespace iceberg diff --git a/src/iceberg/test/predicate_test.cc b/src/iceberg/test/predicate_test.cc new file mode 100644 index 00000000..f34df47d --- /dev/null +++ b/src/iceberg/test/predicate_test.cc @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/expressions.h" +#include "iceberg/schema.h" +#include "iceberg/type.h" +#include "matchers.h" + +namespace iceberg { + +class PredicateTest : public ::testing::Test { + protected: + void SetUp() override { + // Create a simple test schema with various field types + schema_ = std::make_shared( + std::vector{SchemaField::MakeRequired(1, "id", int64()), + SchemaField::MakeOptional(2, "name", string()), + SchemaField::MakeRequired(3, "age", int32()), + SchemaField::MakeOptional(4, "salary", float64()), + SchemaField::MakeRequired(5, "active", boolean())}, + /*schema_id=*/0); + } + + std::shared_ptr schema_; +}; + +TEST_F(PredicateTest, LogicalOperationsAndOr) { + auto true_expr = Expressions::AlwaysTrue(); + auto false_expr = Expressions::AlwaysFalse(); + auto pred1 = Expressions::Equal("age", Literal::Int(25)); + auto pred2 = Expressions::Equal("name", Literal::String("test")); + + // Test AND operations + auto and_true_true = Expressions::And(true_expr, true_expr); + EXPECT_EQ(and_true_true->op(), Expression::Operation::kTrue); + + auto and_true_pred = Expressions::And(true_expr, pred1); + EXPECT_EQ(and_true_pred->op(), Expression::Operation::kEq); + + auto and_pred_true = Expressions::And(pred1, true_expr); + EXPECT_EQ(and_pred_true->op(), Expression::Operation::kEq); + + auto and_false_pred = Expressions::And(false_expr, pred1); + EXPECT_EQ(and_false_pred->op(), Expression::Operation::kFalse); + + auto and_pred_false = Expressions::And(pred1, false_expr); + EXPECT_EQ(and_pred_false->op(), Expression::Operation::kFalse); + + // Test OR operations + auto or_false_false = Expressions::Or(false_expr, false_expr); + EXPECT_EQ(or_false_false->op(), Expression::Operation::kFalse); + + auto or_false_pred = Expressions::Or(false_expr, pred1); + EXPECT_EQ(or_false_pred->op(), Expression::Operation::kEq); + + auto or_pred_false = Expressions::Or(pred1, false_expr); + EXPECT_EQ(or_pred_false->op(), Expression::Operation::kEq); + + auto or_true_pred = Expressions::Or(true_expr, pred1); + EXPECT_EQ(or_true_pred->op(), Expression::Operation::kTrue); + + auto or_pred_true = Expressions::Or(pred1, true_expr); + EXPECT_EQ(or_pred_true->op(), Expression::Operation::kTrue); +} + +TEST_F(PredicateTest, ConstantExpressions) { + auto always_true = Expressions::AlwaysTrue(); + auto always_false = Expressions::AlwaysFalse(); + + EXPECT_EQ(always_true->op(), Expression::Operation::kTrue); + EXPECT_EQ(always_false->op(), Expression::Operation::kFalse); +} + +TEST_F(PredicateTest, UnaryPredicateFactory) { + auto is_null_name = Expressions::IsNull("name"); + EXPECT_EQ(is_null_name->op(), Expression::Operation::kIsNull); + EXPECT_EQ(is_null_name->reference()->name(), "name"); + + auto not_null_name = Expressions::NotNull("active"); + EXPECT_EQ(not_null_name->op(), Expression::Operation::kNotNull); + EXPECT_EQ(not_null_name->reference()->name(), "active"); + + auto is_nan_name = Expressions::IsNaN("salary"); + EXPECT_EQ(is_nan_name->op(), Expression::Operation::kIsNan); + EXPECT_EQ(is_nan_name->reference()->name(), "salary"); + + auto not_nan_name = Expressions::NotNaN("salary"); + EXPECT_EQ(not_nan_name->op(), Expression::Operation::kNotNan); + EXPECT_EQ(not_nan_name->reference()->name(), "salary"); +} + +TEST_F(PredicateTest, ComparisonPredicateFactory) { + auto lt_name = Expressions::LessThan("age", Literal::Int(30)); + EXPECT_EQ(lt_name->op(), Expression::Operation::kLt); + EXPECT_EQ(lt_name->reference()->name(), "age"); + + auto lte_name = Expressions::LessThanOrEqual("salary", Literal::Double(50000.0)); + EXPECT_EQ(lte_name->op(), Expression::Operation::kLtEq); + EXPECT_EQ(lte_name->reference()->name(), "salary"); + + auto gt_name = Expressions::GreaterThan("id", Literal::Long(1000)); + EXPECT_EQ(gt_name->op(), Expression::Operation::kGt); + EXPECT_EQ(gt_name->reference()->name(), "id"); + + auto gte_name = Expressions::GreaterThanOrEqual("age", Literal::Int(18)); + EXPECT_EQ(gte_name->op(), Expression::Operation::kGtEq); + EXPECT_EQ(gte_name->reference()->name(), "age"); + + auto eq_name = Expressions::Equal("name", Literal::String("test")); + EXPECT_EQ(eq_name->op(), Expression::Operation::kEq); + EXPECT_EQ(eq_name->reference()->name(), "name"); + + auto neq_name = Expressions::NotEqual("active", Literal::Boolean(false)); + EXPECT_EQ(neq_name->op(), Expression::Operation::kNotEq); + EXPECT_EQ(neq_name->reference()->name(), "active"); +} + +TEST_F(PredicateTest, StringPredicateFactory) { + auto starts_name = Expressions::StartsWith("name", "John"); + EXPECT_EQ(starts_name->op(), Expression::Operation::kStartsWith); + EXPECT_EQ(starts_name->reference()->name(), "name"); + + auto not_starts_name = Expressions::NotStartsWith("name", "Jane"); + EXPECT_EQ(not_starts_name->op(), Expression::Operation::kNotStartsWith); + EXPECT_EQ(not_starts_name->reference()->name(), "name"); +} + +TEST_F(PredicateTest, SetPredicateFactory) { + std::vector values = {Literal::Int(10), Literal::Int(20), Literal::Int(30)}; + std::initializer_list init_values = {Literal::String("a"), + Literal::String("b")}; + + auto in_name_vec = Expressions::In("age", values); + EXPECT_EQ(in_name_vec->op(), Expression::Operation::kIn); + EXPECT_EQ(in_name_vec->reference()->name(), "age"); + + auto in_name_init = Expressions::In("name", init_values); + EXPECT_EQ(in_name_init->op(), Expression::Operation::kIn); + EXPECT_EQ(in_name_init->reference()->name(), "name"); + + auto not_in_name_vec = Expressions::NotIn("age", values); + EXPECT_EQ(not_in_name_vec->op(), Expression::Operation::kNotIn); + EXPECT_EQ(not_in_name_vec->reference()->name(), "age"); + + auto not_in_name_init = Expressions::NotIn("name", init_values); + EXPECT_EQ(not_in_name_init->op(), Expression::Operation::kNotIn); + EXPECT_EQ(not_in_name_init->reference()->name(), "name"); +} + +TEST_F(PredicateTest, GenericPredicateFactory) { + auto pred_single = + Expressions::Predicate(Expression::Operation::kEq, "age", Literal::Int(25)); + EXPECT_EQ(pred_single->op(), Expression::Operation::kEq); + EXPECT_EQ(pred_single->reference()->name(), "age"); + + std::vector values = {Literal::Int(10), Literal::Int(20)}; + auto pred_multi = Expressions::Predicate(Expression::Operation::kIn, "age", values); + EXPECT_EQ(pred_multi->op(), Expression::Operation::kIn); + EXPECT_EQ(pred_multi->reference()->name(), "age"); + + auto pred_unary = Expressions::Predicate(Expression::Operation::kIsNull, "name"); + EXPECT_EQ(pred_unary->op(), Expression::Operation::kIsNull); + EXPECT_EQ(pred_unary->reference()->name(), "name"); +} + +TEST_F(PredicateTest, TransformFactory) { + auto bucket_transform = Expressions::Bucket("id", 10); + EXPECT_NE(bucket_transform, nullptr); + EXPECT_EQ(bucket_transform->reference()->name(), "id"); + + auto year_transform = Expressions::Year("timestamp_field"); + EXPECT_NE(year_transform, nullptr); + EXPECT_EQ(year_transform->reference()->name(), "timestamp_field"); + + auto month_transform = Expressions::Month("timestamp_field"); + EXPECT_NE(month_transform, nullptr); + EXPECT_EQ(month_transform->reference()->name(), "timestamp_field"); + + auto day_transform = Expressions::Day("timestamp_field"); + EXPECT_NE(day_transform, nullptr); + EXPECT_EQ(day_transform->reference()->name(), "timestamp_field"); + + auto hour_transform = Expressions::Hour("timestamp_field"); + EXPECT_NE(hour_transform, nullptr); + EXPECT_EQ(hour_transform->reference()->name(), "timestamp_field"); + + auto truncate_transform = Expressions::Truncate("string_field", 5); + EXPECT_NE(truncate_transform, nullptr); + EXPECT_EQ(truncate_transform->reference()->name(), "string_field"); +} + +TEST_F(PredicateTest, ReferenceFactory) { + auto ref = Expressions::Ref("test_field"); + EXPECT_EQ(ref->name(), "test_field"); + EXPECT_EQ(ref->ToString(), "ref(name=\"test_field\")"); +} + +TEST_F(PredicateTest, NamedReferenceBasics) { + auto ref = std::make_shared("id"); + EXPECT_EQ(ref->name(), "id"); + EXPECT_EQ(ref->ToString(), "ref(name=\"id\")"); + EXPECT_EQ(ref->reference(), ref); +} + +TEST_F(PredicateTest, NamedReferenceBind) { + auto ref = std::make_shared("id"); + auto bound_result = ref->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_result, IsOk()); + + auto bound_ref = bound_result.value(); + EXPECT_EQ(bound_ref->name(), "id"); + EXPECT_EQ(bound_ref->field().field_id(), 1); + EXPECT_EQ(bound_ref->type()->type_id(), TypeId::kLong); + EXPECT_FALSE(bound_ref->MayProduceNull()); +} + +TEST_F(PredicateTest, NamedReferenceBindNonExistentField) { + auto ref = std::make_shared("non_existent_field"); + auto bound_result = ref->Bind(*schema_, /*case_sensitive=*/true); + EXPECT_THAT(bound_result, IsError(ErrorKind::kInvalidExpression)); +} + +TEST_F(PredicateTest, BoundReferenceEquality) { + auto ref1 = std::make_shared("id"); + auto ref2 = std::make_shared("id"); + auto ref3 = std::make_shared("name"); + + auto bound1 = ref1->Bind(*schema_, true).value(); + auto bound2 = ref2->Bind(*schema_, true).value(); + auto bound3 = ref3->Bind(*schema_, true).value(); + + // Same field should be equal + EXPECT_TRUE(bound1->Equals(*bound2)); + EXPECT_TRUE(bound2->Equals(*bound1)); + + // Different fields should not be equal + EXPECT_FALSE(bound1->Equals(*bound3)); + EXPECT_FALSE(bound3->Equals(*bound1)); +} + +TEST_F(PredicateTest, UnboundPredicateCreation) { + auto is_null_pred = Expressions::IsNull("name"); + EXPECT_EQ(is_null_pred->op(), Expression::Operation::kIsNull); + EXPECT_EQ(is_null_pred->reference()->name(), "name"); + + auto not_null_pred = Expressions::NotNull("name"); + EXPECT_EQ(not_null_pred->op(), Expression::Operation::kNotNull); + + auto equal_pred = Expressions::Equal("age", Literal::Int(25)); + EXPECT_EQ(equal_pred->op(), Expression::Operation::kEq); + + auto greater_than_pred = Expressions::GreaterThan("salary", Literal::Double(50000.0)); + EXPECT_EQ(greater_than_pred->op(), Expression::Operation::kGt); +} + +TEST_F(PredicateTest, UnboundPredicateToString) { + auto equal_pred = Expressions::Equal("age", Literal::Int(25)); + EXPECT_EQ(equal_pred->ToString(), "ref(name=\"age\") == 25"); + + auto is_null_pred = Expressions::IsNull("name"); + EXPECT_EQ(is_null_pred->ToString(), "is_null(ref(name=\"name\"))"); + + auto in_pred = Expressions::In("age", {Literal::Int(10), Literal::Int(20)}); + EXPECT_EQ(in_pred->ToString(), "ref(name=\"age\") in [10, 20]"); + + auto starts_with_pred = Expressions::StartsWith("name", "John"); + EXPECT_EQ(starts_with_pred->ToString(), "ref(name=\"name\") startsWith \"John\""); +} + +TEST_F(PredicateTest, UnboundPredicateNegate) { + auto equal_pred = Expressions::Equal("age", Literal::Int(25)); + auto negated_result = equal_pred->Negate(); + ASSERT_THAT(negated_result, IsOk()); + + auto negated_pred = negated_result.value(); + EXPECT_EQ(negated_pred->op(), Expression::Operation::kNotEq); + + auto is_null_pred = Expressions::IsNull("name"); + auto negated_null_result = is_null_pred->Negate(); + ASSERT_THAT(negated_null_result, IsOk()); + + auto negated_null_pred = negated_null_result.value(); + EXPECT_EQ(negated_null_pred->op(), Expression::Operation::kNotNull); + + auto in_pred = Expressions::In("age", {Literal::Int(10), Literal::Int(20)}); + auto negated_in_result = in_pred->Negate(); + ASSERT_THAT(negated_in_result, IsOk()); + + auto negated_in_pred = negated_in_result.value(); + EXPECT_EQ(negated_in_pred->op(), Expression::Operation::kNotIn); +} + +TEST_F(PredicateTest, UnboundPredicateBindUnary) { + auto is_null_pred = Expressions::IsNull("name"); + auto bound_result = is_null_pred->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_result, IsOk()); + + auto bound_pred = bound_result.value(); + EXPECT_EQ(bound_pred->op(), Expression::Operation::kIsNull); + + // Test NOT NULL on non-nullable field - should return AlwaysTrue + auto not_null_required = Expressions::NotNull("age"); // age is required + auto bound_not_null_result = not_null_required->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_not_null_result, IsOk()); + + auto bound_not_null = bound_not_null_result.value(); + EXPECT_EQ(bound_not_null->op(), Expression::Operation::kTrue); + + // Test IS NULL on non-nullable field - should return AlwaysFalse + auto is_null_required = Expressions::IsNull("age"); // age is required + auto bound_is_null_result = is_null_required->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_is_null_result, IsOk()); + + auto bound_is_null = bound_is_null_result.value(); + EXPECT_EQ(bound_is_null->op(), Expression::Operation::kFalse); +} + +TEST_F(PredicateTest, UnboundPredicateBindLiteral) { + auto equal_pred = Expressions::Equal("age", Literal::Int(25)); + auto bound_result = equal_pred->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_result, IsOk()); + + auto bound_pred = bound_result.value(); + EXPECT_EQ(bound_pred->op(), Expression::Operation::kEq); + + // Test binding with type conversion + auto equal_long_pred = + Expressions::Equal("id", Literal::Int(123)); // int to long conversion + auto bound_long_result = equal_long_pred->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_long_result, IsOk()); + + auto bound_long_pred = bound_long_result.value(); + EXPECT_EQ(bound_long_pred->op(), Expression::Operation::kEq); +} + +TEST_F(PredicateTest, UnboundPredicateBindIn) { + // Test IN operation with single value (should become equality) + auto in_single = Expressions::In("age", {Literal::Int(25)}); + auto bound_single_result = in_single->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_single_result, IsOk()); + + auto bound_single = bound_single_result.value(); + EXPECT_EQ(bound_single->op(), Expression::Operation::kEq); + + // Test NOT IN operation with single value (should become inequality) + auto not_in_single = Expressions::NotIn("age", {Literal::Int(25)}); + auto bound_not_single_result = not_in_single->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_not_single_result, IsOk()); + + auto bound_not_single = bound_not_single_result.value(); + EXPECT_EQ(bound_not_single->op(), Expression::Operation::kNotEq); + + // Test IN operation with multiple values (should stay as IN) + auto in_multi = Expressions::In("age", {Literal::Int(25), Literal::Int(30)}); + auto bound_multi_result = in_multi->Bind(*schema_, true); + ASSERT_THAT(bound_multi_result, IsOk()); + + auto bound_multi = bound_multi_result.value(); + EXPECT_EQ(bound_multi->op(), Expression::Operation::kIn); +} + +TEST_F(PredicateTest, FloatingPointNaNPredicates) { + auto is_nan_float = Expressions::IsNaN("salary"); // salary is float64 + auto bound_nan_result = is_nan_float->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_nan_result, IsOk()); + + auto bound_nan = bound_nan_result.value(); + EXPECT_EQ(bound_nan->op(), Expression::Operation::kIsNan); + + auto is_nan_int = Expressions::IsNaN("age"); // age is int32 + auto bound_nan_int_result = is_nan_int->Bind(*schema_, /*case_sensitive=*/true); + EXPECT_THAT(bound_nan_int_result, IsError(ErrorKind::kInvalidExpression)); +} + +TEST_F(PredicateTest, StringStartsWithPredicates) { + auto starts_with = Expressions::StartsWith("name", "John"); // name is string + auto bound_starts_result = starts_with->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_starts_result, IsOk()); + + auto bound_starts = bound_starts_result.value(); + EXPECT_EQ(bound_starts->op(), Expression::Operation::kStartsWith); + + auto starts_with_int = Expressions::StartsWith("age", "test"); // age is int32 + auto bound_starts_int_result = starts_with_int->Bind(*schema_, /*case_sensitive=*/true); + EXPECT_THAT(bound_starts_int_result, IsError(ErrorKind::kInvalidExpression)); +} + +TEST_F(PredicateTest, LiteralConversionEdgeCases) { + auto large_value_lt = + Expressions::LessThan("age", Literal::Long(std::numeric_limits::max())); + auto bound_large_result = large_value_lt->Bind(*schema_, /*case_sensitive=*/true); + ASSERT_THAT(bound_large_result, IsOk()); + + auto bound_large = bound_large_result.value(); + EXPECT_EQ(bound_large->op(), Expression::Operation::kTrue); +} + +TEST_F(PredicateTest, ComplexExpressionCombinations) { + auto eq_pred = Expressions::Equal("age", Literal::Int(25)); + auto null_pred = Expressions::IsNull("name"); + auto in_pred = + Expressions::In("id", {Literal::Long(1), Literal::Long(2), Literal::Long(3)}); + + // Test AND combinations + auto and_eq_null = Expressions::And(eq_pred, null_pred); + EXPECT_EQ(and_eq_null->op(), Expression::Operation::kAnd); + + auto and_eq_in = Expressions::And(eq_pred, in_pred); + EXPECT_EQ(and_eq_in->op(), Expression::Operation::kAnd); + + // Test OR combinations + auto or_null_in = Expressions::Or(null_pred, in_pred); + EXPECT_EQ(or_null_in->op(), Expression::Operation::kOr); + + // Test nested combinations + auto nested = Expressions::And(and_eq_null, or_null_in); + EXPECT_EQ(nested->op(), Expression::Operation::kAnd); +} + +} // namespace iceberg diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index 1ce1a6e0..dcacf84f 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -84,7 +84,7 @@ Transform::Transform(TransformType transform_type, int32_t param) TransformType Transform::transform_type() const { return transform_type_; } -Result> Transform::Bind( +Result> Transform::Bind( const std::shared_ptr& source_type) const { auto type_str = TransformTypeToString(transform_type_); diff --git a/src/iceberg/transform.h b/src/iceberg/transform.h index 6c771fbf..e5a08235 100644 --- a/src/iceberg/transform.h +++ b/src/iceberg/transform.h @@ -147,7 +147,7 @@ class ICEBERG_EXPORT Transform : public util::Formattable { /// parameter. /// \param source_type The source column type to bind to. /// \return A TransformFunction instance wrapped in `expected`, or an error on failure. - Result> Bind( + Result> Bind( const std::shared_ptr& source_type) const; /// \brief Returns a string representation of this transform (e.g., "bucket[16]"). diff --git a/src/iceberg/transform_function.h b/src/iceberg/transform_function.h index 6d810640..165390b1 100644 --- a/src/iceberg/transform_function.h +++ b/src/iceberg/transform_function.h @@ -25,7 +25,7 @@ namespace iceberg { /// \brief Identity transform that returns the input unchanged. -class IdentityTransform : public TransformFunction { +class ICEBERG_EXPORT IdentityTransform : public TransformFunction { public: /// \param source_type Type of the input data. explicit IdentityTransform(std::shared_ptr const& source_type); @@ -44,7 +44,7 @@ class IdentityTransform : public TransformFunction { }; /// \brief Bucket transform that hashes input values into N buckets. -class BucketTransform : public TransformFunction { +class ICEBERG_EXPORT BucketTransform : public TransformFunction { public: /// \param source_type Type of the input data. /// \param num_buckets Number of buckets to hash into. @@ -68,7 +68,7 @@ class BucketTransform : public TransformFunction { }; /// \brief Truncate transform that truncates values to a specified width. -class TruncateTransform : public TransformFunction { +class ICEBERG_EXPORT TruncateTransform : public TransformFunction { public: /// \param source_type Type of the input data. /// \param width The width to truncate to (e.g., for strings or numbers). @@ -92,7 +92,7 @@ class TruncateTransform : public TransformFunction { }; /// \brief Year transform that extracts the year component from timestamp inputs. -class YearTransform : public TransformFunction { +class ICEBERG_EXPORT YearTransform : public TransformFunction { public: /// \param source_type Must be a timestamp type. explicit YearTransform(std::shared_ptr const& source_type); @@ -111,7 +111,7 @@ class YearTransform : public TransformFunction { }; /// \brief Month transform that extracts the month component from timestamp inputs. -class MonthTransform : public TransformFunction { +class ICEBERG_EXPORT MonthTransform : public TransformFunction { public: /// \param source_type Must be a timestamp type. explicit MonthTransform(std::shared_ptr const& source_type); @@ -130,7 +130,7 @@ class MonthTransform : public TransformFunction { }; /// \brief Day transform that extracts the day of the month from timestamp inputs. -class DayTransform : public TransformFunction { +class ICEBERG_EXPORT DayTransform : public TransformFunction { public: /// \param source_type Must be a timestamp type. explicit DayTransform(std::shared_ptr const& source_type); @@ -149,7 +149,7 @@ class DayTransform : public TransformFunction { }; /// \brief Hour transform that extracts the hour component from timestamp inputs. -class HourTransform : public TransformFunction { +class ICEBERG_EXPORT HourTransform : public TransformFunction { public: /// \param source_type Must be a timestamp type. explicit HourTransform(std::shared_ptr const& source_type); @@ -168,7 +168,7 @@ class HourTransform : public TransformFunction { }; /// \brief Void transform that discards the input and always returns null. -class VoidTransform : public TransformFunction { +class ICEBERG_EXPORT VoidTransform : public TransformFunction { public: /// \param source_type Input type (ignored). explicit VoidTransform(std::shared_ptr const& source_type);