diff --git a/client/transport/oauth.go b/client/transport/oauth.go index bb702adfe..0204fd3c6 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -146,16 +146,88 @@ type OAuthHandler struct { expectedState string // Expected state value for CSRF protection } -// NewOAuthHandler creates a new OAuth handler -func NewOAuthHandler(config OAuthConfig) *OAuthHandler { +type OAuthHandlerOption func(*OAuthHandler) + +// WithOAuthHTTPClient allows setting a custom http.Client for the OAuthHandler. +// +// This is useful when you need to: +// - Configure custom timeouts for OAuth operations (token exchange, refresh, metadata discovery) +// - Use a custom transport (e.g., for proxy settings, TLS configuration, custom CA certificates) +// - Set connection pooling parameters (MaxIdleConns, IdleConnTimeout, etc.) +// - Add request/response interceptors or custom retry logic +// - Configure specific dial or keep-alive settings +// +// If not specified, a default http.Client with a 30-second timeout is used. +// If nil is passed, the option is ignored and the default client is retained. +// +// Validation: If the provided client has a zero timeout, a default 30-second timeout +// will be automatically applied to prevent indefinite hangs during OAuth operations. +// This ensures robust behavior even if users forget to set a timeout. +// +// Note: The provided client will be used for all OAuth-related HTTP requests including: +// - OAuth server metadata discovery (/.well-known/oauth-authorization-server, etc.) +// - Token endpoint requests (authorization code exchange, token refresh) +// - Dynamic client registration +// +// Example usage with custom timeout: +// +// customClient := &http.Client{ +// Timeout: 10 * time.Second, +// } +// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient)) +// +// Example usage with custom transport: +// +// customTransport := &http.Transport{ +// MaxIdleConns: 100, +// IdleConnTimeout: 90 * time.Second, +// TLSHandshakeTimeout: 10 * time.Second, +// } +// customClient := &http.Client{ +// Timeout: 15 * time.Second, +// Transport: customTransport, +// } +// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient)) +// +// Example usage with proxy: +// +// proxyURL, _ := url.Parse("http://proxy.example.com:8080") +// customTransport := &http.Transport{ +// Proxy: http.ProxyURL(proxyURL), +// } +// customClient := &http.Client{ +// Transport: customTransport, +// } +// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient)) +func WithOAuthHTTPClient(client *http.Client) OAuthHandlerOption { + return func(h *OAuthHandler) { + if client != nil { + // If client has zero timeout, set a reasonable default to prevent indefinite hangs + if client.Timeout == 0 { + client.Timeout = 30 * time.Second + } + h.httpClient = client + } + } +} + +// NewOAuthHandler creates a new OAuth handler. +// Optionally accepts functional options such as WithOAuthHTTPClient. +func NewOAuthHandler(config OAuthConfig, opts ...OAuthHandlerOption) *OAuthHandler { if config.TokenStore == nil { config.TokenStore = NewMemoryTokenStore() } - return &OAuthHandler{ + handler := &OAuthHandler{ config: config, httpClient: &http.Client{Timeout: 30 * time.Second}, } + + for _, opt := range opts { + opt(handler) + } + + return handler } // GetAuthorizationHeader returns the Authorization header value for a request diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index a6bc39d61..67b18b03e 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -940,3 +940,288 @@ func TestOAuthHandler_GetServerMetadata_FallbackToDefaultEndpoints(t *testing.T) t.Errorf("Expected token endpoint to be %s/token, got %s", server.URL, metadata.TokenEndpoint) } } + +func TestWithOAuthHTTPClient(t *testing.T) { + t.Run("WithOAuthHTTPClient sets custom http client", func(t *testing.T) { + // Create a custom http.Client with specific timeout + customClient := &http.Client{ + Timeout: 10 * time.Second, + } + + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read"}, + TokenStore: NewMemoryTokenStore(), + } + + // Create handler with custom client + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify the custom client was set + if handler.httpClient != customClient { + t.Error("Expected custom http client to be set") + } + if handler.httpClient.Timeout != 10*time.Second { + t.Errorf("Expected timeout to be 10s, got %v", handler.httpClient.Timeout) + } + }) + + t.Run("WithOAuthHTTPClient ignores nil client", func(t *testing.T) { + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read"}, + TokenStore: NewMemoryTokenStore(), + } + + // Create handler with nil client + handler := NewOAuthHandler(config, WithOAuthHTTPClient(nil)) + + // Verify the default client is still used + if handler.httpClient == nil { + t.Error("Expected default http client to be set") + } + if handler.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected default timeout to be 30s, got %v", handler.httpClient.Timeout) + } + }) + + t.Run("Multiple options including WithOAuthHTTPClient", func(t *testing.T) { + customClient := &http.Client{ + Timeout: 5 * time.Second, + } + + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read"}, + TokenStore: NewMemoryTokenStore(), + } + + // Create handler with custom client option + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify both the config and custom client were applied + if handler.config.ClientID != "test-client" { + t.Errorf("Expected client ID to be test-client, got %s", handler.config.ClientID) + } + if handler.httpClient != customClient { + t.Error("Expected custom http client to be set") + } + if handler.httpClient.Timeout != 5*time.Second { + t.Errorf("Expected timeout to be 5s, got %v", handler.httpClient.Timeout) + } + }) + + t.Run("Custom client with custom transport", func(t *testing.T) { + // Create a custom transport for testing + customTransport := &http.Transport{ + MaxIdleConns: 100, + } + customClient := &http.Client{ + Timeout: 15 * time.Second, + Transport: customTransport, + } + + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify the custom transport is preserved + if handler.httpClient.Transport != customTransport { + t.Error("Expected custom transport to be preserved") + } + }) + + t.Run("Custom client is used for OAuth requests", func(t *testing.T) { + // Track if the custom client was actually used + requestMade := false + + // Create a test server + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestMade = true + switch r.URL.Path { + case "/.well-known/oauth-protected-resource": + w.WriteHeader(http.StatusNotFound) + case "/.well-known/oauth-authorization-server": + metadata := AuthServerMetadata{ + Issuer: server.URL, + AuthorizationEndpoint: server.URL + "/authorize", + TokenEndpoint: server.URL + "/token", + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + // Create custom client with very short timeout + customClient := &http.Client{ + Timeout: 2 * time.Second, // Short but reasonable timeout + } + + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: server.URL + "/callback", + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "", + } + + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + handler.SetBaseURL(server.URL) + + // Make a request that should use the custom client + _, err := handler.GetServerMetadata(context.Background()) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // Verify the request was made + if !requestMade { + t.Error("Expected request to be made using custom client") + } + }) + + t.Run("Without WithOAuthHTTPClient uses default client", func(t *testing.T) { + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + // Create handler without custom client option + handler := NewOAuthHandler(config) + + // Verify default client is used + if handler.httpClient == nil { + t.Error("Expected default http client to be set") + } + if handler.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected default timeout to be 30s, got %v", handler.httpClient.Timeout) + } + }) + + t.Run("WithOAuthHTTPClient validates and sets default timeout for zero timeout", func(t *testing.T) { + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + // Create a client with zero timeout + customClient := &http.Client{ + Timeout: 0, // Zero timeout + } + + // Original client should have zero timeout + if customClient.Timeout != 0 { + t.Errorf("Expected original client timeout to be 0, got %v", customClient.Timeout) + } + + // Create handler with zero-timeout client + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify the timeout was automatically set to 30 seconds + if handler.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected timeout to be auto-set to 30s, got %v", handler.httpClient.Timeout) + } + + // Verify the same client instance was used (modified in place) + if handler.httpClient != customClient { + t.Error("Expected the same client instance to be used") + } + + // Verify the client was modified in place + if customClient.Timeout != 30*time.Second { + t.Errorf("Expected client timeout to be modified to 30s, got %v", customClient.Timeout) + } + }) + + t.Run("WithOAuthHTTPClient preserves non-zero timeout", func(t *testing.T) { + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + // Create a client with explicit timeout + customClient := &http.Client{ + Timeout: 5 * time.Second, + } + + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify the original timeout is preserved + if handler.httpClient.Timeout != 5*time.Second { + t.Errorf("Expected timeout to remain 5s, got %v", handler.httpClient.Timeout) + } + }) + + t.Run("WithOAuthHTTPClient zero timeout with custom transport", func(t *testing.T) { + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + // Create a client with custom transport but zero timeout + customTransport := &http.Transport{ + MaxIdleConns: 100, + } + customClient := &http.Client{ + Timeout: 0, // Zero timeout + Transport: customTransport, + } + + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // Verify timeout was set + if handler.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected timeout to be auto-set to 30s, got %v", handler.httpClient.Timeout) + } + + // Verify transport is preserved + if handler.httpClient.Transport != customTransport { + t.Error("Expected custom transport to be preserved") + } + }) + + t.Run("WithOAuthHTTPClient validation prevents indefinite hangs", func(t *testing.T) { + // This test demonstrates that the validation prevents potential production issues + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + TokenStore: NewMemoryTokenStore(), + } + + // User might create a client without thinking about timeout + customClient := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: 50, + }, + // Forgot to set Timeout! + } + + handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient)) + + // The validation should have added a timeout automatically + if handler.httpClient.Timeout == 0 { + t.Error("Expected validation to prevent zero timeout, but it's still zero") + } + + if handler.httpClient.Timeout != 30*time.Second { + t.Errorf("Expected validation to set 30s timeout, got %v", handler.httpClient.Timeout) + } + }) +}