Skip to content

Commit 224be6a

Browse files
jckingcopybara-github
authored andcommitted
Consolidate struct value builder
PiperOrigin-RevId: 690350923
1 parent feaecdb commit 224be6a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1374
-1468
lines changed

common/type_reflector.cc

Lines changed: 62 additions & 915 deletions
Large diffs are not rendered by default.

common/value.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2640,7 +2640,7 @@ class ValueBuilder {
26402640

26412641
virtual absl::Status SetFieldByNumber(int64_t number, Value value) = 0;
26422642

2643-
virtual Value Build() && = 0;
2643+
virtual absl::StatusOr<Value> Build() && = 0;
26442644
};
26452645

26462646
using ValueBuilderPtr = std::unique_ptr<ValueBuilder>;

common/values/struct_value_builder.cc

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include "common/type_introspector.h"
4242
#include "common/type_reflector.h"
4343
#include "common/value.h"
44+
#include "common/value_factory.h"
4445
#include "common/value_kind.h"
4546
#include "common/value_manager.h"
4647
#include "extensions/protobuf/internal/map_reflection.h"
@@ -129,28 +130,9 @@ class CompatTypeReflector final : public TypeReflector {
129130
absl::StatusOr<absl::optional<Value>> DeserializeValueImpl(
130131
ValueFactory& value_factory, absl::string_view type_url,
131132
const absl::Cord& value) const override {
132-
absl::string_view type_name;
133-
if (!ParseTypeUrl(type_url, &type_name)) {
134-
return absl::InvalidArgumentError("invalid type URL");
135-
}
136-
const auto* descriptor =
137-
descriptor_pool()->FindMessageTypeByName(type_name);
138-
if (descriptor == nullptr) {
139-
return absl::nullopt;
140-
}
141-
const auto* prototype = message_factory()->GetPrototype(descriptor);
142-
if (prototype == nullptr) {
143-
return absl::nullopt;
144-
}
145-
absl::Nullable<google::protobuf::Arena*> arena =
146-
value_factory.GetMemoryManager().arena();
147-
auto message = WrapShared(prototype->New(arena), arena);
148-
if (!message->ParsePartialFromCord(value)) {
149-
return absl::InvalidArgumentError(
150-
absl::StrCat("failed to parse `", type_url, "`"));
151-
}
152-
return Value::Message(WrapShared(prototype->New(arena), arena), pool_,
153-
factory_);
133+
// This should not be reachable, as we provide both the pool and the factory
134+
// which should trigger DeserializeValue to handle the call and not call us.
135+
return absl::nullopt;
154136
}
155137

156138
private:
@@ -1011,9 +993,9 @@ GetProtoRepeatedFieldFromValueMutator(
1011993
}
1012994
}
1013995

1014-
class StructValueBuilderImpl final : public StructValueBuilder {
996+
class MessageValueBuilderImpl {
1015997
public:
1016-
StructValueBuilderImpl(
998+
MessageValueBuilderImpl(
1017999
absl::Nullable<google::protobuf::Arena*> arena,
10181000
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
10191001
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
@@ -1025,13 +1007,13 @@ class StructValueBuilderImpl final : public StructValueBuilder {
10251007
descriptor_(message_->GetDescriptor()),
10261008
reflection_(message_->GetReflection()) {}
10271009

1028-
~StructValueBuilderImpl() override {
1010+
~MessageValueBuilderImpl() {
10291011
if (arena_ == nullptr && message_ != nullptr) {
10301012
delete message_;
10311013
}
10321014
}
10331015

1034-
absl::Status SetFieldByName(absl::string_view name, Value value) override {
1016+
absl::Status SetFieldByName(absl::string_view name, Value value) {
10351017
const auto* field = descriptor_->FindFieldByName(name);
10361018
if (field == nullptr) {
10371019
field = descriptor_pool_->FindExtensionByPrintableName(descriptor_, name);
@@ -1042,7 +1024,7 @@ class StructValueBuilderImpl final : public StructValueBuilder {
10421024
return SetField(field, std::move(value));
10431025
}
10441026

1045-
absl::Status SetFieldByNumber(int64_t number, Value value) override {
1027+
absl::Status SetFieldByNumber(int64_t number, Value value) {
10461028
if (number < std::numeric_limits<int32_t>::min() ||
10471029
number > std::numeric_limits<int32_t>::max()) {
10481030
return NoSuchFieldError(absl::StrCat(number)).NativeValue();
@@ -1055,7 +1037,12 @@ class StructValueBuilderImpl final : public StructValueBuilder {
10551037
return SetField(field, std::move(value));
10561038
}
10571039

1058-
absl::StatusOr<StructValue> Build() && override {
1040+
absl::StatusOr<Value> Build() && {
1041+
return Value::Message(WrapShared(std::exchange(message_, nullptr)),
1042+
descriptor_pool_, message_factory_);
1043+
}
1044+
1045+
absl::StatusOr<StructValue> BuildStruct() && {
10591046
return ParsedMessageValue(
10601047
WrapShared(std::exchange(message_, nullptr), Allocator(arena_)));
10611048
}
@@ -1519,19 +1506,91 @@ class StructValueBuilderImpl final : public StructValueBuilder {
15191506
well_known_types::Reflection well_known_types_;
15201507
};
15211508

1509+
class ValueBuilderImpl final : public ValueBuilder {
1510+
public:
1511+
ValueBuilderImpl(absl::Nullable<google::protobuf::Arena*> arena,
1512+
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
1513+
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
1514+
absl::Nonnull<google::protobuf::Message*> message)
1515+
: builder_(arena, descriptor_pool, message_factory, message) {}
1516+
1517+
absl::Status SetFieldByName(absl::string_view name, Value value) override {
1518+
return builder_.SetFieldByName(name, std::move(value));
1519+
}
1520+
1521+
absl::Status SetFieldByNumber(int64_t number, Value value) override {
1522+
return builder_.SetFieldByNumber(number, std::move(value));
1523+
}
1524+
1525+
absl::StatusOr<Value> Build() && override {
1526+
return std::move(builder_).Build();
1527+
}
1528+
1529+
private:
1530+
MessageValueBuilderImpl builder_;
1531+
};
1532+
1533+
class StructValueBuilderImpl final : public StructValueBuilder {
1534+
public:
1535+
StructValueBuilderImpl(
1536+
absl::Nullable<google::protobuf::Arena*> arena,
1537+
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
1538+
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
1539+
absl::Nonnull<google::protobuf::Message*> message)
1540+
: builder_(arena, descriptor_pool, message_factory, message) {}
1541+
1542+
absl::Status SetFieldByName(absl::string_view name, Value value) override {
1543+
return builder_.SetFieldByName(name, std::move(value));
1544+
}
1545+
1546+
absl::Status SetFieldByNumber(int64_t number, Value value) override {
1547+
return builder_.SetFieldByNumber(number, std::move(value));
1548+
}
1549+
1550+
absl::StatusOr<StructValue> Build() && override {
1551+
return std::move(builder_).BuildStruct();
1552+
}
1553+
1554+
private:
1555+
MessageValueBuilderImpl builder_;
1556+
};
1557+
15221558
} // namespace
15231559

1524-
absl::StatusOr<absl::Nonnull<cel::StructValueBuilderPtr>> NewStructValueBuilder(
1560+
absl::StatusOr<absl::Nullable<cel::ValueBuilderPtr>> NewValueBuilder(
1561+
Allocator<> allocator,
1562+
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
1563+
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
1564+
absl::string_view name) {
1565+
absl::Nullable<const google::protobuf::Descriptor*> descriptor =
1566+
descriptor_pool->FindMessageTypeByName(name);
1567+
if (descriptor == nullptr) {
1568+
return nullptr;
1569+
}
1570+
absl::Nullable<const google::protobuf::Message*> prototype =
1571+
message_factory->GetPrototype(descriptor);
1572+
if (prototype == nullptr) {
1573+
return absl::NotFoundError(absl::StrCat(
1574+
"unable to get prototype for descriptor: ", descriptor->full_name()));
1575+
}
1576+
return std::make_unique<ValueBuilderImpl>(allocator.arena(), descriptor_pool,
1577+
message_factory,
1578+
prototype->New(allocator.arena()));
1579+
}
1580+
1581+
absl::StatusOr<absl::Nullable<cel::StructValueBuilderPtr>>
1582+
NewStructValueBuilder(
15251583
Allocator<> allocator,
15261584
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
15271585
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
15281586
absl::string_view name) {
1529-
const auto* descriptor = descriptor_pool->FindMessageTypeByName(name);
1587+
absl::Nullable<const google::protobuf::Descriptor*> descriptor =
1588+
descriptor_pool->FindMessageTypeByName(name);
15301589
if (descriptor == nullptr) {
1531-
return absl::NotFoundError(
1532-
absl::StrCat("unable to find descriptor for type: ", name));
1590+
return nullptr;
15331591
}
1534-
const auto* prototype = message_factory->GetPrototype(descriptor);
1592+
absl::Nullable<const google::protobuf::Message*> prototype =
1593+
message_factory->GetPrototype(descriptor);
15351594
if (prototype == nullptr) {
15361595
return absl::NotFoundError(absl::StrCat(
15371596
"unable to get prototype for descriptor: ", descriptor->full_name()));

common/values/struct_value_builder.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,15 @@
2323
#include "google/protobuf/descriptor.h"
2424
#include "google/protobuf/message.h"
2525

26-
namespace cel {
26+
namespace cel::common_internal {
2727

28-
class ValueFactory;
29-
30-
namespace common_internal {
31-
32-
absl::StatusOr<absl::Nonnull<cel::StructValueBuilderPtr>> NewStructValueBuilder(
28+
absl::StatusOr<absl::Nullable<cel::StructValueBuilderPtr>>
29+
NewStructValueBuilder(
3330
Allocator<> allocator,
3431
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
3532
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
3633
absl::string_view name);
3734

38-
} // namespace common_internal
39-
40-
} // namespace cel
35+
} // namespace cel::common_internal
4136

4237
#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_STRUCT_VALUE_BUILDER_H_

common/values/value_builder.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_
16+
#define THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_
17+
18+
#include "absl/base/nullability.h"
19+
#include "absl/status/statusor.h"
20+
#include "absl/strings/string_view.h"
21+
#include "common/allocator.h"
22+
#include "common/value.h"
23+
#include "google/protobuf/descriptor.h"
24+
#include "google/protobuf/message.h"
25+
26+
namespace cel::common_internal {
27+
28+
// Like NewStructValueBuilder, but deals with well known types.
29+
absl::StatusOr<absl::Nullable<cel::ValueBuilderPtr>> NewValueBuilder(
30+
Allocator<> allocator,
31+
absl::Nonnull<const google::protobuf::DescriptorPool*> descriptor_pool,
32+
absl::Nonnull<google::protobuf::MessageFactory*> message_factory,
33+
absl::string_view name);
34+
35+
} // namespace cel::common_internal
36+
37+
#endif // THIRD_PARTY_CEL_CPP_COMMON_VALUES_VALUE_BUILDER_H_

conformance/run.bzl

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,19 @@ def _expand_tests_to_skip(tests_to_skip):
3535
result.append(test_to_skip[0:slash] + part)
3636
return result
3737

38-
def _conformance_test_name(name, modern, arena, optimize, recursive, skip_check):
38+
def _conformance_test_name(name, optimize, recursive):
3939
return "_".join(
4040
[
4141
name,
42-
"arena" if arena else "refcount",
4342
"optimized" if optimize else "unoptimized",
4443
"recursive" if recursive else "iterative",
4544
],
4645
)
4746

48-
def _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_tests, dashboard):
47+
def _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard):
4948
args = []
5049
if modern:
5150
args.append("--modern")
52-
elif not arena:
53-
fail("arena must be true for legacy")
54-
if not modern or arena:
55-
args.append("--arena")
5651
if optimize:
5752
args.append("--opt")
5853
if recursive:
@@ -66,10 +61,10 @@ def _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_
6661
args.append("--dashboard")
6762
return args
6863

69-
def _conformance_test(name, data, modern, arena, optimize, recursive, skip_check, skip_tests, tags, dashboard):
64+
def _conformance_test(name, data, modern, optimize, recursive, skip_check, skip_tests, tags, dashboard):
7065
native.cc_test(
71-
name = _conformance_test_name(name, modern, arena, optimize, recursive, skip_check),
72-
args = _conformance_test_args(modern, arena, optimize, recursive, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data],
66+
name = _conformance_test_name(name, optimize, recursive),
67+
args = _conformance_test_args(modern, optimize, recursive, skip_check, skip_tests, dashboard) + ["$(location " + test + ")" for test in data],
7368
data = data,
7469
deps = ["//conformance:run"],
7570
tags = tags,
@@ -89,15 +84,12 @@ def gen_conformance_tests(name, data, modern = False, checked = False, dashboard
8984
dashboard: enable dashboard mode
9085
"""
9186
skip_check = not checked
92-
93-
# TODO: enable refcount mode for modern.
9487
for optimize in (True, False):
9588
for recursive in (True, False):
9689
_conformance_test(
9790
name,
9891
data,
9992
modern = modern,
100-
arena = True,
10193
optimize = optimize,
10294
recursive = recursive,
10395
skip_check = skip_check,

conformance/run.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)");
5757
ABSL_FLAG(
5858
bool, modern, false,
5959
"Use modern cel::Value APIs implementation of the conformance service.");
60-
ABSL_FLAG(bool, arena, false,
61-
"Use arena memory manager (default: global heap ref-counted). Only "
62-
"affects the modern implementation");
6360
ABSL_FLAG(bool, recursive, false,
6461
"Enable recursive plans. Depth limited to slightly more than the "
6562
"default nesting limit.");
@@ -279,7 +276,6 @@ NewConformanceServiceFromFlags() {
279276
cel_conformance::ConformanceServiceOptions{
280277
.optimize = absl::GetFlag(FLAGS_opt),
281278
.modern = absl::GetFlag(FLAGS_modern),
282-
.arena = absl::GetFlag(FLAGS_arena),
283279
.recursive = absl::GetFlag(FLAGS_recursive)});
284280
ABSL_CHECK_OK(status_or_service);
285281
return std::shared_ptr<cel_conformance::ConformanceServiceInterface>(

0 commit comments

Comments
 (0)