Skip to content

[lldb] Make MCP server instance global #145616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions lldb/include/lldb/Core/Debugger.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
void FlushProcessOutput(Process &process, bool flush_stdout,
bool flush_stderr);

void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const;

SourceManager::SourceFileCache &GetSourceFileCache() {
return m_source_file_cache;
}
Expand Down Expand Up @@ -776,8 +772,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
mutable std::mutex m_progress_reports_mutex;
/// @}

llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers;

std::mutex m_destroy_callback_mutex;
lldb::callback_token_t m_destroy_callback_next_token = 0;
struct DestroyCallbackInfo {
Expand Down
5 changes: 3 additions & 2 deletions lldb/include/lldb/Core/ProtocolServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface {
ProtocolServer() = default;
virtual ~ProtocolServer() = default;

static lldb::ProtocolServerSP Create(llvm::StringRef name,
Debugger &debugger);
static ProtocolServer *GetOrCreate(llvm::StringRef name);

static std::vector<llvm::StringRef> GetSupportedProtocols();

struct Connection {
Socket::SocketProtocol protocol;
Expand Down
2 changes: 1 addition & 1 deletion lldb/include/lldb/lldb-forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP;
typedef std::shared_ptr<lldb_private::Process> ProcessSP;
typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP;
typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP;
typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP;
typedef std::unique_ptr<lldb_private::ProtocolServer> ProtocolServerUP;
typedef std::weak_ptr<lldb_private::Process> ProcessWP;
typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP;
typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP;
Expand Down
3 changes: 1 addition & 2 deletions lldb/include/lldb/lldb-private-interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force,
typedef lldb::ProcessSP (*ProcessCreateInstance)(
lldb::TargetSP target_sp, lldb::ListenerSP listener_sp,
const FileSpec *crash_file_path, bool can_connect);
typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)(
Debugger &debugger);
typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)();
typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)(
Target &target);
typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)(
Expand Down
51 changes: 9 additions & 42 deletions lldb/source/Commands/CommandObjectProtocolServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,6 @@ using namespace lldb_private;
#define LLDB_OPTIONS_mcp
#include "CommandOptions.inc"

static std::vector<llvm::StringRef> GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;

for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}

return supported_protocols;
}

class CommandObjectProtocolServerStart : public CommandObjectParsed {
public:
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
Expand All @@ -57,12 +43,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
}

llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}

Expand All @@ -72,10 +57,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
}
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);

ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
if (!server_sp)
server_sp = ProtocolServer::Create(protocol, GetDebugger());

const char *connection_error =
"unsupported connection specifier, expected 'accept:///path' or "
"'listen://[host]:port', got '{0}'.";
Expand All @@ -98,14 +79,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
uri->port.value_or(0));

if (llvm::Error error = server_sp->Start(connection)) {
if (llvm::Error error = server->Start(connection)) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}

GetDebugger().AddProtocolServer(server_sp);

if (Socket *socket = server_sp->GetSocket()) {
if (Socket *socket = server->GetSocket()) {
std::string address =
llvm::join(socket->GetListeningConnectionURI(), ", ");
result.AppendMessageWithFormatv(
Expand Down Expand Up @@ -134,30 +113,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed {
}

llvm::StringRef protocol = args.GetArgumentAtIndex(0);
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
if (llvm::find(supported_protocols, protocol) ==
supported_protocols.end()) {
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
if (!server) {
result.AppendErrorWithFormatv(
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
llvm::join(GetSupportedProtocols(), ", "));
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
return;
}

Debugger &debugger = GetDebugger();

ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
if (!server_sp) {
result.AppendError(
llvm::formatv("no {0} protocol server running", protocol).str());
return;
}

if (llvm::Error error = server_sp->Stop()) {
if (llvm::Error error = server->Stop()) {
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
return;
}

debugger.RemoveProtocolServer(server_sp);
}
};

Expand Down
23 changes: 0 additions & 23 deletions lldb/source/Core/Debugger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
"Debugger::GetThreadPool called before Debugger::Initialize");
return *g_thread_pool;
}

void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
assert(protocol_server_sp &&
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
m_protocol_servers.push_back(protocol_server_sp);
}

void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
if (it != m_protocol_servers.end())
m_protocol_servers.erase(it);
}

lldb::ProtocolServerSP
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
if (!protocol_server_sp)
continue;
if (protocol_server_sp->GetPluginName() == protocol)
return protocol_server_sp;
}
return nullptr;
}
34 changes: 30 additions & 4 deletions lldb/source/Core/ProtocolServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,36 @@
using namespace lldb_private;
using namespace lldb;

ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
Debugger &debugger) {
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
static std::mutex g_mutex;
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;

std::lock_guard<std::mutex> guard(g_mutex);

auto it = g_protocol_server_instances.find(name);
if (it != g_protocol_server_instances.end())
return it->second.get();

if (ProtocolServerCreateInstance create_callback =
PluginManager::GetProtocolCreateCallbackForPluginName(name))
return create_callback(debugger);
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
auto pair =
g_protocol_server_instances.try_emplace(name, create_callback());
return pair.first->second.get();
}

return nullptr;
}

std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
std::vector<llvm::StringRef> supported_protocols;
size_t i = 0;

for (llvm::StringRef protocol_name =
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
!protocol_name.empty();
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
supported_protocols.push_back(protocol_name);
}

return supported_protocols;
}
2 changes: 2 additions & 0 deletions lldb/source/Plugins/Protocol/MCP/Protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>;
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
llvm::json::Value toJSON(const Message &);

using ToolArguments = std::variant<std::monostate, llvm::json::Value>;

} // namespace lldb_private::mcp::protocol

#endif
30 changes: 17 additions & 13 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)

static constexpr size_t kChunkSize = 1024;

ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
: ProtocolServer(), m_debugger(debugger) {
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
AddRequestHandler("initialize",
std::bind(&ProtocolServerMCP::InitializeHandler, this,
std::placeholders::_1));
Expand All @@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
"notifications/initialized", [](const protocol::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
});
AddTool(std::make_unique<LLDBCommandTool>(
"lldb_command", "Run an lldb command.", m_debugger));
AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
AddTool(std::make_unique<DebuggerListTool>(
"lldb_debugger_list", "List debugger instances with their debugger_id."));
}

ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
Expand All @@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
PluginManager::UnregisterPlugin(CreateInstance);
}

lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
return std::make_shared<ProtocolServerMCP>(debugger);
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
return std::make_unique<ProtocolServerMCP>();
}

llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
Expand Down Expand Up @@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
std::lock_guard<std::mutex> guard(m_server_mutex);

if (m_running)
return llvm::createStringError("server already running");
return llvm::createStringError("the MCP server is already running");

Status status;
m_listener = Socket::Create(connection.protocol, status);
Expand All @@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
if (llvm::Error error = handles.takeError())
return error;

m_running = true;
m_listen_handlers = std::move(*handles);
m_loop_thread = std::thread([=] {
llvm::set_thread_name(
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
llvm::set_thread_name("protocol-server.mcp");
m_loop.Run();
});

Expand All @@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
llvm::Error ProtocolServerMCP::Stop() {
{
std::lock_guard<std::mutex> guard(m_server_mutex);
if (!m_running)
return createStringError("the MCP sever is not running");
m_running = false;
}

Expand Down Expand Up @@ -311,11 +314,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
if (it == m_tools.end())
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));

const json::Value *args = param_obj->get("arguments");
if (!args)
return llvm::createStringError("no tool arguments");
protocol::ToolArguments tool_args;
if (const json::Value *args = param_obj->get("arguments"))
tool_args = *args;

llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
llvm::Expected<protocol::TextResult> text_result =
it->second->Call(tool_args);
if (!text_result)
return text_result.takeError();

Expand Down
6 changes: 2 additions & 4 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace lldb_private::mcp {

class ProtocolServerMCP : public ProtocolServer {
public:
ProtocolServerMCP(Debugger &debugger);
ProtocolServerMCP();
virtual ~ProtocolServerMCP() override;

virtual llvm::Error Start(ProtocolServer::Connection connection) override;
Expand All @@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer {
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
static llvm::StringRef GetPluginDescriptionStatic();

static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
static lldb::ProtocolServerUP CreateInstance();

llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }

Expand Down Expand Up @@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer {
llvm::StringLiteral kName = "lldb-mcp";
llvm::StringLiteral kVersion = "0.1.0";

Debugger &m_debugger;

bool m_running = false;

MainLoop m_loop;
Expand Down
Loading
Loading