diff --git a/AUTHORS b/AUTHORS index 77ccc1eff..fa3820f64 100644 --- a/AUTHORS +++ b/AUTHORS @@ -41,6 +41,7 @@ Hossein Shakiba Hiroki Kiyohara Jens Timmerman Jerome Leclanche +Jesse Gibbs Jim Graham Jonas Nygaard Pedersen Jonathan Steffan diff --git a/CHANGELOG.md b/CHANGELOG.md index 675a055f1..abc5a401d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] ### Added +* Support `prompt=login` for the OIDC Authorization Code Flow end user [Authentication Request](https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest). * Add spanish (es) translations. ### Changed diff --git a/docs/oidc.rst b/docs/oidc.rst index 4b427ba86..2211a972a 100644 --- a/docs/oidc.rst +++ b/docs/oidc.rst @@ -359,6 +359,15 @@ token, so you will probably want to re-use that:: claims["color_scheme"] = get_color_scheme(request.user) return claims +Customizing the login flow +========================== + +Clients can request that the user logs in each time a request to the +``/authorize`` endpoint is made during the OIDC Authorization Code Flow by +adding the ``prompt=login`` query parameter and value. Only ``login`` is +currently supported. See +OIDC's `3.1.2.1 Authentication Request `_ +for details. OIDC Views ========== diff --git a/oauth2_provider/views/base.py b/oauth2_provider/views/base.py index 211da45ed..abaa81f59 100644 --- a/oauth2_provider/views/base.py +++ b/oauth2_provider/views/base.py @@ -1,8 +1,11 @@ import json import logging +from urllib.parse import parse_qsl, urlencode, urlparse from django.contrib.auth.mixins import LoginRequiredMixin +from django.contrib.auth.views import redirect_to_login from django.http import HttpResponse +from django.shortcuts import resolve_url from django.utils import timezone from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt @@ -144,6 +147,10 @@ def get(self, request, *args, **kwargs): # Application is not available at this time. return self.error_response(error, application=None) + prompt = request.GET.get("prompt") + if prompt == "login": + return self.handle_prompt_login() + all_scopes = get_scopes_backend().get_all_scopes() kwargs["scopes_descriptions"] = [all_scopes[scope] for scope in scopes] kwargs["scopes"] = scopes @@ -211,6 +218,32 @@ def get(self, request, *args, **kwargs): return self.render_to_response(self.get_context_data(**kwargs)) + def handle_prompt_login(self): + path = self.request.build_absolute_uri() + resolved_login_url = resolve_url(self.get_login_url()) + + # If the login url is the same scheme and net location then use the + # path as the "next" url. + login_scheme, login_netloc = urlparse(resolved_login_url)[:2] + current_scheme, current_netloc = urlparse(path)[:2] + if (not login_scheme or login_scheme == current_scheme) and ( + not login_netloc or login_netloc == current_netloc + ): + path = self.request.get_full_path() + + parsed = urlparse(path) + + parsed_query = dict(parse_qsl(parsed.query)) + parsed_query.pop("prompt") + + parsed = parsed._replace(query=urlencode(parsed_query)) + + return redirect_to_login( + parsed.geturl(), + resolved_login_url, + self.get_redirect_field_name(), + ) + @method_decorator(csrf_exempt, name="dispatch") class TokenView(OAuthLibMixin, View): diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index 8bface719..924bdc1db 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urlparse import pytest +from django.conf import settings from django.contrib.auth import get_user_model from django.test import RequestFactory, TestCase from django.urls import reverse @@ -612,6 +613,40 @@ def test_id_token_code_post_auth_allow(self): self.assertIn("state=random_state_string", response["Location"]) self.assertIn("code=", response["Location"]) + def test_prompt_login(self): + """ + Test response for redirect when supplied with prompt: login + """ + self.oauth2_settings.PKCE_REQUIRED = False + self.client.login(username="test_user", password="123456") + + query_data = { + "client_id": self.application.client_id, + "response_type": "code", + "state": "random_state_string", + "scope": "read write", + "redirect_uri": "http://example.org", + "prompt": "login", + } + + response = self.client.get(reverse("oauth2_provider:authorize"), data=query_data) + + self.assertEqual(response.status_code, 302) + + scheme, netloc, path, params, query, fragment = urlparse(response["Location"]) + + self.assertEqual(path, settings.LOGIN_URL) + + parsed_query = parse_qs(query) + next = parsed_query["next"][0] + + self.assertIn("redirect_uri=http%3A%2F%2Fexample.org", next) + self.assertIn("state=random_state_string", next) + self.assertIn("scope=read+write", next) + self.assertIn(f"client_id={self.application.client_id}", next) + + self.assertNotIn("prompt=login", next) + class BaseAuthorizationCodeTokenView(BaseTest): def get_auth(self, scope="read write"):