Skip to content

Commit ff7c2e2

Browse files
committed
Split out redirect uri logic for easier testing
This adds some unit tests for loopback IP code in particular, as part of reviewing the change
1 parent a410883 commit ff7c2e2

File tree

2 files changed

+69
-23
lines changed

2 files changed

+69
-23
lines changed

oauth2_provider/models.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,7 @@ def redirect_uri_allowed(self, uri):
125125
126126
:param uri: Url to check
127127
"""
128-
parsed_uri = urlparse(uri)
129-
uqs_set = set(parse_qsl(parsed_uri.query))
130-
for allowed_uri in self.redirect_uris.split():
131-
parsed_allowed_uri = urlparse(allowed_uri)
132-
133-
if (
134-
parsed_allowed_uri.scheme == parsed_uri.scheme == "http"
135-
and parsed_uri.hostname in ["127.0.0.1", "::1"]
136-
and isinstance(parsed_allowed_uri.port, type(None))
137-
and parsed_allowed_uri.hostname == parsed_uri.hostname
138-
and parsed_allowed_uri.path == parsed_uri.path
139-
) or (
140-
parsed_allowed_uri.scheme == parsed_uri.scheme
141-
and parsed_allowed_uri.netloc == parsed_uri.netloc
142-
and parsed_allowed_uri.path == parsed_uri.path
143-
):
144-
145-
aqs_set = set(parse_qsl(parsed_allowed_uri.query))
146-
147-
if aqs_set.issubset(uqs_set):
148-
return True
149-
150-
return False
128+
return redirect_to_uri_allowed(uri, self.redirect_uris.split())
151129

152130
def clean(self):
153131
from django.core.exceptions import ValidationError
@@ -680,3 +658,50 @@ def clear_expired():
680658

681659
access_tokens.delete()
682660
grants.delete()
661+
662+
663+
def redirect_to_uri_allowed(uri, allowed_uris):
664+
"""
665+
Checks if a given uri can be redirected to based on the provided allowed_uris configuration.
666+
667+
On top of exact matches, this function also handles loopback IPs based on RFC 8252.
668+
669+
:param uri: URI to check
670+
:param allowed_uris: A list of URIs that are allowed
671+
"""
672+
673+
parsed_uri = urlparse(uri)
674+
uqs_set = set(parse_qsl(parsed_uri.query))
675+
for allowed_uri in allowed_uris:
676+
parsed_allowed_uri = urlparse(allowed_uri)
677+
678+
# From RFC 8252 (Section 7.3)
679+
#
680+
# Loopback redirect URIs use the "http" scheme
681+
# [...]
682+
# The authorization server MUST allow any port to be specified at the
683+
# time of the request for loopback IP redirect URIs, to accommodate
684+
# clients that obtain an available ephemeral port from the operating
685+
# system at the time of the request.
686+
687+
allowed_uri_is_loopback = (
688+
parsed_allowed_uri.scheme == "http"
689+
and parsed_allowed_uri.hostname in ["127.0.0.1", "::1"]
690+
and parsed_allowed_uri.port is None
691+
)
692+
if (
693+
allowed_uri_is_loopback
694+
and parsed_allowed_uri.scheme == parsed_uri.scheme
695+
and parsed_allowed_uri.hostname == parsed_uri.hostname
696+
and parsed_allowed_uri.path == parsed_uri.path
697+
) or (
698+
parsed_allowed_uri.scheme == parsed_uri.scheme
699+
and parsed_allowed_uri.netloc == parsed_uri.netloc
700+
and parsed_allowed_uri.path == parsed_uri.path
701+
):
702+
703+
aqs_set = set(parse_qsl(parsed_allowed_uri.query))
704+
if aqs_set.issubset(uqs_set):
705+
return True
706+
707+
return False

tests/test_oauth2_backends.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from django.test import RequestFactory, TestCase
55

6+
from oauth2_provider.models import redirect_to_uri_allowed
67
from oauth2_provider.backends import get_oauthlib_core
78
from oauth2_provider.oauth2_backends import JSONOAuthLibCore, OAuthLibCore
89

@@ -110,3 +111,23 @@ def test_validate_authorization_request_unsafe_query(self):
110111

111112
oauthlib_core = get_oauthlib_core()
112113
oauthlib_core.verify_request(request, scopes=[])
114+
115+
116+
@pytest.mark.parametrize(
117+
"uri, expected_result",
118+
# localhost is _not_ a loopback URI
119+
[
120+
("http://localhost:3456", False),
121+
# only http scheme is supported for loopback URIs
122+
("https://127.0.0.1:3456", False),
123+
("http://127.0.0.1:3456", True),
124+
("http://[::1]", True),
125+
("http://[::1]:34", True),
126+
],
127+
)
128+
def test_uri_loopback_redirect_check(uri, expected_result):
129+
allowed_uris = ["http://127.0.0.1", "http://[::1]"]
130+
if expected_result:
131+
assert redirect_to_uri_allowed(uri, allowed_uris)
132+
else:
133+
assert not redirect_to_uri_allowed(uri, allowed_uris)

0 commit comments

Comments
 (0)