@@ -237,14 +237,13 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request
237237 }
238238
239239 // Check if this is a sampling response (has result/error but no method)
240- isSamplingResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
240+ isResponse := jsonMessage .Method == "" && jsonMessage .ID != nil &&
241241 (jsonMessage .Result != nil || jsonMessage .Error != nil )
242-
243242 isInitializeRequest := jsonMessage .Method == mcp .MethodInitialize
244243
245244 // Handle sampling responses separately
246- if isSamplingResponse {
247- if err := s .handleSamplingResponse (w , r , jsonMessage ); err != nil {
245+ if isResponse {
246+ if err := s .handleResponse (w , r , jsonMessage ); err != nil {
248247 s .logger .Errorf ("Failed to handle sampling response: %v" , err )
249248 http .Error (w , "Failed to handle sampling response" , http .StatusInternalServerError )
250249 }
@@ -390,7 +389,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
390389 return
391390 }
392391 defer s .server .UnregisterSession (r .Context (), sessionID )
393-
392+
394393 // Register session for sampling response delivery
395394 s .activeSessions .Store (sessionID , session )
396395 defer s .activeSessions .Delete (sessionID )
@@ -437,6 +436,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
437436 case <- done :
438437 return
439438 }
439+ case elicitationReq := <- session .elicitationRequestChan :
440+ // Send elicitation request to client via SSE
441+ jsonrpcRequest := mcp.JSONRPCRequest {
442+ JSONRPC : "2.0" ,
443+ ID : mcp .NewRequestId (elicitationReq .requestID ),
444+ Request : mcp.Request {
445+ Method : string (mcp .MethodElicitationCreate ),
446+ },
447+ Params : elicitationReq .request .Params ,
448+ }
449+ select {
450+ case writeChan <- jsonrpcRequest :
451+ case <- done :
452+ return
453+ }
440454 case <- done :
441455 return
442456 }
@@ -525,8 +539,8 @@ func writeSSEEvent(w io.Writer, data any) error {
525539 return nil
526540}
527541
528- // handleSamplingResponse processes incoming sampling responses from clients
529- func (s * StreamableHTTPServer ) handleSamplingResponse (w http.ResponseWriter , r * http.Request , responseMessage struct {
542+ // handleResponse processes incoming responses from clients
543+ func (s * StreamableHTTPServer ) handleResponse (w http.ResponseWriter , r * http.Request , responseMessage struct {
530544 ID json.RawMessage `json:"id"`
531545 Result json.RawMessage `json:"result,omitempty"`
532546 Error json.RawMessage `json:"error,omitempty"`
@@ -558,7 +572,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
558572 }
559573
560574 // Create the sampling response item
561- response := samplingResponseItem {
575+ response := responseItem {
562576 requestID : requestID ,
563577 }
564578
@@ -575,20 +589,14 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
575589 response .err = fmt .Errorf ("sampling error %d: %s" , jsonrpcError .Code , jsonrpcError .Message )
576590 }
577591 } else if responseMessage .Result != nil {
578- // Parse result
579- var result mcp.CreateMessageResult
580- if err := json .Unmarshal (responseMessage .Result , & result ); err != nil {
581- response .err = fmt .Errorf ("failed to parse sampling result: %v" , err )
582- } else {
583- response .result = & result
584- }
592+ response .result = responseMessage .Result
585593 } else {
586594 response .err = fmt .Errorf ("sampling response has neither result nor error" )
587595 }
588596
589597 // Find the corresponding session and deliver the response
590598 // The response is delivered to the specific session identified by sessionID
591- if err := s .deliverSamplingResponse (sessionID , response ); err != nil {
599+ if err := s .deliverResponse (sessionID , response ); err != nil {
592600 s .logger .Errorf ("Failed to deliver sampling response: %v" , err )
593601 http .Error (w , "Failed to deliver response" , http .StatusInternalServerError )
594602 return err
@@ -600,7 +608,7 @@ func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *
600608}
601609
602610// deliverSamplingResponse delivers a sampling response to the appropriate session
603- func (s * StreamableHTTPServer ) deliverSamplingResponse (sessionID string , response samplingResponseItem ) error {
611+ func (s * StreamableHTTPServer ) deliverResponse (sessionID string , response responseItem ) error {
604612 // Look up the active session
605613 sessionInterface , ok := s .activeSessions .Load (sessionID )
606614 if ! ok {
@@ -613,12 +621,12 @@ func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, respons
613621 }
614622
615623 // Look up the dedicated response channel for this specific request
616- responseChannelInterface , exists := session .samplingRequests .Load (response .requestID )
624+ responseChannelInterface , exists := session .requests .Load (response .requestID )
617625 if ! exists {
618626 return fmt .Errorf ("no pending request found for session %s, request %d" , sessionID , response .requestID )
619627 }
620628
621- responseChan , ok := responseChannelInterface .(chan samplingResponseItem )
629+ responseChan , ok := responseChannelInterface .(chan responseItem )
622630 if ! ok {
623631 return fmt .Errorf ("invalid response channel type for session %s, request %d" , sessionID , response .requestID )
624632 }
@@ -723,15 +731,22 @@ func (s *sessionToolsStore) delete(sessionID string) {
723731type samplingRequestItem struct {
724732 requestID int64
725733 request mcp.CreateMessageRequest
726- response chan samplingResponseItem
734+ response chan responseItem
727735}
728736
729- type samplingResponseItem struct {
737+ type responseItem struct {
730738 requestID int64
731- result * mcp. CreateMessageResult
739+ result json. RawMessage
732740 err error
733741}
734742
743+ // Elicitation support types for HTTP transport
744+ type elicitationRequestItem struct {
745+ requestID int64
746+ request mcp.ElicitationRequest
747+ response chan responseItem
748+ }
749+
735750// streamableHttpSession is a session for streamable-http transport
736751// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
737752// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
@@ -743,18 +758,21 @@ type streamableHttpSession struct {
743758 logLevels * sessionLogLevelsStore
744759
745760 // Sampling support for bidirectional communication
746- samplingRequestChan chan samplingRequestItem // server -> client sampling requests
747- samplingRequests sync.Map // requestID -> pending sampling request context
748- requestIDCounter atomic.Int64 // for generating unique request IDs
761+ samplingRequestChan chan samplingRequestItem // server -> client sampling requests
762+ elicitationRequestChan chan elicitationRequestItem // server -> client elicitation requests
763+
764+ requests sync.Map // requestID -> pending request context
765+ requestIDCounter atomic.Int64 // for generating unique request IDs
749766}
750767
751768func newStreamableHttpSession (sessionID string , toolStore * sessionToolsStore , levels * sessionLogLevelsStore ) * streamableHttpSession {
752769 s := & streamableHttpSession {
753- sessionID : sessionID ,
754- notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
755- tools : toolStore ,
756- logLevels : levels ,
757- samplingRequestChan : make (chan samplingRequestItem , 10 ),
770+ sessionID : sessionID ,
771+ notificationChannel : make (chan mcp.JSONRPCNotification , 100 ),
772+ tools : toolStore ,
773+ logLevels : levels ,
774+ samplingRequestChan : make (chan samplingRequestItem , 10 ),
775+ elicitationRequestChan : make (chan elicitationRequestItem , 10 ),
758776 }
759777 return s
760778}
@@ -810,21 +828,21 @@ var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
810828func (s * streamableHttpSession ) RequestSampling (ctx context.Context , request mcp.CreateMessageRequest ) (* mcp.CreateMessageResult , error ) {
811829 // Generate unique request ID
812830 requestID := s .requestIDCounter .Add (1 )
813-
831+
814832 // Create response channel for this specific request
815- responseChan := make (chan samplingResponseItem , 1 )
816-
833+ responseChan := make (chan responseItem , 1 )
834+
817835 // Create the sampling request item
818836 samplingRequest := samplingRequestItem {
819837 requestID : requestID ,
820838 request : request ,
821839 response : responseChan ,
822840 }
823-
841+
824842 // Store the pending request
825- s .samplingRequests .Store (requestID , responseChan )
826- defer s .samplingRequests .Delete (requestID )
827-
843+ s .requests .Store (requestID , responseChan )
844+ defer s .requests .Delete (requestID )
845+
828846 // Send the sampling request via the channel (non-blocking)
829847 select {
830848 case s .samplingRequestChan <- samplingRequest :
@@ -834,20 +852,70 @@ func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp
834852 default :
835853 return nil , fmt .Errorf ("sampling request queue is full - server overloaded" )
836854 }
837-
855+
856+ // Wait for response or context cancellation
857+ select {
858+ case response := <- responseChan :
859+ if response .err != nil {
860+ return nil , response .err
861+ }
862+ var result mcp.CreateMessageResult
863+ if err := json .Unmarshal (response .result , & result ); err != nil {
864+ return nil , fmt .Errorf ("failed to unmarshal sampling response: %v" , err )
865+ }
866+ return & result , nil
867+ case <- ctx .Done ():
868+ return nil , ctx .Err ()
869+ }
870+ }
871+
872+ // RequestElicitation implements SessionWithElicitation interface for HTTP transport
873+ func (s * streamableHttpSession ) RequestElicitation (ctx context.Context , request mcp.ElicitationRequest ) (* mcp.ElicitationResult , error ) {
874+ // Generate unique request ID
875+ requestID := s .requestIDCounter .Add (1 )
876+
877+ // Create response channel for this specific request
878+ responseChan := make (chan responseItem , 1 )
879+
880+ // Create the sampling request item
881+ elicitationRequest := elicitationRequestItem {
882+ requestID : requestID ,
883+ request : request ,
884+ response : responseChan ,
885+ }
886+
887+ // Store the pending request
888+ s .requests .Store (requestID , responseChan )
889+ defer s .requests .Delete (requestID )
890+
891+ // Send the sampling request via the channel (non-blocking)
892+ select {
893+ case s .elicitationRequestChan <- elicitationRequest :
894+ // Request queued successfully
895+ case <- ctx .Done ():
896+ return nil , ctx .Err ()
897+ default :
898+ return nil , fmt .Errorf ("elicitation request queue is full - server overloaded" )
899+ }
900+
838901 // Wait for response or context cancellation
839902 select {
840903 case response := <- responseChan :
841904 if response .err != nil {
842905 return nil , response .err
843906 }
844- return response .result , nil
907+ var result mcp.ElicitationResult
908+ if err := json .Unmarshal (response .result , & result ); err != nil {
909+ return nil , fmt .Errorf ("failed to unmarshal elicitation response: %v" , err )
910+ }
911+ return & result , nil
845912 case <- ctx .Done ():
846913 return nil , ctx .Err ()
847914 }
848915}
849916
850917var _ SessionWithSampling = (* streamableHttpSession )(nil )
918+ var _ SessionWithElicitation = (* streamableHttpSession )(nil )
851919
852920// --- session id manager ---
853921
0 commit comments