diff --git a/common/internal/shared_byte_string.cc b/common/internal/shared_byte_string.cc index d080bab43..9132ada80 100644 --- a/common/internal/shared_byte_string.cc +++ b/common/internal/shared_byte_string.cc @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/nullability.h" #include "absl/functional/overload.h" @@ -65,6 +66,24 @@ SharedByteString::SharedByteString(Allocator<> allocator, } } +SharedByteString::SharedByteString(Allocator<> allocator, std::string&& value) + : header_(/*is_cord=*/false, /*size=*/value.size()) { + if (value.empty()) { + content_.string.data = ""; + content_.string.refcount = 0; + } else { + if (auto* arena = allocator.arena(); arena != nullptr) { + content_.string.data = + google::protobuf::Arena::Create(arena, std::move(value))->data(); + content_.string.refcount = 0; + return; + } + auto pair = MakeReferenceCountedString(std::move(value)); + content_.string.data = pair.second.data(); + content_.string.refcount = reinterpret_cast(pair.first); + } +} + SharedByteString SharedByteString::Clone(Allocator<> allocator) const { if (absl::Nullable arena = allocator.arena(); arena != nullptr) { diff --git a/common/internal/shared_byte_string.h b/common/internal/shared_byte_string.h index 1d8a37e51..81d58b992 100644 --- a/common/internal/shared_byte_string.h +++ b/common/internal/shared_byte_string.h @@ -37,6 +37,11 @@ #include "common/internal/reference_count.h" #include "common/memory.h" +namespace cel { +class BytesValueInputStream; +class BytesValueOutputStream; +} // namespace cel + namespace cel::common_internal { class TrivialValue; @@ -158,6 +163,11 @@ class SharedByteString final { // if necessary. SharedByteString(Allocator<> allocator, const absl::Cord& value); + SharedByteString(Allocator<> allocator, std::string&& value); + + SharedByteString(Allocator<> allocator, absl::Nullable value) + : SharedByteString(allocator, absl::NullSafeStringView(value)) {} + // Constructs a shared byte string which is borrowed and references `value`. SharedByteString(Borrower borrower, absl::string_view value) : SharedByteString(common_internal::BorrowerRelease(borrower), value) {} @@ -399,6 +409,8 @@ class SharedByteString final { private: friend class TrivialValue; friend class SharedByteStringView; + friend class cel::BytesValueInputStream; + friend class cel::BytesValueOutputStream; static void SwapMixed(SharedByteString& cord, SharedByteString& string) noexcept { diff --git a/common/value.h b/common/value.h index fd7d466a8..338765352 100644 --- a/common/value.h +++ b/common/value.h @@ -46,6 +46,8 @@ #include "common/value_kind.h" #include "common/values/bool_value.h" // IWYU pragma: export #include "common/values/bytes_value.h" // IWYU pragma: export +#include "common/values/bytes_value_input_stream.h" // IWYU pragma: export +#include "common/values/bytes_value_output_stream.h" // IWYU pragma: export #include "common/values/custom_list_value.h" // IWYU pragma: export #include "common/values/custom_map_value.h" // IWYU pragma: export #include "common/values/custom_struct_value.h" // IWYU pragma: export diff --git a/common/values/bytes_value.h b/common/values/bytes_value.h index 039b00c1a..aa0317679 100644 --- a/common/values/bytes_value.h +++ b/common/values/bytes_value.h @@ -47,6 +47,8 @@ namespace cel { class Value; class BytesValue; class TypeManager; +class BytesValueInputStream; +class BytesValueOutputStream; namespace common_internal { class TrivialValue; @@ -89,6 +91,12 @@ class BytesValue final : private common_internal::ValueMixin { BytesValue(Allocator<> allocator, const absl::Cord& value) : value_(allocator, value) {} + BytesValue(Allocator<> allocator, std::string&& value) + : value_(allocator, std::move(value)) {} + + BytesValue(Allocator<> allocator, absl::Nullable value) + : value_(allocator, value) {} + BytesValue(Borrower borrower, absl::string_view value) : value_(borrower, value) {} @@ -206,6 +214,8 @@ class BytesValue final : private common_internal::ValueMixin { friend const common_internal::SharedByteString& common_internal::AsSharedByteString(const BytesValue& value); friend class common_internal::ValueMixin; + friend class BytesValueInputStream; + friend class BytesValueOutputStream; common_internal::SharedByteString value_; }; diff --git a/common/values/bytes_value_input_stream.h b/common/values/bytes_value_input_stream.h new file mode 100644 index 000000000..ed5d8f9a4 --- /dev/null +++ b/common/values/bytes_value_input_stream.h @@ -0,0 +1,119 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "common/values/bytes_value.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { + +class BytesValueInputStream final : public google::protobuf::io::ZeroCopyInputStream { + public: + explicit BytesValueInputStream( + absl::Nonnull value ABSL_ATTRIBUTE_LIFETIME_BOUND) { + Construct(value); + } + + ~BytesValueInputStream() override { AsVariant().~variant(); } + + bool Next(const void** data, int* size) override { + return absl::visit( + [&data, &size](auto& alternative) -> bool { + return alternative.Next(data, size); + }, + AsVariant()); + } + + void BackUp(int count) override { + absl::visit( + [&count](auto& alternative) -> void { alternative.BackUp(count); }, + AsVariant()); + } + + bool Skip(int count) override { + return absl::visit( + [&count](auto& alternative) -> bool { return alternative.Skip(count); }, + AsVariant()); + } + + int64_t ByteCount() const override { + return absl::visit( + [](const auto& alternative) -> int64_t { + return alternative.ByteCount(); + }, + AsVariant()); + } + + bool ReadCord(absl::Cord* cord, int count) override { + return absl::visit( + [&cord, &count](auto& alternative) -> bool { + return alternative.ReadCord(cord, count); + }, + AsVariant()); + } + + private: + using Variant = + absl::variant; + + void Construct(absl::Nonnull value) { + ABSL_DCHECK(value != nullptr); + if (value->value_.header_.is_cord) { + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, + value->value_.cord_ptr()); + } else { + absl::string_view string = value->value_.AsStringView(); + ABSL_DCHECK_LE(string.size(), + static_cast(std::numeric_limits::max())); + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, + string.data(), static_cast(string.size())); + } + } + + void Destruct() { AsVariant().~variant(); } + + Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + alignas(Variant) char impl_[sizeof(Variant)]; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_INPUT_STREAM_H_ diff --git a/common/values/bytes_value_output_stream.h b/common/values/bytes_value_output_stream.h new file mode 100644 index 000000000..58c259fdc --- /dev/null +++ b/common/values/bytes_value_output_stream.h @@ -0,0 +1,162 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "common/value.h" +// IWYU pragma: friend "common/value.h" + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "absl/utility/utility.h" +#include "common/values/bytes_value.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" + +namespace cel { + +class BytesValueOutputStream final : public google::protobuf::io::ZeroCopyOutputStream { + public: + explicit BytesValueOutputStream(const BytesValue& value) + : BytesValueOutputStream(value, /*arena=*/nullptr) {} + + BytesValueOutputStream(const BytesValue& value, + absl::Nullable arena) { + Construct(value, arena); + } + + bool Next(void** data, int* size) override { + return absl::visit(absl::Overload( + [&data, &size](String& string) -> bool { + return string.stream.Next(data, size); + }, + [&data, &size](Cord& cord) -> bool { + return cord.Next(data, size); + }), + AsVariant()); + } + + void BackUp(int count) override { + absl::visit( + absl::Overload( + [&count](String& string) -> void { string.stream.BackUp(count); }, + [&count](Cord& cord) -> void { cord.BackUp(count); }), + AsVariant()); + } + + int64_t ByteCount() const override { + return absl::visit( + absl::Overload( + [](const String& string) -> int64_t { + return string.stream.ByteCount(); + }, + [](const Cord& cord) -> int64_t { return cord.ByteCount(); }), + AsVariant()); + } + + bool WriteAliasedRaw(const void* data, int size) override { + return absl::visit(absl::Overload( + [&data, &size](String& string) -> bool { + return string.stream.WriteAliasedRaw(data, size); + }, + [&data, &size](Cord& cord) -> bool { + return cord.WriteAliasedRaw(data, size); + }), + AsVariant()); + } + + bool AllowsAliasing() const override { + return absl::visit( + absl::Overload( + [](const String& string) -> bool { + return string.stream.AllowsAliasing(); + }, + [](const Cord& cord) -> bool { return cord.AllowsAliasing(); }), + AsVariant()); + } + + bool WriteCord(const absl::Cord& out) override { + return absl::visit( + absl::Overload( + [&out](String& string) -> bool { + return string.stream.WriteCord(out); + }, + [&out](Cord& cord) -> bool { return cord.WriteCord(out); }), + AsVariant()); + } + + BytesValue Consume() && { + return absl::visit(absl::Overload( + [](String& string) -> BytesValue { + return BytesValue(string.arena, + std::move(string.target)); + }, + [](Cord& cord) -> BytesValue { + return BytesValue(cord.Consume()); + }), + AsVariant()); + } + + private: + struct String final { + String(absl::string_view target, absl::Nullable arena) + : target(target), stream(&this->target), arena(arena) {} + + std::string target; + google::protobuf::io::StringOutputStream stream; + absl::Nullable arena; + }; + + using Cord = google::protobuf::io::CordOutputStream; + + using Variant = absl::variant; + + void Construct(const BytesValue& value, + absl::Nullable arena) { + if (value.value_.header_.is_cord) { + ::new (static_cast(&impl_[0])) + Variant(absl::in_place_type, *value.value_.cord_ptr()); + } else { + ::new (static_cast(&impl_[0])) Variant( + absl::in_place_type, value.value_.AsStringView(), arena); + } + } + + void Destruct() { AsVariant().~variant(); } + + Variant& AsVariant() ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + const Variant& AsVariant() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return *std::launder(reinterpret_cast(&impl_[0])); + } + + alignas(Variant) char impl_[sizeof(Variant)]; +}; + +} // namespace cel + +#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_BYTES_VALUE_OUTPUT_STREAM_H_ diff --git a/common/values/bytes_value_test.cc b/common/values/bytes_value_test.cc index a2023017a..0ce53e287 100644 --- a/common/values/bytes_value_test.cc +++ b/common/values/bytes_value_test.cc @@ -14,7 +14,9 @@ #include #include +#include +#include "google/protobuf/struct.pb.h" #include "absl/status/status_matchers.h" #include "absl/strings/cord.h" #include "absl/strings/cord_test_helpers.h" @@ -28,7 +30,9 @@ namespace cel { namespace { using ::absl_testing::IsOk; +using ::testing::An; using ::testing::Eq; +using ::testing::NotNull; using ::testing::Optional; using BytesValueTest = common_internal::ValueTest<>; @@ -154,5 +158,97 @@ TEST_F(BytesValueTest, Comparison) { EXPECT_FALSE(BytesValue("foo") < BytesValue("bar")); } +TEST_F(BytesValueTest, StringInputStream) { + BytesValue value = BytesValue("foo"); + BytesValueInputStream stream(&value); + const void* data; + int size; + absl::Cord cord; + ASSERT_TRUE(stream.Next(&data, &size)); + EXPECT_THAT(data, NotNull()); + EXPECT_EQ(size, 3); + EXPECT_EQ(stream.ByteCount(), 3); + stream.BackUp(size); + ASSERT_TRUE(stream.Skip(3)); + EXPECT_FALSE(stream.ReadCord(&cord, 3)); + EXPECT_FALSE(stream.Next(&data, &size)); +} + +TEST_F(BytesValueTest, CordInputStream) { + BytesValue value = BytesValue(absl::Cord("foo")); + BytesValueInputStream stream(&value); + const void* data; + int size; + absl::Cord cord; + ASSERT_TRUE(stream.Next(&data, &size)); + EXPECT_THAT(data, NotNull()); + EXPECT_EQ(size, 3); + EXPECT_EQ(stream.ByteCount(), 3); + stream.BackUp(size); + ASSERT_TRUE(stream.Skip(3)); + EXPECT_FALSE(stream.ReadCord(&cord, 3)); + EXPECT_FALSE(stream.Next(&data, &size)); +} + +TEST_F(BytesValueTest, ArenaStringOutputStream) { + BytesValue value = BytesValue(""); + { + BytesValueOutputStream stream(value, arena()); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + +TEST_F(BytesValueTest, StringOutputStream) { + BytesValue value = BytesValue(""); + { + BytesValueOutputStream stream(value); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + +TEST_F(BytesValueTest, CordOutputStream) { + BytesValue value = BytesValue(absl::Cord()); + { + BytesValueOutputStream stream(value); + EXPECT_THAT(stream.AllowsAliasing(), An()); + EXPECT_EQ(stream.ByteCount(), 0); + google::protobuf::Value value_proto; + auto* struct_proto = value_proto.mutable_struct_value(); + (*struct_proto->mutable_fields())["foo"].set_string_value("bar"); + (*struct_proto->mutable_fields())["baz"].set_number_value(3.14159); + ASSERT_TRUE(value_proto.SerializePartialToZeroCopyStream(&stream)); + EXPECT_EQ(std::move(stream).Consume(), + value_proto.SerializePartialAsString()); + } + { + BytesValueOutputStream stream(value); + EXPECT_EQ(std::move(stream).Consume(), ""); + } +} + } // namespace } // namespace cel