diff --git a/AUTHORS b/AUTHORS index 16c2058b8..58ae037ee 100644 --- a/AUTHORS +++ b/AUTHORS @@ -56,6 +56,7 @@ Jens Timmerman Jerome Leclanche Jesse Gibbs Jim Graham +John Byrne Jonas Nygaard Pedersen Jonathan Steffan Jordi Sanchez diff --git a/CHANGELOG.md b/CHANGELOG.md index 292300ce2..93176fe4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] ### Added +* #1185 Add middleware for adding access token to request * #1273 Add caching of loading of OIDC private key. * #1285 Add post_logout_redirect_uris field in application views. diff --git a/docs/tutorial/tutorial_03.rst b/docs/tutorial/tutorial_03.rst index 09486c3d6..ef5d57969 100644 --- a/docs/tutorial/tutorial_03.rst +++ b/docs/tutorial/tutorial_03.rst @@ -47,6 +47,8 @@ will not try to get user from the session. If you use AuthenticationMiddleware, be sure it appears before OAuth2TokenMiddleware. However AuthenticationMiddleware is NOT required for using django-oauth-toolkit. +Note, `OAuth2TokenMiddleware` adds the user to the request object. There is also an optional `OAuth2ExtraTokenMiddleware` that adds the `Token` to the request. This makes it convenient to access the `Application` object within your views. To use it just add `oauth2_provider.middleware.OAuth2ExtraTokenMiddleware` to the `MIDDLEWARE` setting. + Protect your view ----------------- The authentication backend will run smoothly with, for example, `login_required` decorators, so diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index 17ba6c35f..28bd968f8 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -1,6 +1,13 @@ +import logging + from django.contrib.auth import authenticate from django.utils.cache import patch_vary_headers +from oauth2_provider.models import AccessToken + + +log = logging.getLogger(__name__) + class OAuth2TokenMiddleware: """ @@ -36,3 +43,20 @@ def __call__(self, request): response = self.get_response(request) patch_vary_headers(response, ("Authorization",)) return response + + +class OAuth2ExtraTokenMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + authheader = request.META.get("HTTP_AUTHORIZATION", "") + if authheader.startswith("Bearer"): + tokenstring = authheader.split()[1] + try: + token = AccessToken.objects.get(token=tokenstring) + request.access_token = token + except AccessToken.DoesNotExist as e: + log.exception(e) + response = self.get_response(request) + return response diff --git a/tests/test_auth_backends.py b/tests/test_auth_backends.py index 8eeb8ef12..6b958ecb0 100644 --- a/tests/test_auth_backends.py +++ b/tests/test_auth_backends.py @@ -10,7 +10,7 @@ from django.utils.timezone import now, timedelta from oauth2_provider.backends import OAuth2Backend -from oauth2_provider.middleware import OAuth2TokenMiddleware +from oauth2_provider.middleware import OAuth2ExtraTokenMiddleware, OAuth2TokenMiddleware from oauth2_provider.models import get_access_token_model, get_application_model @@ -162,3 +162,62 @@ def test_middleware_response_header(self): response = m(request) self.assertIn("Vary", response) self.assertIn("Authorization", response["Vary"]) + + +@override_settings( + AUTHENTICATION_BACKENDS=( + "oauth2_provider.backends.OAuth2Backend", + "django.contrib.auth.backends.ModelBackend", + ), +) +@modify_settings( + MIDDLEWARE={ + "append": "oauth2_provider.middleware.OAuth2TokenMiddleware", + } +) +class TestOAuth2ExtraTokenMiddleware(BaseTest): + def setUp(self): + super().setUp() + self.anon_user = AnonymousUser() + + def dummy_get_response(self, request): + return HttpResponse() + + def test_middleware_wrong_headers(self): + m = OAuth2ExtraTokenMiddleware(self.dummy_get_response) + request = self.factory.get("/a-resource") + m(request) + self.assertFalse(hasattr(request, "access_token")) + auth_headers = { + "HTTP_AUTHORIZATION": "Beerer " + "badstring", # a Beer token for you! + } + request = self.factory.get("/a-resource", **auth_headers) + m(request) + self.assertFalse(hasattr(request, "access_token")) + + def test_middleware_token_does_not_exist(self): + m = OAuth2ExtraTokenMiddleware(self.dummy_get_response) + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "badtokstr", + } + request = self.factory.get("/a-resource", **auth_headers) + m(request) + self.assertFalse(hasattr(request, "access_token")) + + def test_middleware_success(self): + m = OAuth2ExtraTokenMiddleware(self.dummy_get_response) + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "tokstr", + } + request = self.factory.get("/a-resource", **auth_headers) + m(request) + self.assertEqual(request.access_token, self.token) + + def test_middleware_response(self): + m = OAuth2ExtraTokenMiddleware(self.dummy_get_response) + auth_headers = { + "HTTP_AUTHORIZATION": "Bearer " + "tokstr", + } + request = self.factory.get("/a-resource", **auth_headers) + response = m(request) + self.assertIsInstance(response, HttpResponse)