Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
break
except ValidationError:
continue
elif oauth_metadata_response.status_code != 404:
break # Non-404 error, stop trying
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
break # Non-4XX error, stop trying

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand Down
103 changes: 103 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,109 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider):
"https://api.example.com/v1/mcp/.well-known/openid-configuration",
]

@pytest.mark.anyio
async def test_oauth_discovery_fallback_conditions(self, oauth_provider):
"""Test the conditions during which an AS metadata discovery fallback will be attempted."""
# Ensure no tokens are stored
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider._initialized = True

# Mock client info to skip DCR
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="existing_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)

# Create a test request
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")

# Mock the auth flow
auth_flow = oauth_provider.async_auth_flow(test_request)

# First request should be the original request without auth header
request = await auth_flow.__anext__()
assert "Authorization" not in request.headers

# Send a 401 response to trigger the OAuth flow
response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)

# Next request should be to discover protected resource metadata
discovery_request = await auth_flow.asend(response)
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert discovery_request.method == "GET"

# Send a successful discovery response with minimal protected resource metadata
discovery_response = httpx.Response(
200,
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}',
request=discovery_request,
)

# Next request should be to discover OAuth metadata
oauth_metadata_request_1 = await auth_flow.asend(discovery_response)
assert (
str(oauth_metadata_request_1.url)
== "https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp"
)
assert oauth_metadata_request_1.method == "GET"

# Send a 404 response
oauth_metadata_response_1 = httpx.Response(
404,
content=b"Not Found",
request=oauth_metadata_request_1,
)

# Next request should be to discover OAuth metadata at the next endpoint
oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1)
assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/oauth-authorization-server"
assert oauth_metadata_request_2.method == "GET"

# Send a 400 response
oauth_metadata_response_2 = httpx.Response(
400,
content=b"Bad Request",
request=oauth_metadata_request_2,
)

# Next request should be to discover OAuth metadata at the next endpoint
oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2)
assert str(oauth_metadata_request_3.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp"
assert oauth_metadata_request_3.method == "GET"

# Send a 500 response
oauth_metadata_response_3 = httpx.Response(
500,
content=b"Internal Server Error",
request=oauth_metadata_request_3,
)

# Mock the authorization process to minimize unnecessary state in this test
oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))

# Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token)
token_request = await auth_flow.asend(oauth_metadata_response_3)
assert str(token_request.url) == "https://api.example.com/token"
assert token_request.method == "POST"

# Send a successful token response
token_response = httpx.Response(
200,
content=(
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
b'"refresh_token": "new_refresh_token"}'
),
request=token_request,
)
token_request = await auth_flow.asend(token_response)

@pytest.mark.anyio
async def test_handle_metadata_response_success(self, oauth_provider):
"""Test successful metadata response handling."""
Expand Down
Loading