diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 72f4404c92887..0be60a8f3f96a 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -13,29 +13,25 @@ #ifndef LLDB_HOST_JSONTRANSPORT_H #define LLDB_HOST_JSONTRANSPORT_H +#include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include #include +#include #include namespace lldb_private { -class TransportEOFError : public llvm::ErrorInfo { -public: - static char ID; - - TransportEOFError() = default; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; -}; - class TransportUnhandledContentsError : public llvm::ErrorInfo { public: @@ -54,112 +50,214 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; -class TransportInvalidError : public llvm::ErrorInfo { +/// A transport is responsible for maintaining the connection to a client +/// application, and reading/writing structured messages to it. +/// +/// Transports have limited thread safety requirements: +/// - Messages will not be sent concurrently. +/// - Messages MAY be sent while Run() is reading, or its callback is active. +template class Transport { public: - static char ID; - - TransportInvalidError() = default; + using Message = std::variant; + + virtual ~Transport() = default; + + /// Sends an event, a message that does not require a response. + virtual llvm::Error Send(const Evt &) = 0; + /// Sends a request, a message that expects a response. + virtual llvm::Error Send(const Req &) = 0; + /// Sends a response to a specific request. + virtual llvm::Error Send(const Resp &) = 0; + + /// Implemented to handle incoming messages. (See Run() below). + class MessageHandler { + public: + virtual ~MessageHandler() = default; + /// Called when an event is received. + virtual void Received(const Evt &) = 0; + /// Called when a request is received. + virtual void Received(const Req &) = 0; + /// Called when a response is received. + virtual void Received(const Resp &) = 0; + + /// Called when an error occurs while reading from the transport. + /// + /// NOTE: This does *NOT* indicate that a specific request failed, but that + /// there was an error in the underlying transport. + virtual void OnError(llvm::Error) = 0; + + /// Called on EOF or client disconnect. + virtual void OnClosed() = 0; + }; + + using MessageHandlerSP = std::shared_ptr; + + /// RegisterMessageHandler registers the Transport with the given MainLoop and + /// handles any incoming messages using the given MessageHandler. + /// + /// If an unexpected error occurs, the MainLoop will be terminated and a log + /// message will include additional information about the termination reason. + virtual llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; +protected: + template inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + } + virtual void Log(llvm::StringRef message) = 0; }; -/// A transport class that uses JSON for communication. -class JSONTransport { +/// A JSONTransport will encode and decode messages using JSON. +template +class JSONTransport : public Transport { public: - using ReadHandleUP = MainLoopBase::ReadHandleUP; - template - using Callback = std::function)>; - - JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); - virtual ~JSONTransport() = default; - - /// Transport is not copyable. - /// @{ - JSONTransport(const JSONTransport &rhs) = delete; - void operator=(const JSONTransport &rhs) = delete; - /// @} - - /// Writes a message to the output stream. - template llvm::Error Write(const T &t) { - const std::string message = llvm::formatv("{0}", toJSON(t)).str(); - return WriteImpl(message); + using Transport::Transport; + using MessageHandler = typename Transport::MessageHandler; + + JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : m_in(in), m_out(out) {} + + llvm::Error Send(const Evt &evt) override { return Write(evt); } + llvm::Error Send(const Req &req) override { return Write(req); } + llvm::Error Send(const Resp &resp) override { return Write(resp); } + + llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { + Status status; + MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( + m_in, + std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::ref(handler)), + status); + if (status.Fail()) { + return status.takeError(); + } + return read_handle; } - /// Registers the transport with the MainLoop. - template - llvm::Expected RegisterReadObject(MainLoopBase &loop, - Callback read_cb) { - Status error; - ReadHandleUP handle = loop.RegisterReadObject( - m_input, - [read_cb, this](MainLoopBase &loop) { - char buf[kReadBufferSize]; - size_t num_bytes = sizeof(buf); - if (llvm::Error error = m_input->Read(buf, num_bytes).takeError()) { - read_cb(loop, std::move(error)); - return; - } - if (num_bytes) - m_buffer.append(std::string(buf, num_bytes)); - - // If the buffer has contents, try parsing any pending messages. - if (!m_buffer.empty()) { - llvm::Expected> messages = Parse(); - if (llvm::Error error = messages.takeError()) { - read_cb(loop, std::move(error)); - return; - } - - for (const auto &message : *messages) - if constexpr (std::is_same::value) - read_cb(loop, message); - else - read_cb(loop, llvm::json::parse(message)); - } - - // On EOF, notify the callback after the remaining messages were - // handled. - if (num_bytes == 0) { - if (m_buffer.empty()) - read_cb(loop, llvm::make_error()); - else - read_cb(loop, llvm::make_error( - std::string(m_buffer))); - } - }, - error); - if (error.Fail()) - return error.takeError(); - return handle; - } + /// Public for testing purposes, otherwise this should be an implementation + /// detail. + static constexpr size_t kReadBufferSize = 1024; protected: - template inline auto Logv(const char *Fmt, Ts &&...Vals) { - Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + virtual llvm::Expected> Parse() = 0; + virtual std::string Encode(const llvm::json::Value &message) = 0; + llvm::Error Write(const llvm::json::Value &message) { + this->Logv("<-- {0}", message); + std::string output = Encode(message); + size_t bytes_written = output.size(); + return m_out->Write(output.data(), bytes_written).takeError(); } - virtual void Log(llvm::StringRef message); - virtual llvm::Error WriteImpl(const std::string &message) = 0; - virtual llvm::Expected> Parse() = 0; + llvm::SmallString m_buffer; - static constexpr size_t kReadBufferSize = 1024; +private: + void OnRead(MainLoopBase &loop, MessageHandler &handler) { + char buf[kReadBufferSize]; + size_t num_bytes = sizeof(buf); + if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { + handler.OnError(status.takeError()); + return; + } + + if (num_bytes) + m_buffer.append(llvm::StringRef(buf, num_bytes)); + + // If the buffer has contents, try parsing any pending messages. + if (!m_buffer.empty()) { + llvm::Expected> raw_messages = Parse(); + if (llvm::Error error = raw_messages.takeError()) { + handler.OnError(std::move(error)); + return; + } + + for (const std::string &raw_message : *raw_messages) { + llvm::Expected::Message> message = + llvm::json::parse::Message>( + raw_message); + if (!message) { + handler.OnError(message.takeError()); + return; + } + + std::visit([&handler](auto &&msg) { handler.Received(msg); }, *message); + } + } + + // Check if we reached EOF. + if (num_bytes == 0) { + // EOF reached, but there may still be unhandled contents in the buffer. + if (!m_buffer.empty()) + handler.OnError(llvm::make_error( + std::string(m_buffer.str()))); + handler.OnClosed(); + } + } - lldb::IOObjectSP m_input; - lldb::IOObjectSP m_output; - llvm::SmallString m_buffer; + lldb::IOObjectSP m_in; + lldb::IOObjectSP m_out; }; /// A transport class for JSON with a HTTP header. -class HTTPDelimitedJSONTransport : public JSONTransport { +template +class HTTPDelimitedJSONTransport : public JSONTransport { public: - HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~HTTPDelimitedJSONTransport() = default; + using JSONTransport::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected> Parse() override; + /// Encodes messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + std::string Encode(const llvm::json::Value &message) override { + std::string output; + std::string raw_message = llvm::formatv("{0}", message).str(); + llvm::raw_string_ostream OS(output); + OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' + << std::to_string(raw_message.size()) << kEndOfHeader << raw_message; + return output; + } + + /// Parses messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buffer = this->m_buffer; + while (buffer.contains(kEndOfHeader)) { + auto [headers, rest] = buffer.split(kEndOfHeader); + size_t content_length = 0; + // HTTP Headers are formatted like ` ':' []`. + for (const llvm::StringRef &header : + llvm::split(headers, kHeaderSeparator)) { + auto [key, value] = header.split(kHeaderFieldSeparator); + // 'Content-Length' is the only meaningful key at the moment. Others are + // ignored. + if (!key.equals_insensitive(kHeaderContentLength)) + continue; + + value = value.trim(); + if (!llvm::to_integer(value, content_length, 10)) { + // Clear the buffer to avoid re-parsing this malformed message. + this->m_buffer.clear(); + return llvm::createStringError(std::errc::invalid_argument, + "invalid content length: %s", + value.str().c_str()); + } + } + + // Check if we have enough data. + if (content_length > rest.size()) + break; + + llvm::StringRef body = rest.take_front(content_length); + buffer = rest.drop_front(content_length); + messages.emplace_back(body.str()); + this->Logv("--> {0}", body); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buffer.str(); + + return std::move(messages); + } static constexpr llvm::StringLiteral kHeaderContentLength = "Content-Length"; static constexpr llvm::StringLiteral kHeaderFieldSeparator = ":"; @@ -168,15 +266,31 @@ class HTTPDelimitedJSONTransport : public JSONTransport { }; /// A transport class for JSON RPC. -class JSONRPCTransport : public JSONTransport { +template +class JSONRPCTransport : public JSONTransport { public: - JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~JSONRPCTransport() = default; + using JSONTransport::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected> Parse() override; + std::string Encode(const llvm::json::Value &message) override { + return llvm::formatv("{0}{1}", message, kMessageSeparator).str(); + } + + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buf = this->m_buffer; + while (buf.contains(kMessageSeparator)) { + auto [raw_json, rest] = buf.split(kMessageSeparator); + buf = rest; + messages.emplace_back(raw_json.str()); + this->Logv("--> {0}", raw_json); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buf.str(); + + return messages; + } static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 5f0fb3ce562c3..c4b42eafc85d3 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -7,136 +7,26 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" -#include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "lldb/Utility/Status.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include -#include using namespace llvm; using namespace lldb; using namespace lldb_private; -void TransportEOFError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF"; -} - -std::error_code TransportEOFError::convertToErrorCode() const { - return std::make_error_code(std::errc::io_error); -} +char TransportUnhandledContentsError::ID; TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF with unhandled contents " << m_unhandled_contents; + OS << "transport EOF with unhandled contents: '" << m_unhandled_contents + << "'"; } std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } - -void TransportInvalidError::log(llvm::raw_ostream &OS) const { - OS << "transport IO object invalid"; -} -std::error_code TransportInvalidError::convertToErrorCode() const { - return std::make_error_code(std::errc::not_connected); -} - -JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) - : m_input(std::move(input)), m_output(std::move(output)) {} - -void JSONTransport::Log(llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); -} - -// Parses messages based on -// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol -Expected> HTTPDelimitedJSONTransport::Parse() { - std::vector messages; - StringRef buffer = m_buffer; - while (buffer.contains(kEndOfHeader)) { - auto [headers, rest] = buffer.split(kEndOfHeader); - size_t content_length = 0; - // HTTP Headers are formatted like ` ':' []`. - for (const auto &header : llvm::split(headers, kHeaderSeparator)) { - auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. - if (!key.equals_insensitive(kHeaderContentLength)) - continue; - - value = value.trim(); - if (!llvm::to_integer(value, content_length, 10)) - return createStringError(std::errc::invalid_argument, - "invalid content length: %s", - value.str().c_str()); - } - - // Check if we have enough data. - if (content_length > rest.size()) - break; - - StringRef body = rest.take_front(content_length); - buffer = rest.drop_front(content_length); - messages.emplace_back(body.str()); - Logv("--> {0}", body); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buffer.str(); - - return std::move(messages); -} - -Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Logv("<-- {0}", message); - - std::string Output; - raw_string_ostream OS(Output); - OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' << message.length() - << kHeaderSeparator << kHeaderSeparator << message; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -Expected> JSONRPCTransport::Parse() { - std::vector messages; - StringRef buf = m_buffer; - while (buf.contains(kMessageSeparator)) { - auto [raw_json, rest] = buf.split(kMessageSeparator); - buf = rest; - messages.emplace_back(raw_json.str()); - Logv("--> {0}", raw_json); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buf.str(); - - return messages; -} - -Error JSONRPCTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Logv("<-- {0}", message); - - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << message << kMessageSeparator; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -char TransportEOFError::ID; -char TransportUnhandledContentsError::ID; -char TransportInvalidError::ID; diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d9b11bd766686..65ddfaee70160 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" using namespace llvm; diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index ce910b1f60b85..e51ed096073fe 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -121,11 +121,12 @@ static std::string capitalize(llvm::StringRef str) { llvm::StringRef DAP::debug_adapter_path = ""; DAP::DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport) + std::vector pre_init_commands, + llvm::StringRef client_name, DAPTransport &transport, MainLoop &loop) : log(log), transport(transport), broadcaster("lldb-dap"), progress_event_reporter( [&](const ProgressEvent &event) { SendJSON(event.ToJSON()); }), - repl_mode(default_repl_mode) { + repl_mode(default_repl_mode), m_client_name(client_name), m_loop(loop) { configuration.preInitCommands = std::move(pre_init_commands); RegisterRequests(); } @@ -258,36 +259,49 @@ void DAP::SendJSON(const llvm::json::Value &json) { llvm::json::Path::Root root; if (!fromJSON(json, message, root)) { DAP_LOG_ERROR(log, root.getError(), "({1}) encoding failed: {0}", - transport.GetClientName()); + m_client_name); return; } Send(message); } void DAP::Send(const Message &message) { - // FIXME: After all the requests have migrated from LegacyRequestHandler > - // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (auto *resp = std::get_if(&message); - resp && debugger.InterruptRequested()) { - // Clear the interrupt request. - debugger.CancelInterruptRequest(); - - // If the debugger was interrupted, convert this response into a 'cancelled' - // response because we might have a partial result. - Response cancelled{/*request_seq=*/resp->request_seq, - /*command=*/resp->command, - /*success=*/false, - /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}; - if (llvm::Error err = transport.Write(cancelled)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); + if (const protocol::Event *event = std::get_if(&message)) { + if (llvm::Error err = transport.Send(*event)) + DAP_LOG_ERROR(log, std::move(err), "({0}) sending event failed", + m_client_name); return; } - if (llvm::Error err = transport.Write(message)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); + if (const Request *req = std::get_if(&message)) { + if (llvm::Error err = transport.Send(*req)) + DAP_LOG_ERROR(log, std::move(err), "({0}) sending request failed", + m_client_name); + return; + } + + if (const Response *resp = std::get_if(&message)) { + // FIXME: After all the requests have migrated from LegacyRequestHandler > + // RequestHandler<> this should be handled in RequestHandler<>::operator(). + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + llvm::Error err = + (debugger.InterruptRequested()) + ? transport.Send({/*request_seq=*/resp->request_seq, + /*command=*/resp->command, + /*success=*/false, + /*message=*/eResponseMessageCancelled, + /*body=*/std::nullopt}) + : transport.Send(*resp); + if (err) { + DAP_LOG_ERROR(log, std::move(err), "({0}) sending response failed", + m_client_name); + return; + } + return; + } + + llvm_unreachable("Unexpected message type"); } // "OutputEvent": { @@ -755,7 +769,6 @@ void DAP::RunTerminateCommands() { } lldb::SBTarget DAP::CreateTarget(lldb::SBError &error) { - // Grab the name of the program we need to debug and create a target using // the given program as an argument. Executable file can be a source of target // architecture and platform, if they differ from the host. Setting exe path // in launch info is useless because Target.Launch() will not change @@ -795,7 +808,7 @@ void DAP::SetTarget(const lldb::SBTarget target) { bool DAP::HandleObject(const Message &M) { TelemetryDispatcher dispatcher(&debugger); - dispatcher.Set("client_name", transport.GetClientName().str()); + dispatcher.Set("client_name", m_client_name.str()); if (const auto *req = std::get_if(&M)) { { std::lock_guard guard(m_active_request_mutex); @@ -821,8 +834,8 @@ bool DAP::HandleObject(const Message &M) { dispatcher.Set("error", llvm::Twine("unhandled-command:" + req->command).str()); - DAP_LOG(log, "({0}) error: unhandled command '{1}'", - transport.GetClientName(), req->command); + DAP_LOG(log, "({0}) error: unhandled command '{1}'", m_client_name, + req->command); return false; // Fail } @@ -918,11 +931,7 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { } SendTerminatedEvent(); - - disconnecting = true; - m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); - + TerminateLoop(); return ToError(error); } @@ -938,90 +947,121 @@ void DAP::ClearCancelRequest(const CancelArguments &args) { } template -static std::optional getArgumentsIfRequest(const Message &pm, +static std::optional getArgumentsIfRequest(const Request &req, llvm::StringLiteral command) { - auto *const req = std::get_if(&pm); - if (!req || req->command != command) + if (req.command != command) return std::nullopt; T args; llvm::json::Path::Root root; - if (!fromJSON(req->arguments, args, root)) + if (!fromJSON(req.arguments, args, root)) return std::nullopt; return args; } -Status DAP::TransportHandler() { - llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); +void DAP::Received(const protocol::Event &event) { + // no-op, no supported events from the client to the server as of DAP v1.68. +} - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. - disconnecting = true; - m_queue_cv.notify_all(); - }); +void DAP::Received(const protocol::Request &request) { + if (request.command == "disconnect") + m_disconnecting = true; - Status status; - auto handle = transport.RegisterReadObject( - m_loop, - [&](MainLoopBase &loop, llvm::Expected message) { - if (message.errorIsA()) { - llvm::consumeError(message.takeError()); - loop.RequestTermination(); - return; - } + const std::optional cancel_args = + getArgumentsIfRequest(request, "cancel"); + if (cancel_args) { + { + std::lock_guard guard(m_cancelled_requests_mutex); + if (cancel_args->requestId) + m_cancelled_requests.insert(*cancel_args->requestId); + } - if (llvm::Error err = message.takeError()) { - status = Status::FromError(std::move(err)); - loop.RequestTermination(); - return; - } + // If a cancel is requested for the active request, make a best + // effort attempt to interrupt. + std::lock_guard guard(m_active_request_mutex); + if (m_active_request && cancel_args->requestId == m_active_request->seq) { + DAP_LOG(log, "({0}) interrupting inflight request (command={1} seq={2})", + m_client_name, m_active_request->command, m_active_request->seq); + debugger.RequestInterrupt(); + } + } - if (const protocol::Request *req = - std::get_if(&*message); - req && req->arguments == "disconnect") - disconnecting = true; - - const std::optional cancel_args = - getArgumentsIfRequest(*message, "cancel"); - if (cancel_args) { - { - std::lock_guard guard(m_cancelled_requests_mutex); - if (cancel_args->requestId) - m_cancelled_requests.insert(*cancel_args->requestId); - } + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + request.command, request.seq); + m_queue.push_back(request); + m_queue_cv.notify_one(); +} - // If a cancel is requested for the active request, make a best - // effort attempt to interrupt. - std::lock_guard guard(m_active_request_mutex); - if (m_active_request && - cancel_args->requestId == m_active_request->seq) { - DAP_LOG(log, - "({0}) interrupting inflight request (command={1} seq={2})", - transport.GetClientName(), m_active_request->command, - m_active_request->seq); - debugger.RequestInterrupt(); - } - } +void DAP::Received(const protocol::Response &response) { + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + response.command, response.request_seq); + m_queue.push_back(response); + m_queue_cv.notify_one(); +} + +void DAP::OnError(llvm::Error error) { + DAP_LOG_ERROR(log, std::move(error), "({1}) received error: {0}", + m_client_name); + TerminateLoop(/*failed=*/true); +} - std::lock_guard guard(m_queue_mutex); - m_queue.push_back(std::move(*message)); - m_queue_cv.notify_one(); - }); - if (auto err = handle.takeError()) - return Status::FromError(std::move(err)); - if (llvm::Error err = m_loop.Run().takeError()) - return Status::FromError(std::move(err)); - return status; +void DAP::OnClosed() { + DAP_LOG(log, "({0}) received EOF", m_client_name); + TerminateLoop(); +} + +void DAP::TerminateLoop(bool failed) { + std::lock_guard guard(m_queue_mutex); + if (m_disconnecting) + return; // Already disconnecting. + + m_error_occurred = failed; + m_disconnecting = true; + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); +} + +void DAP::TransportHandler() { + auto scope_guard = llvm::make_scope_exit([this] { + std::lock_guard guard(m_queue_mutex); + // Ensure we're marked as disconnecting when the reader exits. + m_disconnecting = true; + m_queue_cv.notify_all(); + }); + + auto handle = transport.RegisterMessageHandler(m_loop, *this); + if (!handle) { + DAP_LOG_ERROR(log, handle.takeError(), + "({1}) registering message handler failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } + + if (Status status = m_loop.Run(); status.Fail()) { + DAP_LOG_ERROR(log, status.takeError(), "({1}) MainLoop run failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } } llvm::Error DAP::Loop() { - // Can't use \a std::future because it doesn't compile on - // Windows. - std::future queue_reader = - std::async(std::launch::async, &DAP::TransportHandler, this); + { + // Reset disconnect flag once we start the loop. + std::lock_guard guard(m_queue_mutex); + m_disconnecting = false; + } + + auto thread = std::thread(std::bind(&DAP::TransportHandler, this)); - auto cleanup = llvm::make_scope_exit([&]() { + auto cleanup = llvm::make_scope_exit([this]() { + // FIXME: Merge these into the MainLoop handler. out.Stop(); err.Stop(); StopEventHandlers(); @@ -1029,9 +1069,9 @@ llvm::Error DAP::Loop() { while (true) { std::unique_lock lock(m_queue_mutex); - m_queue_cv.wait(lock, [&] { return disconnecting || !m_queue.empty(); }); + m_queue_cv.wait(lock, [&] { return m_disconnecting || !m_queue.empty(); }); - if (disconnecting && m_queue.empty()) + if (m_disconnecting && m_queue.empty()) break; Message next = m_queue.front(); @@ -1045,7 +1085,15 @@ llvm::Error DAP::Loop() { "unhandled packet"); } - return queue_reader.get().takeError(); + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + thread.join(); + + if (m_error_occurred) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "DAP Loop terminated due to an internal " + "error, see DAP Logs for more information."); + return llvm::Error::success(); } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { @@ -1284,7 +1332,7 @@ void DAP::ProgressEventThread() { // them prevent multiple threads from writing simultaneously so no locking // is required. void DAP::EventThread() { - llvm::set_thread_name(transport.GetClientName() + ".event_handler"); + llvm::set_thread_name("lldb.DAP.client." + m_client_name + ".event_handler"); lldb::SBEvent event; lldb::SBListener listener = debugger.GetListener(); broadcaster.AddListener(listener, eBroadcastBitStopEventThread); @@ -1316,7 +1364,7 @@ void DAP::EventThread() { if (llvm::Error err = SendThreadStoppedEvent(*this)) DAP_LOG_ERROR(log, std::move(err), "({1}) reporting thread stopped: {0}", - transport.GetClientName()); + m_client_name); } break; case lldb::eStateRunning: diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index b0e9fa9c16b75..0b6373fb80381 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -78,12 +78,16 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -struct DAP { +using DAPTransport = + lldb_private::Transport; + +struct DAP final : private DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; Log *log; - Transport &transport; + DAPTransport &transport; lldb::SBFile in; OutputRedirector out; OutputRedirector err; @@ -114,7 +118,6 @@ struct DAP { /// The focused thread for this DAP session. lldb::tid_t focus_tid = LLDB_INVALID_THREAD_ID; - bool disconnecting = false; llvm::once_flag terminated_event_flag; bool stop_at_entry = false; bool is_attach = false; @@ -177,8 +180,11 @@ struct DAP { /// allocated. /// \param[in] transport /// Transport for this debug session. + /// \param[in] loop + /// Main loop associated with this instance. DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport); + std::vector pre_init_commands, llvm::StringRef client_name, + DAPTransport &transport, lldb_private::MainLoop &loop); ~DAP(); @@ -317,7 +323,7 @@ struct DAP { lldb::SBTarget CreateTarget(lldb::SBError &error); /// Set given target object as a current target for lldb-dap and start - /// listeing for its breakpoint events. + /// listening for its breakpoint events. void SetTarget(const lldb::SBTarget target); bool HandleObject(const protocol::Message &M); @@ -420,13 +426,20 @@ struct DAP { const std::optional> &breakpoints); + void Received(const protocol::Event &) override; + void Received(const protocol::Request &) override; + void Received(const protocol::Response &) override; + void OnError(llvm::Error) override; + void OnClosed() override; + private: std::vector SetSourceBreakpoints( const protocol::Source &source, const std::optional> &breakpoints, SourceBreakpointMap &existing_breakpoints); - lldb_private::Status TransportHandler(); + void TransportHandler(); + void TerminateLoop(bool failed = false); /// Registration of request handler. /// @{ @@ -446,6 +459,8 @@ struct DAP { std::thread progress_event_thread; /// @} + const llvm::StringRef m_client_name; + /// List of addresses mapped by sourceReference. std::vector m_source_references; std::mutex m_source_references_mutex; @@ -454,9 +469,11 @@ struct DAP { std::deque m_queue; std::mutex m_queue_mutex; std::condition_variable m_queue_cv; + bool m_disconnecting = false; + bool m_error_occurred = false; // Loop for managing reading from the client. - lldb_private::MainLoop m_loop; + lldb_private::MainLoop &m_loop; std::mutex m_cancelled_requests_mutex; llvm::SmallSet m_cancelled_requests; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp index bc4fee4aa8b8d..9cd9028d879e9 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp @@ -98,6 +98,10 @@ bool fromJSON(json::Value const &Params, Request &R, json::Path P) { return mapRaw(Params, "arguments", R.arguments, P); } +bool operator==(const Request &a, const Request &b) { + return a.seq == b.seq && a.command == b.command && a.arguments == b.arguments; +} + json::Value toJSON(const Response &R) { json::Object Result{{"type", "response"}, {"seq", 0}, @@ -177,6 +181,11 @@ bool fromJSON(json::Value const &Params, Response &R, json::Path P) { mapRaw(Params, "body", R.body, P); } +bool operator==(const Response &a, const Response &b) { + return a.request_seq == b.request_seq && a.command == b.command && + a.success == b.success && a.message == b.message && a.body == b.body; +} + json::Value toJSON(const ErrorMessage &EM) { json::Object Result{{"id", EM.id}, {"format", EM.format}}; @@ -248,6 +257,10 @@ bool fromJSON(json::Value const &Params, Event &E, json::Path P) { return mapRaw(Params, "body", E.body, P); } +bool operator==(const Event &a, const Event &b) { + return a.event == b.event && a.body == b.body; +} + bool fromJSON(const json::Value &Params, Message &PM, json::Path P) { json::ObjectMapper O(Params, P); if (!O) diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 81496380d412f..0a9ef538a7398 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -52,6 +52,7 @@ struct Request { }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); /// A debug adapter initiated event. struct Event { @@ -63,6 +64,7 @@ struct Event { }; llvm::json::Value toJSON(const Event &); bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +bool operator==(const Event &, const Event &); enum ResponseMessage : unsigned { /// The request was cancelled @@ -101,6 +103,7 @@ struct Response { }; bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); llvm::json::Value toJSON(const Response &); +bool operator==(const Response &, const Response &); /// A structured message object. Used to return errors from requests. struct ErrorMessage { @@ -140,6 +143,7 @@ llvm::json::Value toJSON(const ErrorMessage &); using Message = std::variant; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +bool operator==(const Message &, const Message &); inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Message &V) { OS << toJSON(V); diff --git a/lldb/tools/lldb-dap/Transport.cpp b/lldb/tools/lldb-dap/Transport.cpp index d602920da34e3..8f71f88cae1f7 100644 --- a/lldb/tools/lldb-dap/Transport.cpp +++ b/lldb/tools/lldb-dap/Transport.cpp @@ -14,7 +14,8 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; -using namespace lldb_dap; + +namespace lldb_dap { Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output) @@ -24,3 +25,5 @@ Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, void Transport::Log(llvm::StringRef message) { DAP_LOG(m_log, "({0}) {1}", m_client_name, message); } + +} // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 9a7d8f424d40e..4a9dd76c2303e 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -15,6 +15,7 @@ #define LLDB_TOOLS_LLDB_DAP_TRANSPORT_H #include "DAPForward.h" +#include "Protocol/ProtocolBase.h" #include "lldb/Host/JSONTransport.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" @@ -23,7 +24,9 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport : public lldb_private::HTTPDelimitedJSONTransport { +class Transport final + : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Request, protocol::Response, protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); @@ -31,10 +34,6 @@ class Transport : public lldb_private::HTTPDelimitedJSONTransport { void Log(llvm::StringRef message) override; - /// Returns the name of this transport client, for example `stdin/stdout` or - /// `client_1`. - llvm::StringRef GetClientName() { return m_client_name; } - private: llvm::StringRef m_client_name; lldb_dap::Log *m_log; diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index 8bba4162aa7bf..b74085f25f4e2 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" +#include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -284,7 +285,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, }); std::condition_variable dap_sessions_condition; std::mutex dap_sessions_mutex; - std::map dap_sessions; + std::map dap_sessions; unsigned int clientCount = 0; auto handle = listener->Accept(g_loop, [=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions, @@ -300,8 +301,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, std::thread client([=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions]() { llvm::set_thread_name(client_name + ".runloop"); + MainLoop loop; Transport transport(client_name, log, io, io); - DAP dap(log, default_repl_mode, pre_init_commands, transport); + DAP dap(log, default_repl_mode, pre_init_commands, client_name, transport, + loop); if (auto Err = dap.ConfigureIO()) { llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), @@ -311,7 +314,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, { std::scoped_lock lock(dap_sessions_mutex); - dap_sessions[io.get()] = &dap; + dap_sessions[&loop] = &dap; } if (auto Err = dap.Loop()) { @@ -322,7 +325,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, DAP_LOG(log, "({0}) client disconnected", client_name); std::unique_lock lock(dap_sessions_mutex); - dap_sessions.erase(io.get()); + dap_sessions.erase(&loop); std::notify_all_at_thread_exit(dap_sessions_condition, std::move(lock)); }); client.detach(); @@ -344,13 +347,14 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, bool client_failed = false; { std::scoped_lock lock(dap_sessions_mutex); - for (auto [sock, dap] : dap_sessions) { + for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client " << dap->transport.GetClientName() - << " disconnected failed: " - << llvm::toString(std::move(error)) << "\n"; + llvm::WithColor::error() << "DAP client disconnected failed: " + << llvm::toString(std::move(error)) << "\n"; } + loop->AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); } } @@ -550,8 +554,10 @@ int main(int argc, char *argv[]) { stdout_fd, File::eOpenOptionWriteOnly, NativeFile::Unowned); constexpr llvm::StringLiteral client_name = "stdio"; + MainLoop loop; Transport transport(client_name, log.get(), input, output); - DAP dap(log.get(), default_repl_mode, pre_init_commands, transport); + DAP dap(log.get(), default_repl_mode, pre_init_commands, client_name, + transport, loop); // stdout/stderr redirection to the IDE's console if (auto Err = dap.ConfigureIO(stdout, stderr)) { diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 138910d917424..d5a9591ad0a43 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -10,8 +10,8 @@ #include "Protocol/ProtocolBase.h" #include "TestBase.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" -#include #include using namespace llvm; @@ -19,6 +19,7 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using namespace testing; class DAPTest : public TransportBase {}; @@ -27,12 +28,13 @@ TEST_F(DAPTest, SendProtocolMessages) { /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/{}, - /*transport=*/*to_dap, + /*client_name=*/"test_client", + /*transport=*/*transport, + /*loop=*/loop, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - RunOnce([&](llvm::Expected message) { - ASSERT_THAT_EXPECTED( - message, HasValue(testing::VariantWith(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); - }); + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); + ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 0546aeb154d50..c6ff1f90b01d5 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -23,18 +23,15 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using testing::_; class DisconnectRequestHandlerTest : public DAPTestBase {}; TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, - testing::Contains(testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); + RunOnce(); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -47,17 +44,14 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); dap->configuration.terminateCommands = {"?script print(1)", "script print(2)"}; EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, testing::ElementsAre( - OutputMatcher("Running terminateCommands:\n"), - OutputMatcher("(lldb) script print(2)\n"), - OutputMatcher("2\n"), - testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + EXPECT_CALL(client, Received(Output("1\n"))); + EXPECT_CALL(client, Received(Output("2\n"))).Times(2); + EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); + EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); + RunOnce(); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 8f9b098c8b1e1..54ac27da694e6 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -7,19 +7,19 @@ //===----------------------------------------------------------------------===// #include "TestBase.h" -#include "Protocol/ProtocolBase.h" +#include "DAPLog.h" #include "TestingSupport/TestUtilities.h" #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" -#include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Pipe.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include #include +#include using namespace llvm; using namespace lldb; @@ -27,38 +27,36 @@ using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; using lldb_private::File; +using lldb_private::FileSpec; +using lldb_private::FileSystem; using lldb_private::MainLoop; -using lldb_private::MainLoopBase; -using lldb_private::NativeFile; using lldb_private::Pipe; -void TransportBase::SetUp() { - PipePairTest::SetUp(); - to_dap = std::make_unique( - "to_dap", nullptr, - std::make_shared(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - from_dap = std::make_unique( - "from_dap", nullptr, - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); +Expected +TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { + Expected dummy_file = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + if (!dummy_file) + return dummy_file.takeError(); + m_dummy_file = std::move(*dummy_file); + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; } void DAPTestBase::SetUp() { TransportBase::SetUp(); + std::error_code EC; + log = std::make_unique("-", EC); dap = std::make_unique( - /*log=*/nullptr, + /*log=*/log.get(), /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/std::vector(), - /*transport=*/*to_dap); + /*client_name=*/"test_client", + /*transport=*/*transport, /*loop=*/loop); } void DAPTestBase::TearDown() { @@ -76,7 +74,7 @@ void DAPTestBase::SetUpTestSuite() { } void DAPTestBase::TeatUpTestSuite() { SBDebugger::Terminate(); } -bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { +bool DAPTestBase::GetDebuggerSupportsTarget(StringRef platform) { EXPECT_TRUE(dap->debugger); lldb::SBStructuredData data = dap->debugger.GetBuildConfiguration() @@ -85,7 +83,7 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { for (size_t i = 0; i < data.GetSize(); i++) { char buf[100] = {0}; size_t size = data.GetItemAtIndex(i).GetStringValue(buf, sizeof(buf)); - if (llvm::StringRef(buf, size) == platform) + if (StringRef(buf, size) == platform) return true; } @@ -95,6 +93,24 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { void DAPTestBase::CreateDebugger() { dap->debugger = lldb::SBDebugger::Create(); ASSERT_TRUE(dap->debugger); + dap->target = dap->debugger.GetDummyTarget(); + + Expected dev_null = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + ASSERT_THAT_EXPECTED(dev_null, Succeeded()); + lldb::FileSP dev_null_sp = std::move(*dev_null); + + std::FILE *dev_null_stream = dev_null_sp->GetStream(); + ASSERT_THAT_ERROR(dap->ConfigureIO(dev_null_stream, dev_null_stream), + Succeeded()); + + dap->debugger.SetInputFile(dap->in); + auto out_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(out_fd, Succeeded()); + dap->debugger.SetOutputFile(lldb::SBFile(*out_fd, "w", false)); + auto err_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(err_fd, Succeeded()); + dap->debugger.SetErrorFile(lldb::SBFile(*err_fd, "w", false)); } void DAPTestBase::LoadCore() { @@ -118,22 +134,3 @@ void DAPTestBase::LoadCore() { SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); ASSERT_TRUE(process); } - -std::vector DAPTestBase::DrainOutput() { - std::vector msgs; - output.CloseWriteFileDescriptor(); - auto handle = from_dap->RegisterReadObject( - loop, [&](MainLoopBase &loop, Expected next) { - if (llvm::Error error = next.takeError()) { - loop.RequestTermination(); - consumeError(std::move(error)); - return; - } - - msgs.push_back(*next); - }); - - consumeError(handle.takeError()); - consumeError(loop.Run().takeError()); - return msgs; -} diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index afdfb540d39b8..c19eead4e37e7 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,55 +8,109 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" -#include "TestingSupport/Host/PipeTestUtilities.h" -#include "Transport.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" +#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include namespace lldb_dap_tests { +class TestTransport final + : public lldb_private::Transport { +public: + using Message = lldb_private::Transport::Message; + + TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) + : m_loop(loop), m_handler(handler) {} + + llvm::Error Send(const lldb_dap::protocol::Event &e) override { + m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { + this->m_handler.Received(e); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const lldb_dap::protocol::Request &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.Received(r); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const lldb_dap::protocol::Response &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.Received(r); + }); + return llvm::Error::success(); + } + + llvm::Expected + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; + +private: + lldb_private::MainLoop &m_loop; + MessageHandler &m_handler; + lldb::FileSP m_dummy_file; +}; + /// A base class for tests that need transport configured for communicating DAP /// messages. -class TransportBase : public PipePairTest { +class TransportBase : public testing::Test { protected: - std::unique_ptr to_dap; - std::unique_ptr from_dap; + lldb_private::SubsystemRAII + subsystems; lldb_private::MainLoop loop; + std::unique_ptr transport; + MockMessageHandler + client; - void SetUp() override; - - template - void RunOnce(const std::function)> &callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = from_dap->RegisterReadObject

( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected

message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void SetUp() override { + transport = std::make_unique(loop, client); } }; +/// A matcher for a DAP event. +template +inline testing::Matcher +IsEvent(const M1 &m1, const M2 &m2) { + return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), + testing::Field(&lldb_dap::protocol::Event::body, m2)); +} + /// Matches an "output" event. -inline auto OutputMatcher(const llvm::StringRef output, - const llvm::StringRef category = "console") { - return testing::VariantWith(testing::FieldsAre( - /*event=*/"output", /*body=*/testing::Optional( - llvm::json::Object{{"category", category}, {"output", output}}))); +inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { + return IsEvent("output", + testing::Optional(llvm::json::Value( + llvm::json::Object{{"category", cat}, {"output", o}}))); } /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: + std::unique_ptr log; std::unique_ptr dap; std::optional core; std::optional binary; @@ -73,9 +127,11 @@ class DAPTestBase : public TransportBase { void CreateDebugger(); void LoadCore(); - /// Closes the DAP output pipe and returns the remaining protocol messages in - /// the buffer. - std::vector DrainOutput(); + void RunOnce() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); + } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 4e94582d3bc6a..445674f402252 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -7,43 +7,142 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "llvm/ADT/FunctionExtras.h" +#include "lldb/Utility/Log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include #include #include using namespace llvm; using namespace lldb_private; +using testing::_; +using testing::HasSubstr; +using testing::InSequence; namespace { -struct JSONTestType { - std::string str; +namespace test_protocol { + +struct Req { + std::string name; }; +json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +bool fromJSON(const json::Value &V, Req &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("req", T.name); +} +bool operator==(const Req &a, const Req &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Req &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} -json::Value toJSON(const JSONTestType &T) { - return json::Object{{"str", T.str}}; +struct Resp { + std::string name; +}; +json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +bool fromJSON(const json::Value &V, Resp &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resp", T.name); +} +bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Resp &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; } -bool fromJSON(const json::Value &V, JSONTestType &T, json::Path P) { +struct Evt { + std::string name; +}; +json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("str", T.str); + return O && O.map("evt", T.name); +} +bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Evt &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +using Message = std::variant; +json::Value toJSON(const Message &msg) { + return std::visit([](const auto &msg) { return toJSON(msg); }, msg); +} +bool fromJSON(const json::Value &V, Message &msg, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + if (O->get("req")) { + Req R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; + } + if (O->get("resp")) { + Resp R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; + } + if (O->get("evt")) { + Evt E; + if (!fromJSON(V, E, P)) + return false; + + msg = std::move(E); + return true; + } + P.report("unknown message type"); + return false; } -template class JSONTransportTest : public PipePairTest { +} // namespace test_protocol + +template +class JSONTransportTest : public PipePairTest { + protected: - std::unique_ptr transport; + MockMessageHandler message_handler; + std::unique_ptr transport; MainLoop loop; void SetUp() override { @@ -57,53 +156,57 @@ template class JSONTransportTest : public PipePairTest { NativeFile::Unowned)); } - template - Expected

- RunOnce(std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - std::promise> promised_message; - std::future> future_message = promised_message.get_future(); - RunUntil

( - [&promised_message](Expected

message) mutable -> bool { - promised_message.set_value(std::move(message)); - return /*keep_going*/ false; - }, - timeout); - return future_message.get(); - } - - /// RunUntil runs the event loop until the callback returns `false` or a - /// timeout has occurred. - template - void RunUntil(std::function)> callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = transport->RegisterReadObject

( - loop, [&callback](MainLoopBase &loop, Expected

message) mutable { - bool keep_going = callback(std::move(message)); - if (!keep_going) - loop.RequestTermination(); - }); + /// Run the transport MainLoop and return any messages received. + Error + Run(bool close_input = true, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + if (close_input) { + input.CloseWriteFileDescriptor(); + EXPECT_CALL(message_handler, OnClosed()).WillOnce([this]() { + loop.RequestTermination(); + }); + } loop.AddCallback( - [&callback](MainLoopBase &loop) mutable { + [](MainLoopBase &loop) { loop.RequestTermination(); - callback(createStringError("timeout")); + FAIL() << "timeout"; }, timeout); - EXPECT_THAT_EXPECTED(handle, Succeeded()); - EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); + auto handle = transport->RegisterMessageHandler(loop, message_handler); + if (!handle) + return handle.takeError(); + + return loop.Run().takeError(); } - template llvm::Expected Write(Ts... args) { + template void Write(Ts... args) { std::string message; for (const auto &arg : {args...}) message += Encode(arg); - return input.Write(message.data(), message.size()); + EXPECT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); } virtual std::string Encode(const json::Value &) = 0; }; +class TestHTTPDelimitedJSONTransport final + : public HTTPDelimitedJSONTransport { +public: + using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; +}; + class HTTPDelimitedJSONTransportTest - : public JSONTransportTest { + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -118,7 +221,22 @@ class HTTPDelimitedJSONTransportTest } }; -class JSONRPCTransportTest : public JSONTransportTest { +class TestJSONRPCTransport final + : public JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; +}; + +class JSONRPCTransportTest + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -134,6 +252,7 @@ class JSONRPCTransportTest : public JSONTransportTest { // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 +using namespace test_protocol; TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { std::string malformed_header = @@ -141,84 +260,83 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - FailedWithMessage("invalid content length: -1")); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage("invalid content length: -1")); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + Write(Req{"foo"}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"one"}, JSONTestType{"two"}), - Succeeded()); - unsigned count = 0; - RunUntil([&](Expected message) -> bool { - if (count == 0) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"one"))); - } else if (count == 1) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"two"))); - } - - count++; - return count < 2; - }); + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { - std::string long_str = std::string(2048, 'x'); - ASSERT_THAT_EXPECTED(Write(JSONTestType{long_str}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + std::string long_str = std::string( + HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); - ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); + + EXPECT_CALL(message_handler, Received(Req{"foo"})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + // Run the main loop once for the initial read. + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + + // zero-byte write. ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), Succeeded()); // zero-byte write. + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); - + // Write the remaining part of the message. ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), Failed()); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { @@ -227,36 +345,41 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { formatv("Content-Length: {0}\r\nContent-type: text/json\r\n\r\n{1}", json.size(), json) .str(); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + // The error should indicate that there are unhandled contents. + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + // Write an incomplete message and close the handle. ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), - Failed()); -} - -TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - auto handle = transport->RegisterReadObject( - loop, [&](MainLoopBase &, llvm::Expected) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); + transport = + std::make_unique(nullptr, nullptr); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"json({"str":"foo"})json")); + R"({"req":"foo"})" + "Content-Length: 14\r\n\r\n" + R"({"resp":"bar"})" + "Content-Length: 13\r\n\r\n" + R"({"evt":"baz"})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -264,80 +387,94 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), llvm::Failed()); + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage(HasSubstr("Invalid JSON value"))); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + Write(Message{Req{"foo"}}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { - std::string long_str = std::string(2048, 'x'); - std::string message = Encode(JSONTestType{long_str}); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), - Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + // Use a string longer than the chunk size to ensure we split the message + // across the chunk boundary. + std::string long_str = + std::string(JSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"str": "foo"})" + std::string message = R"({"req": "foo"})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); - ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + EXPECT_CALL(message_handler, Received(Req{"foo"})); - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), Succeeded()); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadWithEOF) { - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), Failed()); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { - std::string message = R"json({"str": "foo"})json" - "\n"; + std::string message = R"json({"req": "foo")json"; // Write an incomplete message and close the handle. ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), - Failed()); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" + "\n" + R"({"resp":"bar"})" + "\n" + R"({"evt":"baz"})" "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - auto handle = transport->RegisterReadObject( - loop, [&](MainLoopBase &, llvm::Expected) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); -} - -TEST_F(JSONRPCTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } #endif diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index de2ae2313ecd7..18112428950ce 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -8,6 +8,7 @@ #include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" #include "Plugins/Protocol/MCP/ProtocolServerMCP.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Core/Debugger.h" #include "lldb/Core/ProtocolServer.h" @@ -21,7 +22,9 @@ #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include #include @@ -31,6 +34,7 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; +using testing::_; namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { @@ -43,11 +47,18 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { using ProtocolServerMCP::ProtocolServerMCP; }; -class TestJSONTransport : public lldb_private::JSONRPCTransport { +using Message = typename Transport::Message; + +class TestJSONTransport final + : public lldb_private::JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; - using JSONRPCTransport::Parse; - using JSONRPCTransport::WriteImpl; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; }; /// Test tool that returns it argument as text. @@ -55,7 +66,7 @@ class TestTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get(args).getAsObject()) { @@ -73,7 +84,7 @@ class TestTool : public Tool { class TestResourceProvider : public ResourceProvider { using ResourceProvider::ResourceProvider; - virtual std::vector GetResources() const override { + std::vector GetResources() const override { std::vector resources; Resource resource; @@ -86,7 +97,7 @@ class TestResourceProvider : public ResourceProvider { return resources; } - virtual llvm::Expected + llvm::Expected ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error(uri.str()); @@ -107,7 +118,7 @@ class ErrorTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; @@ -117,7 +128,7 @@ class FailTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { TextResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; @@ -134,30 +145,30 @@ class ProtocolServerMCPTest : public ::testing::Test { std::unique_ptr m_transport_up; std::unique_ptr m_server_up; MainLoop loop; + MockMessageHandler message_handler; static constexpr llvm::StringLiteral k_localhost = "localhost"; llvm::Error Write(llvm::StringRef message) { - return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + std::string output = llvm::formatv("{0}\n", message).str(); + size_t bytes_written = output.size(); + return m_io_sp->Write(output.data(), bytes_written).takeError(); } - template - void - RunOnce(const std::function)> &callback, - std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { - auto handle = m_transport_up->RegisterReadObject

( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected

message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [&](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void CloseInput() { + EXPECT_THAT_ERROR(m_io_sp->Close().takeError(), Succeeded()); + } + + /// Run the transport MainLoop and return any messages received. + llvm::Error + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + auto handle = m_transport_up->RegisterMessageHandler(loop, message_handler); + if (!handle) + return handle.takeError(); + + return loop.Run().takeError(); } void SetUp() override { @@ -202,41 +213,45 @@ class ProtocolServerMCPTest : public ::testing::Test { TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + ASSERT_THAT_ERROR(Write(request), Succeeded()); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = - R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; - llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; + R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":"one"})json"; + + ToolDefinition test_tool; + test_tool.name = "test"; + test_tool.description = "test tool"; + test_tool.inputSchema = json::Object{{"type", "object"}}; + + ToolDefinition lldb_command_tool; + lldb_command_tool.description = "Run an lldb command."; + lldb_command_tool.name = "lldb_command"; + lldb_command_tool.inputSchema = json::Object{ + {"type", "object"}, + {"properties", + json::Object{{"arguments", json::Object{{"type", "string"}}}, + {"debugger_id", json::Object{{"type", "number"}}}}}, + {"required", json::Array{"debugger_id"}}}; + Response response; + response.id = "one"; + response.result = json::Object{ + {"tools", + json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, + }; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ResourcesList) { @@ -246,17 +261,10 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCall) { @@ -266,17 +274,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { @@ -288,17 +289,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { @@ -310,17 +304,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h new file mode 100644 index 0000000000000..5a9eb8e59f2b6 --- /dev/null +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H + +#include "lldb/Host/JSONTransport.h" +#include "gmock/gmock.h" + +template +class MockMessageHandler final + : public lldb_private::Transport::MessageHandler { +public: + MOCK_METHOD(void, Received, (const Evt &), (override)); + MOCK_METHOD(void, Received, (const Req &), (override)); + MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, OnError, (llvm::Error), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); +}; + +#endif