Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions lldb/include/lldb/Protocol/MCP/Server.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef LLDB_PROTOCOL_MCP_SERVER_H
#define LLDB_PROTOCOL_MCP_SERVER_H

#include "lldb/Host/JSONTransport.h"
#include "lldb/Host/MainLoop.h"
#include "lldb/Protocol/MCP/Protocol.h"
#include "lldb/Protocol/MCP/Resource.h"
#include "lldb/Protocol/MCP/Tool.h"
Expand All @@ -18,26 +20,52 @@

namespace lldb_protocol::mcp {

class Server {
class MCPTransport final
: public lldb_private::JSONRPCTransport<Request, Response, Notification> {
public:
Server(std::string name, std::string version);
virtual ~Server() = default;
using LogCallback = std::function<void(llvm::StringRef message)>;

MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out,
std::string client_name, LogCallback log_callback = {})
: JSONRPCTransport(in, out), m_client_name(std::move(client_name)),
m_log_callback(log_callback) {}
virtual ~MCPTransport() = default;

void Log(llvm::StringRef message) override {
if (m_log_callback)
m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str());
}

private:
std::string m_client_name;
LogCallback m_log_callback;
};

class Server : public MCPTransport::MessageHandler {
public:
Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop);
~Server() = default;

using NotificationHandler = std::function<void(const Notification &)>;

void AddTool(std::unique_ptr<Tool> tool);
void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider);
void AddNotificationHandler(llvm::StringRef method,
NotificationHandler handler);

llvm::Error Run();

protected:
virtual Capabilities GetCapabilities() = 0;
Capabilities GetCapabilities();

using RequestHandler =
std::function<llvm::Expected<Response>(const Request &)>;
using NotificationHandler = std::function<void(const Notification &)>;

void AddRequestHandlers();

void AddRequestHandler(llvm::StringRef method, RequestHandler handler);
void AddNotificationHandler(llvm::StringRef method,
NotificationHandler handler);

llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data);

Expand All @@ -52,12 +80,23 @@ class Server {
llvm::Expected<Response> ResourcesListHandler(const Request &);
llvm::Expected<Response> ResourcesReadHandler(const Request &);

void Received(const Request &) override;
void Received(const Response &) override;
void Received(const Notification &) override;
void OnError(llvm::Error) override;
void OnClosed() override;

void TerminateLoop();

std::mutex m_mutex;

private:
const std::string m_name;
const std::string m_version;

std::unique_ptr<MCPTransport> m_transport_up;
lldb_private::MainLoop &m_loop;

llvm::StringMap<std::unique_ptr<Tool>> m_tools;
std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers;

Expand Down
106 changes: 28 additions & 78 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,10 @@ using namespace llvm;

LLDB_PLUGIN_DEFINE(ProtocolServerMCP)

static constexpr size_t kChunkSize = 1024;
static constexpr llvm::StringLiteral kName = "lldb-mcp";
static constexpr llvm::StringLiteral kVersion = "0.1.0";

ProtocolServerMCP::ProtocolServerMCP()
: ProtocolServer(),
lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) {
AddNotificationHandler("notifications/initialized",
[](const lldb_protocol::mcp::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host),
"MCP initialization complete");
});

AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));

AddResourceProvider(std::make_unique<DebuggerResourceProvider>());
}
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {}

ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }

Expand All @@ -64,57 +50,37 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
return "MCP Server.";
}

void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const {
server.AddNotificationHandler("notifications/initialized",
[](const lldb_protocol::mcp::Notification &) {
LLDB_LOG(GetLog(LLDBLog::Host),
"MCP initialization complete");
});
server.AddTool(
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
server.AddResourceProvider(std::make_unique<DebuggerResourceProvider>());
}

void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected",
m_clients.size() + 1);
Log *log = GetLog(LLDBLog::Host);
std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1);
LLDB_LOG(log, "New MCP client connected: {0}", client_name);

lldb::IOObjectSP io_sp = std::move(socket);
auto client_up = std::make_unique<Client>();
client_up->io_sp = io_sp;
Client *client = client_up.get();

Status status;
auto read_handle_up = m_loop.RegisterReadObject(
io_sp,
[this, client](MainLoopBase &loop) {
if (llvm::Error error = ReadCallback(*client)) {
LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}");
client->read_handle_up.reset();
}
},
status);
if (status.Fail())
auto transport_up = std::make_unique<lldb_protocol::mcp::MCPTransport>(
io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) {
LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message);
});
auto instance_up = std::make_unique<lldb_protocol::mcp::Server>(
std::string(kName), std::string(kVersion), std::move(transport_up),
m_loop);
Extend(*instance_up);
llvm::Error error = instance_up->Run();
if (error) {
LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}");
return;

client_up->read_handle_up = std::move(read_handle_up);
m_clients.emplace_back(std::move(client_up));
}

llvm::Error ProtocolServerMCP::ReadCallback(Client &client) {
char chunk[kChunkSize];
size_t bytes_read = sizeof(chunk);
if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail())
return status.takeError();
client.buffer.append(chunk, bytes_read);

for (std::string::size_type pos;
(pos = client.buffer.find('\n')) != std::string::npos;) {
llvm::Expected<std::optional<lldb_protocol::mcp::Message>> message =
HandleData(StringRef(client.buffer.data(), pos));
client.buffer = client.buffer.erase(0, pos + 1);
if (!message)
return message.takeError();

if (*message) {
std::string Output;
llvm::raw_string_ostream OS(Output);
OS << llvm::formatv("{0}", toJSON(**message)) << '\n';
size_t num_bytes = Output.size();
return client.io_sp->Write(Output.data(), num_bytes).takeError();
}
}

return llvm::Error::success();
m_instances.push_back(std::move(instance_up));
}

llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
Expand Down Expand Up @@ -158,27 +124,11 @@ llvm::Error ProtocolServerMCP::Stop() {

// Stop the main loop.
m_loop.AddPendingCallback(
[](MainLoopBase &loop) { loop.RequestTermination(); });
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });

// Wait for the main loop to exit.
if (m_loop_thread.joinable())
m_loop_thread.join();

{
std::lock_guard<std::mutex> guard(m_mutex);
m_listener.reset();
m_listen_handlers.clear();
m_clients.clear();
}

return llvm::Error::success();
}

lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() {
lldb_protocol::mcp::Capabilities capabilities;
capabilities.tools.listChanged = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
capabilities.resources.listChanged = false;
return capabilities;
}
23 changes: 10 additions & 13 deletions lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@

namespace lldb_private::mcp {

class ProtocolServerMCP : public ProtocolServer,
public lldb_protocol::mcp::Server {
class ProtocolServerMCP : public ProtocolServer {
public:
ProtocolServerMCP();
virtual ~ProtocolServerMCP() override;
Expand All @@ -39,26 +38,24 @@ class ProtocolServerMCP : public ProtocolServer,

Socket *GetSocket() const override { return m_listener.get(); }

protected:
// This adds tools and resource providers that
// are specific to this server. Overridable by the unit tests.
virtual void Extend(lldb_protocol::mcp::Server &server) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does extend describe what we're extending here?


private:
void AcceptCallback(std::unique_ptr<Socket> socket);

lldb_protocol::mcp::Capabilities GetCapabilities() override;

bool m_running = false;

MainLoop m_loop;
lldb_private::MainLoop m_loop;
std::thread m_loop_thread;
std::mutex m_mutex;

std::unique_ptr<Socket> m_listener;
std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;

struct Client {
lldb::IOObjectSP io_sp;
MainLoopBase::ReadHandleUP read_handle_up;
std::string buffer;
};
llvm::Error ReadCallback(Client &client);
std::vector<std::unique_ptr<Client>> m_clients;
std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers;
std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances;
};
} // namespace lldb_private::mcp

Expand Down
75 changes: 73 additions & 2 deletions lldb/source/Protocol/MCP/Server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
using namespace lldb_protocol::mcp;
using namespace llvm;

Server::Server(std::string name, std::string version)
: m_name(std::move(name)), m_version(std::move(version)) {
Server::Server(std::string name, std::string version,
std::unique_ptr<MCPTransport> transport_up,
lldb_private::MainLoop &loop)
: m_name(std::move(name)), m_version(std::move(version)),
m_transport_up(std::move(transport_up)), m_loop(loop) {
AddRequestHandlers();
}

Expand Down Expand Up @@ -232,3 +235,71 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
MCPError::kResourceNotFound);
}

Capabilities Server::GetCapabilities() {
lldb_protocol::mcp::Capabilities capabilities;
capabilities.tools.listChanged = true;
// FIXME: Support sending notifications when a debugger/target are
// added/removed.
capabilities.resources.listChanged = false;
return capabilities;
}

llvm::Error Server::Run() {
auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this);
if (!handle)
return handle.takeError();

lldb_private::Status status = m_loop.Run();
if (status.Fail())
return status.takeError();

return llvm::Error::success();
}

void Server::Received(const Request &request) {
auto SendResponse = [this](const Response &response) {
if (llvm::Error error = m_transport_up->Send(response))
m_transport_up->Log(llvm::toString(std::move(error)));
};

llvm::Expected<Response> response = Handle(request);
if (response)
return SendResponse(*response);

lldb_protocol::mcp::Error protocol_error;
llvm::handleAllErrors(
response.takeError(),
[&](const MCPError &err) { protocol_error = err.toProtocolError(); },
[&](const llvm::ErrorInfoBase &err) {
protocol_error.code = MCPError::kInternalError;
protocol_error.message = err.message();
});
Response error_response;
error_response.id = request.id;
error_response.result = std::move(protocol_error);
SendResponse(error_response);
}

void Server::Received(const Response &response) {
m_transport_up->Log("unexpected MCP message: response");
}

void Server::Received(const Notification &notification) {
Handle(notification);
}

void Server::OnError(llvm::Error error) {
m_transport_up->Log(llvm::toString(std::move(error)));
TerminateLoop();
}

void Server::OnClosed() {
m_transport_up->Log("EOF");
TerminateLoop();
}

void Server::TerminateLoop() {
m_loop.AddPendingCallback(
[](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); });
}
Loading
Loading