Skip to content

Commit ee0947f

Browse files
committed
Feature: add abstract and factory classes for overload controller module
1 parent 00e845f commit ee0947f

12 files changed

+392
-0
lines changed

trpc/common/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ cc_library(
136136
"//trpc/transport/common:ssl_helper",
137137
"//trpc/util/log/default:default_log",
138138
"//trpc/util:net_util",
139+
"//trpc/overload_control:trpc_overload_control",
139140
] + select({
140141
"//trpc:trpc_include_rpcz": [
141142
"//trpc/rpcz:collector",

trpc/common/trpc_plugin.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#ifdef TRPC_BUILD_INCLUDE_RPCZ
3636
#include "trpc/rpcz/collector.h"
3737
#endif
38+
#include "trpc/overload_control/trpc_overload_control.h"
3839
#include "trpc/runtime/common/periphery_task_scheduler.h"
3940
#include "trpc/runtime/common/runtime_info_report/runtime_info_reporter.h"
4041
#include "trpc/runtime/common/stats/frame_stats.h"
@@ -82,6 +83,8 @@ int TrpcPlugin::RegisterPlugins() {
8283
TRPC_ASSERT(telemetry::Init());
8384
TRPC_ASSERT(naming::Init());
8485

86+
TRPC_ASSERT(overload_control::Init());
87+
8588
CollectPlugins();
8689
InitPlugins();
8790

@@ -229,6 +232,8 @@ int TrpcPlugin::UnregisterPlugins() {
229232

230233
StopPlugins();
231234

235+
overload_control::Stop();
236+
232237
PeripheryTaskScheduler::GetInstance()->Stop();
233238
PeripheryTaskScheduler::GetInstance()->Join();
234239

@@ -529,6 +534,8 @@ void TrpcPlugin::DestroyResource() {
529534

530535
log::Destroy();
531536

537+
overload_control::Destroy();
538+
532539
GetTrpcClient()->Destroy();
533540

534541
is_all_inited_ = false;

trpc/overload_control/BUILD

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,40 @@ cc_library(
1212
}),
1313
visibility = ["//visibility:public"],
1414
)
15+
16+
cc_library(
17+
name = "server_overload_controller",
18+
hdrs = ["server_overload_controller.h"],
19+
deps = [
20+
"//trpc/server:server_context",
21+
],
22+
)
23+
24+
cc_library(
25+
name = "server_overload_controller_factory",
26+
hdrs = ["server_overload_controller_factory.h"],
27+
deps = [
28+
":server_overload_controller",
29+
"//trpc/overload_control/common:overload_control_factory",
30+
],
31+
)
32+
33+
cc_test(
34+
name = "server_overload_controller_factory_test",
35+
srcs = ["server_overload_controller_factory_test.cc"],
36+
deps = [
37+
":server_overload_controller_factory",
38+
"//trpc/overload_control/testing:overload_control_testing",
39+
"@com_google_googletest//:gtest_main",
40+
],
41+
)
42+
43+
cc_library(
44+
name = "trpc_overload_control",
45+
srcs = ["trpc_overload_control.cc"],
46+
hdrs = ["trpc_overload_control.h"],
47+
deps = [
48+
":server_overload_controller_factory",
49+
"//trpc/filter:filter_manager",
50+
],
51+
)

trpc/overload_control/common/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,11 @@ cc_test(
188188
"@com_google_googletest//:gtest_main",
189189
],
190190
)
191+
192+
cc_library(
193+
name = "overload_control_factory",
194+
hdrs = ["overload_control_factory.h"],
195+
deps = [
196+
"//trpc/log:trpc_log",
197+
],
198+
)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) 2024, Tencent Inc.
2+
// All rights reserved.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <unordered_map>
8+
9+
#include "trpc/log/trpc_log.h"
10+
11+
namespace trpc::overload_control {
12+
13+
/// @brief Factory of overload control strategy(as template T).
14+
template <class T>
15+
class OverloadControlFactory {
16+
public:
17+
/// @brief Singleton
18+
static OverloadControlFactory* GetInstance() {
19+
static OverloadControlFactory instance;
20+
return &instance;
21+
}
22+
23+
/// @brief Can't construct by user.
24+
OverloadControlFactory(const OverloadControlFactory&) = delete;
25+
OverloadControlFactory& operator=(const OverloadControlFactory&) = delete;
26+
27+
/// @brief Register the overload control strategy.
28+
/// @param obj strategy
29+
/// @note Non-thread-safe
30+
bool Register(const T& obj);
31+
32+
/// @brief Get the overload control strategy by name
33+
/// @param name name of strategy
34+
/// @return strategy
35+
T Get(const std::string& name);
36+
37+
/// @brief Get number of strategies.
38+
/// @return number of strategies
39+
/// @note Non-thread-safe
40+
size_t Size() const { return objs_map_.size(); }
41+
42+
/// @brief Stop all of overload control strategies.
43+
// Mainly used to stop inner thread createdy by each strategy
44+
/// @note Non-thread-safe.
45+
void Stop();
46+
47+
/// @brief Destroy resource of overload control strategies.
48+
/// @note Non-thread-safe.
49+
void Destroy();
50+
51+
/// @brief Clear overload control strategies in this factory.
52+
/// @note Non-thread-safe.
53+
void Clear() { objs_map_.clear(); }
54+
55+
private:
56+
OverloadControlFactory() = default;
57+
58+
private:
59+
// strategies mapping(name->stratege obj)
60+
std::unordered_map<std::string, T> objs_map_;
61+
};
62+
63+
template <class T>
64+
bool OverloadControlFactory<T>::Register(const T& obj) {
65+
if (!obj) {
66+
TRPC_FMT_ERROR("register object is nullptr");
67+
return false;
68+
}
69+
if (Get(obj->Name())) {
70+
return false;
71+
}
72+
if (obj->Init()) {
73+
objs_map_.emplace(obj->Name(), obj);
74+
return true;
75+
}
76+
TRPC_FMT_ERROR("{} is `Init` failed ", obj->Name());
77+
return false;
78+
}
79+
80+
template <class T>
81+
T OverloadControlFactory<T>::Get(const std::string& name) {
82+
T obj = nullptr;
83+
auto iter = objs_map_.find(name);
84+
if (iter != objs_map_.end()) {
85+
obj = iter->second;
86+
}
87+
return obj;
88+
}
89+
90+
template <class T>
91+
void OverloadControlFactory<T>::Stop() {
92+
for (auto& obj : objs_map_) {
93+
obj.second->Stop();
94+
}
95+
}
96+
97+
template <class T>
98+
void OverloadControlFactory<T>::Destroy() {
99+
for (auto& obj : objs_map_) {
100+
obj.second->Destroy();
101+
}
102+
Clear();
103+
}
104+
105+
} // namespace trpc::overload_control
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2024, Tencent Inc.
2+
// All rights reserved.
3+
4+
#pragma once
5+
6+
#include <memory>
7+
#include <string>
8+
9+
#include "trpc/server/server_context.h"
10+
11+
namespace trpc::overload_control {
12+
13+
/// @brief Base class of overload controller.
14+
class ServerOverloadController {
15+
public:
16+
virtual ~ServerOverloadController() = default;
17+
18+
/// @brief Name of this controller.
19+
virtual std::string Name() = 0;
20+
21+
/// @brief Initialize controller.
22+
/// You can allocate resources or start thread as controller need.
23+
/// @return bool true: succ; false: failed
24+
virtual bool Init() { return true; }
25+
26+
/// @brief Whether this request should be scheduled to handle.
27+
/// When reject this request, you should also set status with error code TRPC_SERVER_OVERLOAD_ERR
28+
/// into context.
29+
/// @param context server context.
30+
/// @return bool true: this request will be handled; false: this request should be rejected.
31+
virtual bool BeforeSchedule(const ServerContextPtr& context) = 0;
32+
33+
/// @brief After this request being sheduled. At this point, it may be handled or rejected.
34+
// You can check status from context to distinguish these 2 scenes when implement.
35+
/// @param context server context.
36+
/// @return bool true: succ; false: failed.
37+
virtual bool AfterSchedule(const ServerContextPtr& context) = 0;
38+
39+
/// @brief Stop controller. One can stop the thread execution of controller implemetation.
40+
virtual void Stop() {}
41+
42+
/// @brief Destroy resources of controller.
43+
virtual void Destroy() {}
44+
};
45+
46+
using ServerOverloadControllerPtr = std::shared_ptr<ServerOverloadController>;
47+
48+
} // namespace trpc::overload_control
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Copyright (c) 2024, Tencent Inc.
2+
// All rights reserved.
3+
4+
#pragma once
5+
6+
#include <string>
7+
#include <unordered_map>
8+
9+
#include "trpc/overload_control/common/overload_control_factory.h"
10+
#include "trpc/overload_control/server_overload_controller.h"
11+
12+
namespace trpc::overload_control {
13+
14+
using ServerOverloadControllerFactory = OverloadControlFactory<ServerOverloadControllerPtr>;
15+
16+
} // namespace trpc::overload_control
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (c) 2024, Tencent Inc.
2+
// All rights reserved.
3+
4+
#include "trpc/overload_control/server_overload_controller_factory.h"
5+
6+
#include "gmock/gmock.h"
7+
#include "gtest/gtest.h"
8+
9+
#include "trpc/log/trpc_log.h"
10+
#include "trpc/overload_control/testing/overload_control_testing.h"
11+
12+
namespace trpc::overload_control {
13+
14+
namespace testing {
15+
16+
TEST(ServerOverloadControllerFactory, All) {
17+
// Testing register interface
18+
{
19+
// 1. Register nullptr, failed.
20+
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(nullptr));
21+
// 2. First time register, succ.
22+
ServerOverloadControllerPtr controller = std::make_shared<MockServerOverloadController>();
23+
MockServerOverloadController* mock_controller = static_cast<MockServerOverloadController*>(controller.get());
24+
EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(false));
25+
EXPECT_CALL(*mock_controller, Name()).WillRepeatedly(::testing::Return(std::string("mock_controller")));
26+
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));
27+
28+
EXPECT_CALL(*mock_controller, Init()).WillOnce(::testing::Return(true));
29+
ASSERT_TRUE(ServerOverloadControllerFactory::GetInstance()->Register(controller));
30+
// 3. Duplicated register, failed.
31+
ASSERT_FALSE(ServerOverloadControllerFactory::GetInstance()->Register(controller));
32+
33+
auto size = ServerOverloadControllerFactory::GetInstance()->Size();
34+
ASSERT_EQ(size, 1);
35+
}
36+
// Testing get interface
37+
{
38+
ServerOverloadControllerPtr controller = ServerOverloadControllerFactory::GetInstance()->Get("xxx");
39+
ASSERT_EQ(controller, nullptr);
40+
controller = ServerOverloadControllerFactory::GetInstance()->Get("mock_controller");
41+
ASSERT_NE(controller, nullptr);
42+
}
43+
44+
// Testing series of cleaning interface
45+
{
46+
ServerOverloadControllerFactory::GetInstance()->Stop();
47+
ServerOverloadControllerFactory::GetInstance()->Destroy();
48+
auto size = ServerOverloadControllerFactory::GetInstance()->Size();
49+
ASSERT_EQ(size, 0);
50+
}
51+
}
52+
53+
} // namespace testing
54+
55+
} // namespace trpc::overload_control
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
licenses(["notice"])
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
cc_library(
6+
name = "overload_control_testing",
7+
hdrs = ["overload_control_testing.h"],
8+
visibility = ["//visibility:public"],
9+
deps = [
10+
"//trpc/codec:protocol",
11+
"//trpc/coroutine:fiber",
12+
"//trpc/coroutine/testing:fiber_runtime_test",
13+
"//trpc/filter:filter_manager",
14+
"//trpc/overload_control:server_overload_controller",
15+
"//trpc/server:service",
16+
"//trpc/server/testing:service_adapter_testing",
17+
"@com_google_googletest//:gtest_main",
18+
],
19+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) 2024, Tencent Inc.
2+
// All rights reserved.
3+
4+
#pragma once
5+
6+
#include <atomic>
7+
8+
#include "gmock/gmock.h"
9+
#include "gtest/gtest.h"
10+
11+
#include "trpc/codec/protocol.h"
12+
#include "trpc/coroutine/fiber.h"
13+
#include "trpc/coroutine/fiber_latch.h"
14+
#include "trpc/coroutine/testing/fiber_runtime.h"
15+
#include "trpc/filter/filter_manager.h"
16+
#include "trpc/overload_control/server_overload_controller.h"
17+
#include "trpc/server/service.h"
18+
#include "trpc/server/testing/service_adapter_testing.h"
19+
20+
namespace trpc::overload_control {
21+
namespace testing {
22+
23+
// Mock protocol, only allowed to be used at overload control module.
24+
class MockProtocol : public Protocol {
25+
public:
26+
MOCK_METHOD1(ZeroCopyDecode, bool(NoncontiguousBuffer&));
27+
MOCK_METHOD1(ZeroCopyEncode, bool(NoncontiguousBuffer&));
28+
MOCK_METHOD1(SetCallType, void(RpcCallType));
29+
MOCK_METHOD0(GetCallType, RpcCallType());
30+
};
31+
32+
using MockProtocolPtr = std::shared_ptr<MockProtocol>;
33+
34+
// Get filter object by filter point and filter name.
35+
inline MessageServerFilterPtr GetGlobalServerFilterByName(FilterPoint type, const std::string& name) {
36+
const std::deque<MessageServerFilterPtr>& filters = FilterManager::GetInstance()->GetMessageServerGlobalFilter(type);
37+
for (auto& filter : filters) {
38+
if (!filter->Name().compare(name)) {
39+
return filter;
40+
}
41+
}
42+
return nullptr;
43+
}
44+
45+
class MockServerOverloadController : public ServerOverloadController {
46+
public:
47+
MOCK_METHOD0(Name, std::string());
48+
MOCK_METHOD0(Init, bool());
49+
MOCK_METHOD1(BeforeSchedule, bool(const ServerContextPtr&));
50+
MOCK_METHOD1(AfterSchedule, bool(const ServerContextPtr&));
51+
};
52+
53+
} // namespace testing
54+
} // namespace trpc::overload_control

0 commit comments

Comments
 (0)