Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ cc_library(
"//extensions/protobuf/internal:map_reflection",
"//extensions/protobuf/internal:qualify",
"//internal:casts",
"//internal:empty_descriptors",
"//internal:json",
"//internal:manual",
"//internal:message_equality",
Expand Down
6 changes: 4 additions & 2 deletions common/values/message_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ common_internal::ValueVariant MessageValue::ToValueVariant() && {

common_internal::StructValueVariant MessageValue::ToStructValueVariant()
const& {
return absl::get<ParsedMessageValue>(variant_);
return common_internal::StructValueVariant(
absl::get<ParsedMessageValue>(variant_));
}

common_internal::StructValueVariant MessageValue::ToStructValueVariant() && {
return absl::get<ParsedMessageValue>(std::move(variant_));
return common_internal::StructValueVariant(
absl::get<ParsedMessageValue>(std::move(variant_)));
}

} // namespace cel
96 changes: 64 additions & 32 deletions common/values/parsed_message_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
#include <cstdint>
#include <limits>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#include "google/protobuf/empty.pb.h"
#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/log/absl_check.h"
Expand All @@ -33,6 +35,7 @@
#include "common/memory.h"
#include "common/value.h"
#include "extensions/protobuf/internal/qualify.h"
#include "internal/empty_descriptors.h"
#include "internal/json.h"
#include "internal/message_equality.h"
#include "internal/status_macros.h"
Expand All @@ -42,16 +45,37 @@
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"

namespace cel {

namespace {

using ::cel::well_known_types::ValueReflection;

template <typename T>
std::enable_if_t<std::is_base_of_v<google::protobuf::Message, T>,
absl::Nonnull<const google::protobuf::Message*>>
EmptyParsedMessageValue() {
return &T::default_instance();
}

template <typename T>
std::enable_if_t<
std::conjunction_v<std::is_base_of<google::protobuf::MessageLite, T>,
std::negation<std::is_base_of<google::protobuf::Message, T>>>,
absl::Nonnull<const google::protobuf::Message*>>
EmptyParsedMessageValue() {
return internal::GetEmptyDefaultInstance();
}

} // namespace

ParsedMessageValue::ParsedMessageValue()
: value_(EmptyParsedMessageValue<google::protobuf::Empty>()),
arena_(nullptr) {}

bool ParsedMessageValue::IsZeroValue() const {
ABSL_DCHECK(*this);
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
return true;
}
const auto* reflection = GetReflection();
if (!reflection->GetUnknownFields(*value_).empty()) {
return false;
Expand All @@ -62,9 +86,6 @@ bool ParsedMessageValue::IsZeroValue() const {
}

std::string ParsedMessageValue::DebugString() const {
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
return "INVALID";
}
return absl::StrCat(*value_);
}

Expand All @@ -75,11 +96,6 @@ absl::Status ParsedMessageValue::SerializeTo(
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(output != nullptr);
ABSL_DCHECK(*this);

if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
return absl::OkStatus();
}

if (!value_->SerializePartialToZeroCopyStream(output)) {
return absl::UnknownError(
Expand All @@ -97,16 +113,11 @@ absl::Status ParsedMessageValue::ConvertToJson(
ABSL_DCHECK(json != nullptr);
ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(),
google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE);
ABSL_DCHECK(*this);

ValueReflection value_reflection;
CEL_RETURN_IF_ERROR(value_reflection.Initialize(json->GetDescriptor()));
google::protobuf::Message* json_object = value_reflection.MutableStructValue(json);

if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
json_object->Clear();
return absl::OkStatus();
}
return internal::MessageToJson(*value_, descriptor_pool, message_factory,
json_object);
}
Expand All @@ -120,12 +131,7 @@ absl::Status ParsedMessageValue::ConvertToJsonObject(
ABSL_DCHECK(json != nullptr);
ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(),
google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT);
ABSL_DCHECK(*this);

if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
json->Clear();
return absl::OkStatus();
}
return internal::MessageToJson(*value_, descriptor_pool, message_factory,
json);
}
Expand All @@ -135,7 +141,11 @@ absl::Status ParsedMessageValue::Equal(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
ABSL_DCHECK(*this);
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);
ABSL_DCHECK(result != nullptr);

if (auto other_message = other.AsParsedMessage(); other_message) {
CEL_ASSIGN_OR_RETURN(
auto equal, internal::MessageEquals(*value_, **other_message,
Expand All @@ -154,10 +164,8 @@ absl::Status ParsedMessageValue::Equal(

ParsedMessageValue ParsedMessageValue::Clone(
absl::Nonnull<google::protobuf::Arena*> arena) const {
ABSL_DCHECK(*this);
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
return ParsedMessageValue();
}
ABSL_DCHECK(arena != nullptr);

if (arena_ == arena) {
return *this;
}
Expand All @@ -171,6 +179,11 @@ absl::Status ParsedMessageValue::GetFieldByName(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);
ABSL_DCHECK(result != nullptr);

const auto* descriptor = GetDescriptor();
const auto* field = descriptor->FindFieldByName(name);
if (field == nullptr) {
Expand All @@ -190,6 +203,11 @@ absl::Status ParsedMessageValue::GetFieldByNumber(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);
ABSL_DCHECK(result != nullptr);

const auto* descriptor = GetDescriptor();
if (number < std::numeric_limits<int32_t>::min() ||
number > std::numeric_limits<int32_t>::max()) {
Expand Down Expand Up @@ -238,10 +256,10 @@ absl::Status ParsedMessageValue::ForEachField(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena) const {
ABSL_DCHECK(*this);
if (ABSL_PREDICT_FALSE(value_ == nullptr)) {
return absl::OkStatus();
}
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);

std::vector<const google::protobuf::FieldDescriptor*> fields;
const auto* reflection = GetReflection();
reflection->ListFields(*value_, &fields);
Expand Down Expand Up @@ -322,7 +340,13 @@ absl::Status ParsedMessageValue::Qualify(
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result,
absl::Nonnull<int*> count) const {
ABSL_DCHECK(*this);
ABSL_DCHECK(!qualifiers.empty());
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);
ABSL_DCHECK(result != nullptr);
ABSL_DCHECK(count != nullptr);

if (ABSL_PREDICT_FALSE(qualifiers.empty())) {
return absl::InvalidArgumentError("invalid select qualifier path.");
}
Expand Down Expand Up @@ -357,13 +381,21 @@ absl::Status ParsedMessageValue::GetField(
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
absl::Nonnull<google::protobuf::Arena*> arena, absl::Nonnull<Value*> result) const {
ABSL_DCHECK(field != nullptr);
ABSL_DCHECK(descriptor_pool != nullptr);
ABSL_DCHECK(message_factory != nullptr);
ABSL_DCHECK(arena != nullptr);
ABSL_DCHECK(result != nullptr);

*result = Value::WrapField(unboxing_options, value_, field, descriptor_pool,
message_factory, arena);
return absl::OkStatus();
}

bool ParsedMessageValue::HasField(
absl::Nonnull<const google::protobuf::FieldDescriptor*> field) const {
ABSL_DCHECK(field != nullptr);

const auto* reflection = GetReflection();
if (field->is_map() || field->is_repeated()) {
return reflection->FieldSize(*value_, field) > 0;
Expand Down
17 changes: 6 additions & 11 deletions common/values/parsed_message_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ class ParsedMessageValue final
ABSL_DCHECK_OK(CheckArena(value_, arena_));
}

// Places the `ParsedMessageValue` into an invalid state. Anything except
// assigning to `MessageValue` is undefined behavior.
ParsedMessageValue() = default;

// Places the `ParsedMessageValue` into a special state where it is logically
// equivalent to the default instance of `google.protobuf.Empty`, however
// dereferencing via `operator*` or `operator->` is not allowed.
ParsedMessageValue();
ParsedMessageValue(const ParsedMessageValue&) = default;
ParsedMessageValue(ParsedMessageValue&&) = default;
ParsedMessageValue& operator=(const ParsedMessageValue&) = default;
Expand All @@ -96,13 +96,11 @@ class ParsedMessageValue final
}

const google::protobuf::Message& operator*() const ABSL_ATTRIBUTE_LIFETIME_BOUND {
ABSL_DCHECK(*this);
return *value_;
}

absl::Nonnull<const google::protobuf::Message*> operator->() const
ABSL_ATTRIBUTE_LIFETIME_BOUND {
ABSL_DCHECK(*this);
return value_;
}

Expand Down Expand Up @@ -171,9 +169,6 @@ class ParsedMessageValue final
absl::Nonnull<int*> count) const;
using StructValueMixin::Qualify;

// Returns `true` if `ParsedMessageValue` is in a valid state.
explicit operator bool() const { return value_ != nullptr; }

friend void swap(ParsedMessageValue& lhs, ParsedMessageValue& rhs) noexcept {
using std::swap;
swap(lhs.value_, rhs.value_);
Expand Down Expand Up @@ -205,8 +200,8 @@ class ParsedMessageValue final

bool HasField(absl::Nonnull<const google::protobuf::FieldDescriptor*> field) const;

absl::Nullable<const google::protobuf::Message*> value_ = nullptr;
absl::Nullable<google::protobuf::Arena*> arena_ = nullptr;
absl::Nonnull<const google::protobuf::Message*> value_;
absl::Nullable<google::protobuf::Arena*> arena_;
};

inline std::ostream& operator<<(std::ostream& out,
Expand Down
10 changes: 0 additions & 10 deletions common/values/parsed_message_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,6 @@ using TestAllTypesProto3 = ::cel::expr::conformance::proto3::TestAllTypes;

using ParsedMessageValueTest = common_internal::ValueTest<>;

TEST_F(ParsedMessageValueTest, Default) {
ParsedMessageValue value;
EXPECT_FALSE(value);
}

TEST_F(ParsedMessageValueTest, Field) {
ParsedMessageValue value = MakeParsedMessage<TestAllTypesProto3>();
EXPECT_TRUE(value);
}

TEST_F(ParsedMessageValueTest, Kind) {
ParsedMessageValue value = MakeParsedMessage<TestAllTypesProto3>();
EXPECT_EQ(value.kind(), ParsedMessageValue::kKind);
Expand Down
Loading