| 
 | 1 | +package io.modelcontextprotocol.kotlin.sdk.client  | 
 | 2 | + | 
 | 3 | +import io.github.oshai.kotlinlogging.KotlinLogging  | 
 | 4 | +import io.ktor.client.HttpClient  | 
 | 5 | +import io.ktor.client.plugins.ClientRequestException  | 
 | 6 | +import io.ktor.client.plugins.sse.ClientSSESession  | 
 | 7 | +import io.ktor.client.plugins.sse.sseSession  | 
 | 8 | +import io.ktor.client.request.HttpRequestBuilder  | 
 | 9 | +import io.ktor.client.request.accept  | 
 | 10 | +import io.ktor.client.request.delete  | 
 | 11 | +import io.ktor.client.request.headers  | 
 | 12 | +import io.ktor.client.request.post  | 
 | 13 | +import io.ktor.client.request.setBody  | 
 | 14 | +import io.ktor.client.statement.HttpResponse  | 
 | 15 | +import io.ktor.client.statement.bodyAsChannel  | 
 | 16 | +import io.ktor.client.statement.bodyAsText  | 
 | 17 | +import io.ktor.http.ContentType  | 
 | 18 | +import io.ktor.http.HttpHeaders  | 
 | 19 | +import io.ktor.http.HttpMethod  | 
 | 20 | +import io.ktor.http.HttpStatusCode  | 
 | 21 | +import io.ktor.http.contentType  | 
 | 22 | +import io.ktor.http.isSuccess  | 
 | 23 | +import io.ktor.utils.io.readUTF8Line  | 
 | 24 | +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage  | 
 | 25 | +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification  | 
 | 26 | +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest  | 
 | 27 | +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse  | 
 | 28 | +import io.modelcontextprotocol.kotlin.sdk.RequestId  | 
 | 29 | +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport  | 
 | 30 | +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson  | 
 | 31 | +import kotlinx.coroutines.CancellationException  | 
 | 32 | +import kotlinx.coroutines.CoroutineName  | 
 | 33 | +import kotlinx.coroutines.CoroutineScope  | 
 | 34 | +import kotlinx.coroutines.Dispatchers  | 
 | 35 | +import kotlinx.coroutines.Job  | 
 | 36 | +import kotlinx.coroutines.SupervisorJob  | 
 | 37 | +import kotlinx.coroutines.cancel  | 
 | 38 | +import kotlinx.coroutines.cancelAndJoin  | 
 | 39 | +import kotlinx.coroutines.launch  | 
 | 40 | +import kotlin.concurrent.atomics.AtomicBoolean  | 
 | 41 | +import kotlin.concurrent.atomics.ExperimentalAtomicApi  | 
 | 42 | +import kotlin.time.Duration  | 
 | 43 | + | 
 | 44 | +private val logger = KotlinLogging.logger {}  | 
 | 45 | + | 
 | 46 | +private const val MCP_SESSION_ID_HEADER = "mcp-session-id"  | 
 | 47 | +private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"  | 
 | 48 | +private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID"  | 
 | 49 | + | 
 | 50 | +/**  | 
 | 51 | + * Error class for Streamable HTTP transport errors.  | 
 | 52 | + */  | 
 | 53 | +public class StreamableHttpError(  | 
 | 54 | +    public val code: Int? = null,  | 
 | 55 | +    message: String? = null  | 
 | 56 | +) : Exception("Streamable HTTP error: $message")  | 
 | 57 | + | 
 | 58 | +/**  | 
 | 59 | + * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification.  | 
 | 60 | + * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events  | 
 | 61 | + * for receiving messages.  | 
 | 62 | + */  | 
 | 63 | +@OptIn(ExperimentalAtomicApi::class)  | 
 | 64 | +public class StreamableHttpClientTransport(  | 
 | 65 | +    private val client: HttpClient,  | 
 | 66 | +    private val url: String,  | 
 | 67 | +    private val reconnectionTime: Duration? = null,  | 
 | 68 | +    private val requestBuilder: HttpRequestBuilder.() -> Unit = {},  | 
 | 69 | +) : AbstractTransport() {  | 
 | 70 | + | 
 | 71 | +    public var sessionId: String? = null  | 
 | 72 | +        private set  | 
 | 73 | +    public var protocolVersion: String? = null  | 
 | 74 | + | 
 | 75 | +    private val initialized: AtomicBoolean = AtomicBoolean(false)  | 
 | 76 | + | 
 | 77 | +    private var sseSession: ClientSSESession? = null  | 
 | 78 | +    private var sseJob: Job? = null  | 
 | 79 | + | 
 | 80 | +    private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) }  | 
 | 81 | + | 
 | 82 | +    private var lastEventId: String? = null  | 
 | 83 | + | 
 | 84 | +    override suspend fun start() {  | 
 | 85 | +        if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {  | 
 | 86 | +            error("StreamableHttpClientTransport already started!")  | 
 | 87 | +        }  | 
 | 88 | +        logger.debug { "Client transport starting..." }  | 
 | 89 | +    }  | 
 | 90 | + | 
 | 91 | +    /**  | 
 | 92 | +     * Sends a single message with optional resumption support  | 
 | 93 | +     */  | 
 | 94 | +    override suspend fun send(message: JSONRPCMessage) {  | 
 | 95 | +        send(message, null)  | 
 | 96 | +    }  | 
 | 97 | + | 
 | 98 | +    /**  | 
 | 99 | +     * Sends one or more messages with optional resumption support.  | 
 | 100 | +     * This is the main send method that matches the TypeScript implementation.  | 
 | 101 | +     */  | 
 | 102 | +    public suspend fun send(  | 
 | 103 | +        message: JSONRPCMessage,  | 
 | 104 | +        resumptionToken: String?,  | 
 | 105 | +        onResumptionToken: ((String) -> Unit)? = null  | 
 | 106 | +    ) {  | 
 | 107 | +        logger.debug { "Client sending message via POST to $url: ${McpJson.encodeToString(message)}" }  | 
 | 108 | + | 
 | 109 | +        // If we have a resumption token, reconnect the SSE stream with it  | 
 | 110 | +        resumptionToken?.let { token ->  | 
 | 111 | +            startSseSession(  | 
 | 112 | +                resumptionToken = token, onResumptionToken = onResumptionToken,  | 
 | 113 | +                replayMessageId = if (message is JSONRPCRequest) message.id else null  | 
 | 114 | +            )  | 
 | 115 | +            return  | 
 | 116 | +        }  | 
 | 117 | + | 
 | 118 | +        val jsonBody = McpJson.encodeToString(message)  | 
 | 119 | +        val response = client.post(url) {  | 
 | 120 | +            applyCommonHeaders(this)  | 
 | 121 | +            headers.append(HttpHeaders.Accept, "${ContentType.Application.Json}, ${ContentType.Text.EventStream}")  | 
 | 122 | +            contentType(ContentType.Application.Json)  | 
 | 123 | +            setBody(jsonBody)  | 
 | 124 | +            requestBuilder()  | 
 | 125 | +        }  | 
 | 126 | + | 
 | 127 | +        response.headers[MCP_SESSION_ID_HEADER]?.let { sessionId = it }  | 
 | 128 | + | 
 | 129 | +        if (response.status == HttpStatusCode.Accepted) {  | 
 | 130 | +            if (message is JSONRPCNotification && message.method == "notifications/initialized") {  | 
 | 131 | +                startSseSession(onResumptionToken = onResumptionToken)  | 
 | 132 | +            }  | 
 | 133 | +            return  | 
 | 134 | +        }  | 
 | 135 | + | 
 | 136 | +        if (!response.status.isSuccess()) {  | 
 | 137 | +            val error = StreamableHttpError(response.status.value, response.bodyAsText())  | 
 | 138 | +            _onError(error)  | 
 | 139 | +            throw error  | 
 | 140 | +        }  | 
 | 141 | + | 
 | 142 | +        when (response.contentType()?.withoutParameters()) {  | 
 | 143 | +            ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json ->  | 
 | 144 | +                runCatching { McpJson.decodeFromString<JSONRPCMessage>(json) }  | 
 | 145 | +                    .onSuccess { _onMessage(it) }  | 
 | 146 | +                    .onFailure(_onError)  | 
 | 147 | +            }  | 
 | 148 | + | 
 | 149 | +            ContentType.Text.EventStream -> handleInlineSse(  | 
 | 150 | +                response, onResumptionToken = onResumptionToken,  | 
 | 151 | +                replayMessageId = if (message is JSONRPCRequest) message.id else null  | 
 | 152 | +            )  | 
 | 153 | +            else -> {  | 
 | 154 | +                val body = response.bodyAsText()  | 
 | 155 | +                if (response.contentType() == null && body.isBlank()) return  | 
 | 156 | + | 
 | 157 | +                val ct = response.contentType()?.toString() ?: "<none>"  | 
 | 158 | +                val error = StreamableHttpError(-1, "Unexpected content type: $$ct")  | 
 | 159 | +                _onError(error)  | 
 | 160 | +                throw error  | 
 | 161 | +            }  | 
 | 162 | +        }  | 
 | 163 | +    }  | 
 | 164 | + | 
 | 165 | +    override suspend fun close() {  | 
 | 166 | +        if (!initialized.load()) return // Already closed or never started  | 
 | 167 | +        logger.debug { "Client transport closing." }  | 
 | 168 | + | 
 | 169 | +        try {  | 
 | 170 | +            // Try to terminate session if we have one  | 
 | 171 | +            terminateSession()  | 
 | 172 | + | 
 | 173 | +            sseSession?.cancel()  | 
 | 174 | +            sseJob?.cancelAndJoin()  | 
 | 175 | +            scope.cancel()  | 
 | 176 | +        } catch (_: Exception) {  | 
 | 177 | +            // Ignore errors during cleanup  | 
 | 178 | +        } finally {  | 
 | 179 | +            initialized.store(false)  | 
 | 180 | +            _onClose()  | 
 | 181 | +        }  | 
 | 182 | +    }  | 
 | 183 | + | 
 | 184 | +    /**  | 
 | 185 | +     * Terminates the current session by sending a DELETE request to the server.  | 
 | 186 | +     */  | 
 | 187 | +    public suspend fun terminateSession() {  | 
 | 188 | +        if (sessionId == null) return  | 
 | 189 | +        logger.debug { "Terminating session: $sessionId" }  | 
 | 190 | +        val response = client.delete(url) {  | 
 | 191 | +            applyCommonHeaders(this)  | 
 | 192 | +            requestBuilder()  | 
 | 193 | +        }  | 
 | 194 | + | 
 | 195 | +        // 405 means server doesn't support explicit session termination  | 
 | 196 | +        if (!response.status.isSuccess() && response.status != HttpStatusCode.MethodNotAllowed) {  | 
 | 197 | +            val error = StreamableHttpError(  | 
 | 198 | +                response.status.value,  | 
 | 199 | +                "Failed to terminate session: ${response.status.description}"  | 
 | 200 | +            )  | 
 | 201 | +            logger.error(error) { "Failed to terminate session" }  | 
 | 202 | +            _onError(error)  | 
 | 203 | +            throw error  | 
 | 204 | +        }  | 
 | 205 | + | 
 | 206 | +        sessionId = null  | 
 | 207 | +        lastEventId = null  | 
 | 208 | +        logger.debug { "Session terminated successfully" }  | 
 | 209 | +    }  | 
 | 210 | + | 
 | 211 | +    private suspend fun startSseSession(  | 
 | 212 | +        resumptionToken: String? = null,  | 
 | 213 | +        replayMessageId: RequestId? = null,  | 
 | 214 | +        onResumptionToken: ((String) -> Unit)? = null  | 
 | 215 | +    ) {  | 
 | 216 | +        sseSession?.cancel()  | 
 | 217 | +        sseJob?.cancelAndJoin()  | 
 | 218 | + | 
 | 219 | +        logger.debug { "Client attempting to start SSE session at url: $url" }  | 
 | 220 | +        try {  | 
 | 221 | +            sseSession = client.sseSession(  | 
 | 222 | +                urlString = url,  | 
 | 223 | +                reconnectionTime = reconnectionTime,  | 
 | 224 | +            ) {  | 
 | 225 | +                method = HttpMethod.Get  | 
 | 226 | +                applyCommonHeaders(this)  | 
 | 227 | +                accept(ContentType.Text.EventStream)  | 
 | 228 | +                (resumptionToken ?: lastEventId)?.let { headers.append(MCP_RESUMPTION_TOKEN_HEADER, it) }  | 
 | 229 | +                requestBuilder()  | 
 | 230 | +            }  | 
 | 231 | +            logger.debug { "Client SSE session started successfully." }  | 
 | 232 | +        } catch (e: ClientRequestException) {  | 
 | 233 | +            if (e.response.status == HttpStatusCode.MethodNotAllowed) {  | 
 | 234 | +                logger.info { "Server returned 405 for GET/SSE, stream disabled." }  | 
 | 235 | +                return  | 
 | 236 | +            }  | 
 | 237 | +            _onError(e)  | 
 | 238 | +            throw e  | 
 | 239 | +        }  | 
 | 240 | + | 
 | 241 | +        sseJob = scope.launch(CoroutineName("StreamableHttpTransport.collect#${hashCode()}")) {  | 
 | 242 | +            sseSession?.let { collectSse(it, replayMessageId, onResumptionToken) }  | 
 | 243 | +        }  | 
 | 244 | +    }  | 
 | 245 | + | 
 | 246 | +    private fun applyCommonHeaders(builder: HttpRequestBuilder) {  | 
 | 247 | +        builder.headers {  | 
 | 248 | +            sessionId?.let { append(MCP_SESSION_ID_HEADER, it) }  | 
 | 249 | +            protocolVersion?.let { append(MCP_PROTOCOL_VERSION_HEADER, it) }  | 
 | 250 | +        }  | 
 | 251 | +    }  | 
 | 252 | + | 
 | 253 | +    private suspend fun collectSse(  | 
 | 254 | +        session: ClientSSESession,  | 
 | 255 | +        replayMessageId: RequestId?,  | 
 | 256 | +        onResumptionToken: ((String) -> Unit)?  | 
 | 257 | +    ) {  | 
 | 258 | +        try {  | 
 | 259 | +            session.incoming.collect { event ->  | 
 | 260 | +                event.id?.let {  | 
 | 261 | +                    lastEventId = it  | 
 | 262 | +                    onResumptionToken?.invoke(it)  | 
 | 263 | +                }  | 
 | 264 | +                logger.trace { "Client received SSE event: event=${event.event}, data=${event.data}, id=${event.id}" }  | 
 | 265 | +                when (event.event) {  | 
 | 266 | +                    null, "message" ->  | 
 | 267 | +                        event.data?.takeIf { it.isNotEmpty() }?.let { json ->  | 
 | 268 | +                            runCatching { McpJson.decodeFromString<JSONRPCMessage>(json) }  | 
 | 269 | +                                .onSuccess { msg ->  | 
 | 270 | +                                    if (replayMessageId != null && msg is JSONRPCResponse) {  | 
 | 271 | +                                        _onMessage(msg.copy(id = replayMessageId))  | 
 | 272 | +                                    } else {  | 
 | 273 | +                                        _onMessage(msg)  | 
 | 274 | +                                    }  | 
 | 275 | +                                }  | 
 | 276 | +                                .onFailure(_onError)  | 
 | 277 | +                        }  | 
 | 278 | + | 
 | 279 | +                    "error" -> _onError(StreamableHttpError(null, event.data))  | 
 | 280 | +                }  | 
 | 281 | +            }  | 
 | 282 | +        } catch (_: CancellationException) {  | 
 | 283 | +            // ignore  | 
 | 284 | +        } catch (t: Throwable) {  | 
 | 285 | +            _onError(t)  | 
 | 286 | +        }  | 
 | 287 | +    }  | 
 | 288 | + | 
 | 289 | +    private suspend fun handleInlineSse(  | 
 | 290 | +        response: HttpResponse,  | 
 | 291 | +        replayMessageId: RequestId?,  | 
 | 292 | +        onResumptionToken: ((String) -> Unit)?  | 
 | 293 | +    ) {  | 
 | 294 | +        logger.trace { "Handling inline SSE from POST response" }  | 
 | 295 | +        val channel = response.bodyAsChannel()  | 
 | 296 | + | 
 | 297 | +        val sb = StringBuilder()  | 
 | 298 | +        var id: String? = null  | 
 | 299 | +        var eventName: String? = null  | 
 | 300 | + | 
 | 301 | +        suspend fun dispatch(data: String) {  | 
 | 302 | +            id?.let {  | 
 | 303 | +                lastEventId = it  | 
 | 304 | +                onResumptionToken?.invoke(it)  | 
 | 305 | +            }  | 
 | 306 | +            if (eventName == null || eventName == "message") {  | 
 | 307 | +                runCatching { McpJson.decodeFromString<JSONRPCMessage>(data) }  | 
 | 308 | +                    .onSuccess { msg ->  | 
 | 309 | +                        if (replayMessageId != null && msg is JSONRPCResponse) {  | 
 | 310 | +                            _onMessage(msg.copy(id = replayMessageId))  | 
 | 311 | +                        } else {  | 
 | 312 | +                            _onMessage(msg)  | 
 | 313 | +                        }  | 
 | 314 | +                    }  | 
 | 315 | +                    .onFailure(_onError)  | 
 | 316 | +            }  | 
 | 317 | +            // reset  | 
 | 318 | +            id = null  | 
 | 319 | +            eventName = null  | 
 | 320 | +            sb.clear()  | 
 | 321 | +        }  | 
 | 322 | + | 
 | 323 | +        while (!channel.isClosedForRead) {  | 
 | 324 | +            val line = channel.readUTF8Line() ?: break  | 
 | 325 | +            if (line.isEmpty()) {  | 
 | 326 | +                dispatch(sb.toString())  | 
 | 327 | +                continue  | 
 | 328 | +            }  | 
 | 329 | +            when {  | 
 | 330 | +                line.startsWith("id:") -> id = line.substringAfter("id:").trim()  | 
 | 331 | +                line.startsWith("event:") -> eventName = line.substringAfter("event:").trim()  | 
 | 332 | +                line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim())  | 
 | 333 | +            }  | 
 | 334 | +        }  | 
 | 335 | +    }  | 
 | 336 | +}  | 
0 commit comments