diff --git a/common/BUILD b/common/BUILD index 6f688572c..c2e00722a 100644 --- a/common/BUILD +++ b/common/BUILD @@ -594,6 +594,7 @@ cc_library( ":unknown", ":value_kind", "//base:attributes", + "//base/internal:message_wrapper", "//common/internal:byte_string", "//common/internal:reference_count", "//eval/internal:cel_value_equal", diff --git a/common/legacy_value.cc b/common/legacy_value.cc index a27f24754..83ae901df 100644 --- a/common/legacy_value.cc +++ b/common/legacy_value.cc @@ -36,6 +36,7 @@ #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/attribute.h" +#include "base/internal/message_wrapper.h" #include "common/allocator.h" #include "common/casting.h" #include "common/kind.h" @@ -86,10 +87,27 @@ absl::Status InvalidMapKeyTypeError(ValueKind kind) { absl::StrCat("Invalid map key type: '", ValueKindToString(kind), "'")); } -MessageWrapper AsMessageWrapper( - absl::NullabilityUnknown message_ptr, - absl::NullabilityUnknown type_info) { - return MessageWrapper(message_ptr, type_info); +const CelList* AsCelList(uintptr_t impl) { + return reinterpret_cast(impl); +} + +const CelMap* AsCelMap(uintptr_t impl) { + return reinterpret_cast(impl); +} + +MessageWrapper AsMessageWrapper(uintptr_t message_ptr, uintptr_t type_info) { + if ((message_ptr & base_internal::kMessageWrapperTagMask) == + base_internal::kMessageWrapperTagMessageValue) { + return MessageWrapper::Builder( + static_cast( + reinterpret_cast( + message_ptr & base_internal::kMessageWrapperPtrMask))) + .Build(reinterpret_cast(type_info)); + } else { + return MessageWrapper::Builder( + reinterpret_cast(message_ptr)) + .Build(reinterpret_cast(type_info)); + } } class CelListIterator final : public ValueIterator { @@ -187,7 +205,7 @@ CelValue LegacyTrivialListValue(absl::Nonnull arena, const Value& value) { if (auto legacy_list_value = common_internal::AsLegacyListValue(value); legacy_list_value) { - return CelValue::CreateList(legacy_list_value->cel_list()); + return CelValue::CreateList(AsCelList(legacy_list_value->NativeValue())); } if (auto parsed_repeated_field_value = value.AsParsedRepeatedField(); parsed_repeated_field_value) { @@ -226,7 +244,7 @@ CelValue LegacyTrivialMapValue(absl::Nonnull arena, const Value& value) { if (auto legacy_map_value = common_internal::AsLegacyMapValue(value); legacy_map_value) { - return CelValue::CreateMap(legacy_map_value->cel_map()); + return CelValue::CreateMap(AsCelMap(legacy_map_value->NativeValue())); } if (auto parsed_map_field_value = value.AsParsedMapField(); parsed_map_field_value) { @@ -305,7 +323,7 @@ google::api::expr::runtime::CelValue LegacyTrivialValue( namespace common_internal { std::string LegacyListValue::DebugString() const { - return CelValue::CreateList(impl_).DebugString(); + return CelValue::CreateList(AsCelList(impl_)).DebugString(); } // See `ValueInterface::SerializeTo`. @@ -325,8 +343,9 @@ absl::Status LegacyListValue::SerializeTo( } google::protobuf::Arena arena; - const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( - descriptor, message_factory, CelValue::CreateList(impl_), &arena); + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(descriptor, message_factory, + CelValue::CreateList(AsCelList(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } @@ -351,7 +370,7 @@ absl::Status LegacyListValue::ConvertToJson( google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, - CelValue::CreateList(impl_), &arena); + CelValue::CreateList(AsCelList(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy list to JSON"); } @@ -390,7 +409,7 @@ absl::Status LegacyListValue::ConvertToJsonArray( google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, - CelValue::CreateList(impl_), &arena); + CelValue::CreateList(AsCelList(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy list to JSON"); } @@ -415,10 +434,10 @@ absl::Status LegacyListValue::ConvertToJsonArray( } } -bool LegacyListValue::IsEmpty() const { return impl_->empty(); } +bool LegacyListValue::IsEmpty() const { return AsCelList(impl_)->empty(); } size_t LegacyListValue::Size() const { - return static_cast(impl_->size()); + return static_cast(AsCelList(impl_)->size()); } // See LegacyListValueInterface::Get for documentation. @@ -426,12 +445,12 @@ absl::Status LegacyListValue::Get( size_t index, absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::Nonnull arena, absl::Nonnull result) const { - if (ABSL_PREDICT_FALSE(index < 0 || index >= impl_->size())) { + if (ABSL_PREDICT_FALSE(index < 0 || index >= AsCelList(impl_)->size())) { *result = ErrorValue(absl::InvalidArgumentError("index out of bounds")); return absl::OkStatus(); } - CEL_RETURN_IF_ERROR( - ModernValue(arena, impl_->Get(arena, static_cast(index)), *result)); + CEL_RETURN_IF_ERROR(ModernValue( + arena, AsCelList(impl_)->Get(arena, static_cast(index)), *result)); return absl::OkStatus(); } @@ -440,10 +459,11 @@ absl::Status LegacyListValue::ForEach( absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::Nonnull arena) const { - const auto size = impl_->size(); + const auto size = AsCelList(impl_)->size(); Value element; for (int index = 0; index < size; ++index) { - CEL_RETURN_IF_ERROR(ModernValue(arena, impl_->Get(arena, index), element)); + CEL_RETURN_IF_ERROR( + ModernValue(arena, AsCelList(impl_)->Get(arena, index), element)); CEL_ASSIGN_OR_RETURN(auto ok, callback(index, Value(element))); if (!ok) { break; @@ -454,7 +474,7 @@ absl::Status LegacyListValue::ForEach( absl::StatusOr> LegacyListValue::NewIterator() const { - return std::make_unique(impl_); + return std::make_unique(AsCelList(impl_)); } absl::Status LegacyListValue::Contains( @@ -463,7 +483,7 @@ absl::Status LegacyListValue::Contains( absl::Nonnull message_factory, absl::Nonnull arena, absl::Nonnull result) const { CEL_ASSIGN_OR_RETURN(auto legacy_other, LegacyValue(arena, other)); - const auto* cel_list = impl_; + const auto* cel_list = AsCelList(impl_); for (int i = 0; i < cel_list->size(); ++i) { auto element = cel_list->Get(arena, i); absl::optional equal = @@ -480,7 +500,7 @@ absl::Status LegacyListValue::Contains( } std::string LegacyMapValue::DebugString() const { - return CelValue::CreateMap(impl_).DebugString(); + return CelValue::CreateMap(AsCelMap(impl_)).DebugString(); } absl::Status LegacyMapValue::SerializeTo( @@ -498,8 +518,9 @@ absl::Status LegacyMapValue::SerializeTo( } google::protobuf::Arena arena; - const google::protobuf::Message* wrapped = MaybeWrapValueToMessage( - descriptor, message_factory, CelValue::CreateMap(impl_), &arena); + const google::protobuf::Message* wrapped = + MaybeWrapValueToMessage(descriptor, message_factory, + CelValue::CreateMap(AsCelMap(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } @@ -523,7 +544,7 @@ absl::Status LegacyMapValue::ConvertToJson( google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, - CelValue::CreateMap(impl_), &arena); + CelValue::CreateMap(AsCelMap(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } @@ -559,7 +580,7 @@ absl::Status LegacyMapValue::ConvertToJsonObject( google::protobuf::Arena arena; const google::protobuf::Message* wrapped = MaybeWrapValueToMessage(json->GetDescriptor(), message_factory, - CelValue::CreateMap(impl_), &arena); + CelValue::CreateMap(AsCelMap(impl_)), &arena); if (wrapped == nullptr) { return absl::UnknownError("failed to convert legacy map to JSON"); } @@ -582,10 +603,10 @@ absl::Status LegacyMapValue::ConvertToJsonObject( return absl::OkStatus(); } -bool LegacyMapValue::IsEmpty() const { return impl_->empty(); } +bool LegacyMapValue::IsEmpty() const { return AsCelMap(impl_)->empty(); } size_t LegacyMapValue::Size() const { - return static_cast(impl_->size()); + return static_cast(AsCelMap(impl_)->size()); } absl::Status LegacyMapValue::Get( @@ -611,7 +632,7 @@ absl::Status LegacyMapValue::Get( return InvalidMapKeyTypeError(key.kind()); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - auto cel_value = impl_->Get(arena, cel_key); + auto cel_value = AsCelMap(impl_)->Get(arena, cel_key); if (!cel_value.has_value()) { *result = NoSuchKeyError(key.DebugString()); return absl::OkStatus(); @@ -643,7 +664,7 @@ absl::StatusOr LegacyMapValue::Find( return InvalidMapKeyTypeError(key.kind()); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - auto cel_value = impl_->Get(arena, cel_key); + auto cel_value = AsCelMap(impl_)->Get(arena, cel_key); if (!cel_value.has_value()) { *result = NullValue{}; return false; @@ -675,7 +696,7 @@ absl::Status LegacyMapValue::Has( return InvalidMapKeyTypeError(key.kind()); } CEL_ASSIGN_OR_RETURN(auto cel_key, LegacyValue(arena, key)); - CEL_ASSIGN_OR_RETURN(auto has, impl_->Has(cel_key)); + CEL_ASSIGN_OR_RETURN(auto has, AsCelMap(impl_)->Has(cel_key)); *result = BoolValue{has}; return absl::OkStatus(); } @@ -685,8 +706,9 @@ absl::Status LegacyMapValue::ListKeys( absl::Nonnull message_factory, absl::Nonnull arena, absl::Nonnull result) const { - CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); - *result = ListValue{common_internal::LegacyListValue(keys)}; + CEL_ASSIGN_OR_RETURN(auto keys, AsCelMap(impl_)->ListKeys(arena)); + *result = ListValue{ + common_internal::LegacyListValue{reinterpret_cast(keys)}}; return absl::OkStatus(); } @@ -695,13 +717,13 @@ absl::Status LegacyMapValue::ForEach( absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::Nonnull arena) const { - CEL_ASSIGN_OR_RETURN(auto keys, impl_->ListKeys(arena)); + CEL_ASSIGN_OR_RETURN(auto keys, AsCelMap(impl_)->ListKeys(arena)); const auto size = keys->size(); Value key; Value value; for (int index = 0; index < size; ++index) { auto cel_key = keys->Get(arena, index); - auto cel_value = *impl_->Get(arena, cel_key); + auto cel_value = *AsCelMap(impl_)->Get(arena, cel_key); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_key, key)); CEL_RETURN_IF_ERROR(ModernValue(arena, cel_value, value)); CEL_ASSIGN_OR_RETURN(auto ok, callback(key, value)); @@ -714,16 +736,16 @@ absl::Status LegacyMapValue::ForEach( absl::StatusOr> LegacyMapValue::NewIterator() const { - return std::make_unique(impl_); + return std::make_unique(AsCelMap(impl_)); } absl::string_view LegacyStructValue::GetTypeName() const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); return message_wrapper.legacy_type_info()->GetTypename(message_wrapper); } std::string LegacyStructValue::DebugString() const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); return message_wrapper.legacy_type_info()->DebugString(message_wrapper); } @@ -734,7 +756,7 @@ absl::Status LegacyStructValue::SerializeTo( ABSL_DCHECK(descriptor_pool != nullptr); ABSL_DCHECK(message_factory != nullptr); - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); if (ABSL_PREDICT_TRUE( message_wrapper.message_ptr()->SerializePartialToCord(&value))) { return absl::OkStatus(); @@ -752,7 +774,7 @@ absl::Status LegacyStructValue::ConvertToJson( ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_VALUE); - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); return internal::MessageToJson( *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), @@ -769,7 +791,7 @@ absl::Status LegacyStructValue::ConvertToJsonObject( ABSL_DCHECK_EQ(json->GetDescriptor()->well_known_type(), google::protobuf::Descriptor::WELLKNOWNTYPE_STRUCT); - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); return internal::MessageToJson( *google::protobuf::DownCastMessage(message_wrapper.message_ptr()), @@ -783,7 +805,7 @@ absl::Status LegacyStructValue::Equal( absl::Nonnull arena, absl::Nonnull result) const { if (auto legacy_struct_value = common_internal::AsLegacyStructValue(other); legacy_struct_value.has_value()) { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -799,7 +821,7 @@ absl::Status LegacyStructValue::Equal( } if (auto struct_value = other.AsStruct(); struct_value.has_value()) { return common_internal::StructValueEqual( - common_internal::LegacyStructValue(message_ptr_, legacy_type_info_), + common_internal::LegacyStructValue(message_ptr_, type_info_), *struct_value, descriptor_pool, message_factory, arena, result); } *result = FalseValue(); @@ -807,7 +829,7 @@ absl::Status LegacyStructValue::Equal( } bool LegacyStructValue::IsZeroValue() const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -821,7 +843,7 @@ absl::Status LegacyStructValue::GetFieldByName( absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::Nonnull arena, absl::Nonnull result) const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -847,7 +869,7 @@ absl::Status LegacyStructValue::GetFieldByNumber( absl::StatusOr LegacyStructValue::HasFieldByName( absl::string_view name) const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -866,7 +888,7 @@ absl::Status LegacyStructValue::ForEachField( absl::Nonnull descriptor_pool, absl::Nonnull message_factory, absl::Nonnull arena) const { - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -899,7 +921,7 @@ absl::Status LegacyStructValue::Qualify( if (ABSL_PREDICT_FALSE(qualifiers.empty())) { return absl::InvalidArgumentError("invalid select qualifier path."); } - auto message_wrapper = AsMessageWrapper(message_ptr_, legacy_type_info_); + auto message_wrapper = AsMessageWrapper(message_ptr_, type_info_); const auto* access_apis = message_wrapper.legacy_type_info()->GetAccessApis(message_wrapper); if (ABSL_PREDICT_FALSE(access_apis == nullptr)) { @@ -956,10 +978,12 @@ absl::Status ModernValue(google::protobuf::Arena* arena, return absl::OkStatus(); case CelValue::Type::kMessage: { auto message_wrapper = legacy_value.MessageWrapperOrDie(); - result = common_internal::LegacyStructValue( - google::protobuf::DownCastMessage( - message_wrapper.message_ptr()), - message_wrapper.legacy_type_info()); + result = common_internal::LegacyStructValue{ + reinterpret_cast(message_wrapper.message_ptr()) | + (message_wrapper.HasFullProto() + ? base_internal::kMessageWrapperTagMessageValue + : uintptr_t{0}), + reinterpret_cast(message_wrapper.legacy_type_info())}; return absl::OkStatus(); } case CelValue::Type::kDuration: @@ -969,12 +993,12 @@ absl::Status ModernValue(google::protobuf::Arena* arena, result = UnsafeTimestampValue(legacy_value.TimestampOrDie()); return absl::OkStatus(); case CelValue::Type::kList: - result = - ListValue(common_internal::LegacyListValue(legacy_value.ListOrDie())); + result = ListValue{common_internal::LegacyListValue{ + reinterpret_cast(legacy_value.ListOrDie())}}; return absl::OkStatus(); case CelValue::Type::kMap: - result = - MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); + result = MapValue{common_internal::LegacyMapValue{ + reinterpret_cast(legacy_value.MapOrDie())}}; return absl::OkStatus(); case CelValue::Type::kUnknownSet: result = UnknownValue{*legacy_value.UnknownSetOrDie()}; @@ -1074,20 +1098,23 @@ absl::StatusOr FromLegacyValue(google::protobuf::Arena* arena, legacy_value.BytesOrDie().value()); case CelValue::Type::kMessage: { auto message_wrapper = legacy_value.MessageWrapperOrDie(); - return common_internal::LegacyStructValue( - google::protobuf::DownCastMessage( - message_wrapper.message_ptr()), - message_wrapper.legacy_type_info()); + return common_internal::LegacyStructValue{ + reinterpret_cast(message_wrapper.message_ptr()) | + (message_wrapper.HasFullProto() + ? base_internal::kMessageWrapperTagMessageValue + : uintptr_t{0}), + reinterpret_cast(message_wrapper.legacy_type_info())}; } case CelValue::Type::kDuration: return UnsafeDurationValue(legacy_value.DurationOrDie()); case CelValue::Type::kTimestamp: return UnsafeTimestampValue(legacy_value.TimestampOrDie()); case CelValue::Type::kList: - return ListValue( - common_internal::LegacyListValue(legacy_value.ListOrDie())); + return ListValue{common_internal::LegacyListValue{ + reinterpret_cast(legacy_value.ListOrDie())}}; case CelValue::Type::kMap: - return MapValue(common_internal::LegacyMapValue(legacy_value.MapOrDie())); + return MapValue{common_internal::LegacyMapValue{ + reinterpret_cast(legacy_value.MapOrDie())}}; case CelValue::Type::kUnknownSet: return UnknownValue{*legacy_value.UnknownSetOrDie()}; case CelValue::Type::kCelType: diff --git a/common/legacy_value.h b/common/legacy_value.h index f1f5928ba..f6523ac70 100644 --- a/common/legacy_value.h +++ b/common/legacy_value.h @@ -79,12 +79,12 @@ inline DoubleValue CreateDoubleValue(double value) { inline ListValue CreateLegacyListValue( const google::api::expr::runtime::CelList* value) { - return common_internal::LegacyListValue(value); + return common_internal::LegacyListValue{reinterpret_cast(value)}; } inline MapValue CreateLegacyMapValue( const google::api::expr::runtime::CelMap* value) { - return common_internal::LegacyMapValue(value); + return common_internal::LegacyMapValue{reinterpret_cast(value)}; } inline Value CreateDurationValue(absl::Duration value, bool unchecked = false) { diff --git a/common/values/legacy_list_value.cc b/common/values/legacy_list_value.cc index 8e2f82eea..2e701888a 100644 --- a/common/values/legacy_list_value.cc +++ b/common/values/legacy_list_value.cc @@ -14,6 +14,8 @@ #include "common/values/legacy_list_value.h" +#include + #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" @@ -60,15 +62,15 @@ absl::optional AsLegacyListValue(const Value& value) { if (auto custom_list_value = value.AsCustomList(); custom_list_value) { NativeTypeId native_type_id = NativeTypeId::Of(*custom_list_value); if (native_type_id == NativeTypeId::For()) { - return LegacyListValue( + return LegacyListValue(reinterpret_cast( static_cast( cel::internal::down_cast( - (*custom_list_value).operator->()))); + (*custom_list_value).operator->())))); } else if (native_type_id == NativeTypeId::For()) { - return LegacyListValue( + return LegacyListValue(reinterpret_cast( static_cast( cel::internal::down_cast( - (*custom_list_value).operator->()))); + (*custom_list_value).operator->())))); } } return absl::nullopt; diff --git a/common/values/legacy_list_value.h b/common/values/legacy_list_value.h index 1f5f7cb90..116c12b8a 100644 --- a/common/values/legacy_list_value.h +++ b/common/values/legacy_list_value.h @@ -19,6 +19,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_LIST_VALUE_H_ #include +#include #include #include @@ -35,10 +36,6 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" -namespace google::api::expr::runtime { -class CelList; -} - namespace cel { class TypeManager; @@ -53,9 +50,8 @@ class LegacyListValue final public: static constexpr ValueKind kKind = ValueKind::kList; - explicit LegacyListValue( - absl::NullabilityUnknown impl) - : impl_(impl) {} + // NOLINTNEXTLINE(google-explicit-constructor) + explicit LegacyListValue(uintptr_t impl) : impl_(impl) {} // By default, this creates an empty list whose type is `list(dyn)`. Unless // you can help it, you should use a more specific typed list value. @@ -131,23 +127,24 @@ class LegacyListValue final absl::Nonnull arena, absl::Nonnull result) const; using ListValueMixin::Contains; - absl::NullabilityUnknown - cel_list() const { - return impl_; - } - - friend void swap(LegacyListValue& lhs, LegacyListValue& rhs) noexcept { + void swap(LegacyListValue& other) noexcept { using std::swap; - swap(lhs.impl_, rhs.impl_); + swap(impl_, other.impl_); } + uintptr_t NativeValue() const { return impl_; } + private: friend class common_internal::ValueMixin; friend class common_internal::ListValueMixin; - absl::NullabilityUnknown impl_; + uintptr_t impl_; }; +inline void swap(LegacyListValue& lhs, LegacyListValue& rhs) noexcept { + lhs.swap(rhs); +} + inline std::ostream& operator<<(std::ostream& out, const LegacyListValue& type) { return out << type.DebugString(); diff --git a/common/values/legacy_map_value.cc b/common/values/legacy_map_value.cc index 42258d9bf..1de99a916 100644 --- a/common/values/legacy_map_value.cc +++ b/common/values/legacy_map_value.cc @@ -14,6 +14,8 @@ #include "common/values/legacy_map_value.h" +#include + #include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" @@ -60,15 +62,15 @@ absl::optional AsLegacyMapValue(const Value& value) { if (auto custom_map_value = value.AsCustomMap(); custom_map_value) { NativeTypeId native_type_id = NativeTypeId::Of(*custom_map_value); if (native_type_id == NativeTypeId::For()) { - return LegacyMapValue( + return LegacyMapValue(reinterpret_cast( static_cast( cel::internal::down_cast( - (*custom_map_value).operator->()))); + (*custom_map_value).operator->())))); } else if (native_type_id == NativeTypeId::For()) { - return LegacyMapValue( + return LegacyMapValue(reinterpret_cast( static_cast( cel::internal::down_cast( - (*custom_map_value).operator->()))); + (*custom_map_value).operator->())))); } } return absl::nullopt; diff --git a/common/values/legacy_map_value.h b/common/values/legacy_map_value.h index ad905ed2b..d7acb06d4 100644 --- a/common/values/legacy_map_value.h +++ b/common/values/legacy_map_value.h @@ -19,6 +19,7 @@ #define THIRD_PARTY_CEL_CPP_COMMON_VALUES_LEGACY_MAP_VALUE_H_ #include +#include #include #include @@ -35,10 +36,6 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" -namespace google::api::expr::runtime { -class CelMap; -} - namespace cel { class TypeManager; @@ -53,9 +50,8 @@ class LegacyMapValue final public: static constexpr ValueKind kKind = ValueKind::kMap; - explicit LegacyMapValue( - absl::NullabilityUnknown impl) - : impl_(impl) {} + // NOLINTNEXTLINE(google-explicit-constructor) + explicit LegacyMapValue(uintptr_t impl) : impl_(impl) {} // By default, this creates an empty map whose type is `map(dyn, dyn)`. // Unless you can help it, you should use a more specific typed map value. @@ -152,22 +148,24 @@ class LegacyMapValue final absl::StatusOr> NewIterator() const; - absl::Nonnull cel_map() const { - return impl_; - } - - friend void swap(LegacyMapValue& lhs, LegacyMapValue& rhs) noexcept { + void swap(LegacyMapValue& other) noexcept { using std::swap; - swap(lhs.impl_, rhs.impl_); + swap(impl_, other.impl_); } + uintptr_t NativeValue() const { return impl_; } + private: friend class common_internal::ValueMixin; friend class common_internal::MapValueMixin; - absl::NullabilityUnknown impl_; + uintptr_t impl_; }; +inline void swap(LegacyMapValue& lhs, LegacyMapValue& rhs) noexcept { + lhs.swap(rhs); +} + inline std::ostream& operator<<(std::ostream& out, const LegacyMapValue& type) { return out << type.DebugString(); } diff --git a/common/values/legacy_struct_value.cc b/common/values/legacy_struct_value.cc index c39c6dde6..25184b92c 100644 --- a/common/values/legacy_struct_value.cc +++ b/common/values/legacy_struct_value.cc @@ -15,14 +15,24 @@ #include "absl/log/absl_check.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "base/internal/message_wrapper.h" #include "common/type.h" #include "common/value.h" #include "google/protobuf/message.h" +#include "google/protobuf/message_lite.h" namespace cel::common_internal { StructType LegacyStructValue::GetRuntimeType() const { - return MessageType(message_ptr_->GetDescriptor()); + if ((message_ptr_ & ::cel::base_internal::kMessageWrapperTagMask) == + ::cel::base_internal::kMessageWrapperTagMessageValue) { + return MessageType( + google::protobuf::DownCastMessage( + reinterpret_cast( + message_ptr_ & ::cel::base_internal::kMessageWrapperPtrMask)) + ->GetDescriptor()); + } + return common_internal::MakeBasicStructType(GetTypeName()); } bool IsLegacyStructValue(const Value& value) { diff --git a/common/values/legacy_struct_value.h b/common/values/legacy_struct_value.h index 249b0a272..380760d23 100644 --- a/common/values/legacy_struct_value.h +++ b/common/values/legacy_struct_value.h @@ -40,10 +40,6 @@ #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" -namespace google::api::expr::runtime { -class LegacyTypeInfoApis; -} - namespace cel { class Value; @@ -61,12 +57,8 @@ class LegacyStructValue final public: static constexpr ValueKind kKind = ValueKind::kStruct; - LegacyStructValue( - absl::NullabilityUnknown message_ptr, - absl::NullabilityUnknown< - const google::api::expr::runtime::LegacyTypeInfoApis*> - legacy_type_info) - : message_ptr_(message_ptr), legacy_type_info_(legacy_type_info) {} + LegacyStructValue(uintptr_t message_ptr, uintptr_t type_info) + : message_ptr_(message_ptr), type_info_(type_info) {} LegacyStructValue(const LegacyStructValue&) = default; LegacyStructValue& operator=(const LegacyStructValue&) = default; @@ -106,6 +98,12 @@ class LegacyStructValue final bool IsZeroValue() const; + void swap(LegacyStructValue& other) noexcept { + using std::swap; + swap(message_ptr_, other.message_ptr_); + swap(type_info_, other.type_info_); + } + absl::Status GetFieldByName( absl::string_view name, ProtoWrapperTypeOptions unboxing_options, absl::Nonnull descriptor_pool, @@ -140,32 +138,22 @@ class LegacyStructValue final absl::Nonnull count) const; using StructValueMixin::Qualify; - absl::NullabilityUnknown message_ptr() const { - return message_ptr_; - } - - absl::NullabilityUnknown< - const google::api::expr::runtime::LegacyTypeInfoApis*> - legacy_type_info() const { - return legacy_type_info_; - } + uintptr_t message_ptr() const { return message_ptr_; } - friend void swap(LegacyStructValue& lhs, LegacyStructValue& rhs) noexcept { - using std::swap; - swap(lhs.message_ptr_, rhs.message_ptr_); - swap(lhs.legacy_type_info_, rhs.legacy_type_info_); - } + uintptr_t legacy_type_info() const { return type_info_; } private: friend class common_internal::ValueMixin; friend class common_internal::StructValueMixin; - absl::NullabilityUnknown message_ptr_; - absl::NullabilityUnknown< - const google::api::expr::runtime::LegacyTypeInfoApis*> - legacy_type_info_; + uintptr_t message_ptr_; + uintptr_t type_info_; }; +inline void swap(LegacyStructValue& lhs, LegacyStructValue& rhs) noexcept { + lhs.swap(rhs); +} + inline std::ostream& operator<<(std::ostream& out, const LegacyStructValue& value) { return out << value.DebugString(); diff --git a/common/values/struct_value_builder.cc b/common/values/struct_value_builder.cc index 50ada0e47..18b2de590 100644 --- a/common/values/struct_value_builder.cc +++ b/common/values/struct_value_builder.cc @@ -30,6 +30,8 @@ #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/internal/message_wrapper.h" #include "common/allocator.h" #include "common/any.h" #include "common/memory.h" @@ -295,7 +297,8 @@ absl::Status ProtoMessageFromValueImpl( // Deal with legacy values. if (auto legacy_value = common_internal::AsLegacyStructValue(value); legacy_value) { - const auto* from_message = legacy_value->message_ptr(); + const auto* from_message = reinterpret_cast( + legacy_value->message_ptr() & base_internal::kMessageWrapperPtrMask); return ProtoMessageCopy(message, to_desc, from_message); } diff --git a/eval/public/structs/legacy_type_provider.cc b/eval/public/structs/legacy_type_provider.cc index fded90e03..45c6e3231 100644 --- a/eval/public/structs/legacy_type_provider.cc +++ b/eval/public/structs/legacy_type_provider.cc @@ -35,7 +35,6 @@ #include "internal/status_macros.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" -#include "google/protobuf/message_lite.h" namespace google::api::expr::runtime { @@ -80,9 +79,12 @@ class LegacyStructValueBuilder final : public cel::StructValueBuilder { return absl::FailedPreconditionError("expected MessageWrapper"); } auto message_wrapper = message.MessageWrapperOrDie(); - return cel::common_internal::LegacyStructValue( - google::protobuf::DownCastMessage(message_wrapper.message_ptr()), - message_wrapper.legacy_type_info()); + return cel::common_internal::LegacyStructValue{ + reinterpret_cast(message_wrapper.message_ptr()) | + (message_wrapper.HasFullProto() + ? cel::base_internal::kMessageWrapperTagMessageValue + : uintptr_t{0}), + reinterpret_cast(message_wrapper.legacy_type_info())}; } private: diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index fa93b8c1a..67033fbdb 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -163,6 +163,7 @@ cc_library( ], deps = [ ":type", + "//base/internal:message_wrapper", "//common:memory", "//common:type", "//common:value", diff --git a/extensions/protobuf/value.h b/extensions/protobuf/value.h index bda209c16..82aa16ee0 100644 --- a/extensions/protobuf/value.h +++ b/extensions/protobuf/value.h @@ -31,6 +31,7 @@ #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" +#include "base/internal/message_wrapper.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" @@ -63,7 +64,9 @@ inline absl::Status ProtoMessageFromValue(const Value& value, if (auto legacy_struct_value = cel::common_internal::AsLegacyStructValue(value); legacy_struct_value) { - src_message = legacy_struct_value->message_ptr(); + src_message = reinterpret_cast( + legacy_struct_value->message_ptr() & + cel::base_internal::kMessageWrapperPtrMask); } if (auto parsed_message_value = value.AsParsedMessage(); parsed_message_value) {