Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 75 additions & 3 deletions client/transport/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,88 @@ type OAuthHandler struct {
expectedState string // Expected state value for CSRF protection
}

// NewOAuthHandler creates a new OAuth handler
func NewOAuthHandler(config OAuthConfig) *OAuthHandler {
type OAuthHandlerOption func(*OAuthHandler)

// WithOAuthHTTPClient allows setting a custom http.Client for the OAuthHandler.
//
// This is useful when you need to:
// - Configure custom timeouts for OAuth operations (token exchange, refresh, metadata discovery)
// - Use a custom transport (e.g., for proxy settings, TLS configuration, custom CA certificates)
// - Set connection pooling parameters (MaxIdleConns, IdleConnTimeout, etc.)
// - Add request/response interceptors or custom retry logic
// - Configure specific dial or keep-alive settings
//
// If not specified, a default http.Client with a 30-second timeout is used.
// If nil is passed, the option is ignored and the default client is retained.
//
// Validation: If the provided client has a zero timeout, a default 30-second timeout
// will be automatically applied to prevent indefinite hangs during OAuth operations.
// This ensures robust behavior even if users forget to set a timeout.
//
// Note: The provided client will be used for all OAuth-related HTTP requests including:
// - OAuth server metadata discovery (/.well-known/oauth-authorization-server, etc.)
// - Token endpoint requests (authorization code exchange, token refresh)
// - Dynamic client registration
//
// Example usage with custom timeout:
//
// customClient := &http.Client{
// Timeout: 10 * time.Second,
// }
// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient))
//
// Example usage with custom transport:
//
// customTransport := &http.Transport{
// MaxIdleConns: 100,
// IdleConnTimeout: 90 * time.Second,
// TLSHandshakeTimeout: 10 * time.Second,
// }
// customClient := &http.Client{
// Timeout: 15 * time.Second,
// Transport: customTransport,
// }
// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient))
//
// Example usage with proxy:
//
// proxyURL, _ := url.Parse("http://proxy.example.com:8080")
// customTransport := &http.Transport{
// Proxy: http.ProxyURL(proxyURL),
// }
// customClient := &http.Client{
// Transport: customTransport,
// }
// handler := transport.NewOAuthHandler(config, transport.WithOAuthHTTPClient(customClient))
func WithOAuthHTTPClient(client *http.Client) OAuthHandlerOption {
return func(h *OAuthHandler) {
if client != nil {
// If client has zero timeout, set a reasonable default to prevent indefinite hangs
if client.Timeout == 0 {
client.Timeout = 30 * time.Second
}
h.httpClient = client
}
}
}

// NewOAuthHandler creates a new OAuth handler.
// Optionally accepts functional options such as WithOAuthHTTPClient.
func NewOAuthHandler(config OAuthConfig, opts ...OAuthHandlerOption) *OAuthHandler {
if config.TokenStore == nil {
config.TokenStore = NewMemoryTokenStore()
}

return &OAuthHandler{
handler := &OAuthHandler{
config: config,
httpClient: &http.Client{Timeout: 30 * time.Second},
}

for _, opt := range opts {
opt(handler)
}

return handler
}

// GetAuthorizationHeader returns the Authorization header value for a request
Expand Down
285 changes: 285 additions & 0 deletions client/transport/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,288 @@ func TestOAuthHandler_GetServerMetadata_FallbackToDefaultEndpoints(t *testing.T)
t.Errorf("Expected token endpoint to be %s/token, got %s", server.URL, metadata.TokenEndpoint)
}
}

func TestWithOAuthHTTPClient(t *testing.T) {
t.Run("WithOAuthHTTPClient sets custom http client", func(t *testing.T) {
// Create a custom http.Client with specific timeout
customClient := &http.Client{
Timeout: 10 * time.Second,
}

config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
Scopes: []string{"mcp.read"},
TokenStore: NewMemoryTokenStore(),
}

// Create handler with custom client
handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify the custom client was set
if handler.httpClient != customClient {
t.Error("Expected custom http client to be set")
}
if handler.httpClient.Timeout != 10*time.Second {
t.Errorf("Expected timeout to be 10s, got %v", handler.httpClient.Timeout)
}
})

t.Run("WithOAuthHTTPClient ignores nil client", func(t *testing.T) {
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
Scopes: []string{"mcp.read"},
TokenStore: NewMemoryTokenStore(),
}

// Create handler with nil client
handler := NewOAuthHandler(config, WithOAuthHTTPClient(nil))

// Verify the default client is still used
if handler.httpClient == nil {
t.Error("Expected default http client to be set")
}
if handler.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected default timeout to be 30s, got %v", handler.httpClient.Timeout)
}
})

t.Run("Multiple options including WithOAuthHTTPClient", func(t *testing.T) {
customClient := &http.Client{
Timeout: 5 * time.Second,
}

config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
Scopes: []string{"mcp.read"},
TokenStore: NewMemoryTokenStore(),
}

// Create handler with custom client option
handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify both the config and custom client were applied
if handler.config.ClientID != "test-client" {
t.Errorf("Expected client ID to be test-client, got %s", handler.config.ClientID)
}
if handler.httpClient != customClient {
t.Error("Expected custom http client to be set")
}
if handler.httpClient.Timeout != 5*time.Second {
t.Errorf("Expected timeout to be 5s, got %v", handler.httpClient.Timeout)
}
})

t.Run("Custom client with custom transport", func(t *testing.T) {
// Create a custom transport for testing
customTransport := &http.Transport{
MaxIdleConns: 100,
}
customClient := &http.Client{
Timeout: 15 * time.Second,
Transport: customTransport,
}

config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify the custom transport is preserved
if handler.httpClient.Transport != customTransport {
t.Error("Expected custom transport to be preserved")
}
})

t.Run("Custom client is used for OAuth requests", func(t *testing.T) {
// Track if the custom client was actually used
requestMade := false

// Create a test server
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestMade = true
switch r.URL.Path {
case "/.well-known/oauth-protected-resource":
w.WriteHeader(http.StatusNotFound)
case "/.well-known/oauth-authorization-server":
metadata := AuthServerMetadata{
Issuer: server.URL,
AuthorizationEndpoint: server.URL + "/authorize",
TokenEndpoint: server.URL + "/token",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(metadata); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()

// Create custom client with very short timeout
customClient := &http.Client{
Timeout: 2 * time.Second, // Short but reasonable timeout
}

config := OAuthConfig{
ClientID: "test-client",
RedirectURI: server.URL + "/callback",
TokenStore: NewMemoryTokenStore(),
AuthServerMetadataURL: "",
}

handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))
handler.SetBaseURL(server.URL)

// Make a request that should use the custom client
_, err := handler.GetServerMetadata(context.Background())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}

// Verify the request was made
if !requestMade {
t.Error("Expected request to be made using custom client")
}
})

t.Run("Without WithOAuthHTTPClient uses default client", func(t *testing.T) {
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

// Create handler without custom client option
handler := NewOAuthHandler(config)

// Verify default client is used
if handler.httpClient == nil {
t.Error("Expected default http client to be set")
}
if handler.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected default timeout to be 30s, got %v", handler.httpClient.Timeout)
}
})

t.Run("WithOAuthHTTPClient validates and sets default timeout for zero timeout", func(t *testing.T) {
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

// Create a client with zero timeout
customClient := &http.Client{
Timeout: 0, // Zero timeout
}

// Original client should have zero timeout
if customClient.Timeout != 0 {
t.Errorf("Expected original client timeout to be 0, got %v", customClient.Timeout)
}

// Create handler with zero-timeout client
handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify the timeout was automatically set to 30 seconds
if handler.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected timeout to be auto-set to 30s, got %v", handler.httpClient.Timeout)
}

// Verify the same client instance was used (modified in place)
if handler.httpClient != customClient {
t.Error("Expected the same client instance to be used")
}

// Verify the client was modified in place
if customClient.Timeout != 30*time.Second {
t.Errorf("Expected client timeout to be modified to 30s, got %v", customClient.Timeout)
}
})

t.Run("WithOAuthHTTPClient preserves non-zero timeout", func(t *testing.T) {
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

// Create a client with explicit timeout
customClient := &http.Client{
Timeout: 5 * time.Second,
}

handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify the original timeout is preserved
if handler.httpClient.Timeout != 5*time.Second {
t.Errorf("Expected timeout to remain 5s, got %v", handler.httpClient.Timeout)
}
})

t.Run("WithOAuthHTTPClient zero timeout with custom transport", func(t *testing.T) {
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

// Create a client with custom transport but zero timeout
customTransport := &http.Transport{
MaxIdleConns: 100,
}
customClient := &http.Client{
Timeout: 0, // Zero timeout
Transport: customTransport,
}

handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// Verify timeout was set
if handler.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected timeout to be auto-set to 30s, got %v", handler.httpClient.Timeout)
}

// Verify transport is preserved
if handler.httpClient.Transport != customTransport {
t.Error("Expected custom transport to be preserved")
}
})

t.Run("WithOAuthHTTPClient validation prevents indefinite hangs", func(t *testing.T) {
// This test demonstrates that the validation prevents potential production issues
config := OAuthConfig{
ClientID: "test-client",
RedirectURI: "http://localhost:8085/callback",
TokenStore: NewMemoryTokenStore(),
}

// User might create a client without thinking about timeout
customClient := &http.Client{
Transport: &http.Transport{
MaxIdleConns: 50,
},
// Forgot to set Timeout!
}

handler := NewOAuthHandler(config, WithOAuthHTTPClient(customClient))

// The validation should have added a timeout automatically
if handler.httpClient.Timeout == 0 {
t.Error("Expected validation to prevent zero timeout, but it's still zero")
}

if handler.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected validation to set 30s timeout, got %v", handler.httpClient.Timeout)
}
})
}
Loading