From 295721a5f12c48e911a1efcd278431bc7d7a5dd5 Mon Sep 17 00:00:00 2001 From: Luca Chang Date: Wed, 23 Jul 2025 13:37:16 -0700 Subject: [PATCH] fix: perform auth server metadata discovery fallbacks on any 4xx --- src/mcp/client/auth.py | 4 +- tests/client/test_auth.py | 103 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index b00db7b9b..775fb0f6c 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -526,8 +526,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. break except ValidationError: continue - elif oauth_metadata_response.status_code != 404: - break # Non-404 error, stop trying + elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: + break # Non-4XX error, stop trying # Step 3: Register client if needed registration_request = await self._register_client() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 46208d69c..bb962bfc1 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -261,6 +261,109 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider): "https://api.example.com/v1/mcp/.well-known/openid-configuration", ] + @pytest.mark.anyio + async def test_oauth_discovery_fallback_conditions(self, oauth_provider): + """Test the conditions during which an AS metadata discovery fallback will be attempted.""" + # Ensure no tokens are stored + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider._initialized = True + + # Mock client info to skip DCR + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + + # Mock the auth flow + auth_flow = oauth_provider.async_auth_flow(test_request) + + # First request should be the original request without auth header + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send a 401 response to trigger the OAuth flow + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + + # Next request should be to discover protected resource metadata + discovery_request = await auth_flow.asend(response) + assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert discovery_request.method == "GET" + + # Send a successful discovery response with minimal protected resource metadata + discovery_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}', + request=discovery_request, + ) + + # Next request should be to discover OAuth metadata + oauth_metadata_request_1 = await auth_flow.asend(discovery_response) + assert ( + str(oauth_metadata_request_1.url) + == "https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp" + ) + assert oauth_metadata_request_1.method == "GET" + + # Send a 404 response + oauth_metadata_response_1 = httpx.Response( + 404, + content=b"Not Found", + request=oauth_metadata_request_1, + ) + + # Next request should be to discover OAuth metadata at the next endpoint + oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1) + assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/oauth-authorization-server" + assert oauth_metadata_request_2.method == "GET" + + # Send a 400 response + oauth_metadata_response_2 = httpx.Response( + 400, + content=b"Bad Request", + request=oauth_metadata_request_2, + ) + + # Next request should be to discover OAuth metadata at the next endpoint + oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2) + assert str(oauth_metadata_request_3.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp" + assert oauth_metadata_request_3.method == "GET" + + # Send a 500 response + oauth_metadata_response_3 = httpx.Response( + 500, + content=b"Internal Server Error", + request=oauth_metadata_request_3, + ) + + # Mock the authorization process to minimize unnecessary state in this test + oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token) + token_request = await auth_flow.asend(oauth_metadata_response_3) + assert str(token_request.url) == "https://api.example.com/token" + assert token_request.method == "POST" + + # Send a successful token response + token_response = httpx.Response( + 200, + content=( + b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, ' + b'"refresh_token": "new_refresh_token"}' + ), + request=token_request, + ) + token_request = await auth_flow.asend(token_response) + @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider): """Test successful metadata response handling."""