Skip to content

Add claims_supported to discovery info, without breaking the API #1069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@ pySilver
Shaheed Haque
Vinay Karanam
Eduardo Oliveira
Andrea Greco
Dominik George
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* #651 Batch expired token deletions in `cleartokens` management command
* Added pt-BR translations.
* #1070 Add a Celery task for clearing expired tokens, e.g. to be scheduled as a [periodic task](https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html)
* #1069 OIDC: Re-introduce [additional claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#adding-claims-to-the-id-token) beyond `sub` to the id_token.

### Fixed
* #1012 Return status for introspecting a nonexistent token from 401 to the correct value of 200 per [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662#section-2.2).
Expand Down
43 changes: 36 additions & 7 deletions docs/oidc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,45 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc),
and the ``sub`` claim will use the primary key of the user as the value.
You'll probably want to customize this and add additional claims or change
what is sent for the ``sub`` claim. To do so, you will need to add a method to
our custom validator::
our custom validator. It takes one of two forms:

The first form gets passed a request object, and should return a dictionary
mapping a claim name to claim data::
class CustomOAuth2Validator(OAuth2Validator):

def get_additional_claims(self, request):
return {
"sub": request.user.email,
"first_name": request.user.first_name,
"last_name": request.user.last_name,
}
claims = {}
claims["email"] = request.user.get_user_email()
claims["username"] = request.user.get_full_name()

return claims

The second form gets no request object, and should return a dictionary
mapping a claim name to a callable, accepting a request and producing
the claim data::
class CustomOAuth2Validator(OAuth2Validator):
def get_additional_claims(self):
def get_user_email(request):
return request.user.get_user_email()

claims = {}
claims["email"] = get_user_email
claims["username"] = lambda r: r.user.get_full_name()

return claims

Standard claim ``sub`` is included by default, to remove it override ``get_claim_dict``.

In some cases, it might be desirable to not list all claims in discovery info. To customize
which claims are advertised, you can override the ``get_discovery_claims`` method to return
a list of claim names to advertise. If your ``get_additional_claims`` uses the first form
and you still want to advertise claims, you can also override ``get_discovery_claims``.

In order to help lcients discover claims early, they can be advertised in the discovery
info, under the ``claims_supported`` key. In order for the discovery info view to automatically
add all claims your validator returns, you need to use the second form (producing callables),
because the discovery info views are requested with an unauthenticated request, so directly
producing claim data would fail. If you use the first form, producing claim data directly,
your claims will not be added to discovery info.

.. note::
This ``request`` object is not a ``django.http.Request`` object, but an
Expand Down
35 changes: 29 additions & 6 deletions oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import binascii
import http.client
import inspect
import json
import logging
import uuid
Expand Down Expand Up @@ -725,18 +726,40 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs):
)
return id_token

@classmethod
def _get_additional_claims_is_request_agnostic(cls):
return len(inspect.signature(cls.get_additional_claims).parameters) == 1

def get_jwt_bearer_token(self, token, token_handler, request):
return self.get_id_token(token, token_handler, request)

def get_oidc_claims(self, token, token_handler, request):
# Required OIDC claims
claims = {
"sub": str(request.user.id),
}
def get_claim_dict(self, request):
if self._get_additional_claims_is_request_agnostic():
claims = {"sub": lambda r: str(r.user.id)}
else:
claims = {"sub": str(request.user.id)}

# https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
claims.update(**self.get_additional_claims(request))
if self._get_additional_claims_is_request_agnostic():
add = self.get_additional_claims()
else:
add = self.get_additional_claims(request)
claims.update(add)

return claims

def get_discovery_claims(self, request):
claims = ["sub"]
if self._get_additional_claims_is_request_agnostic():
claims += list(self.get_claim_dict(request).keys())
return claims

def get_oidc_claims(self, token, token_handler, request):
data = self.get_claim_dict(request)
claims = {}

for k, v in data.items():
claims[k] = v(request) if callable(v) else v
return claims

def get_id_token_dictionary(self, token, token_handler, request):
Expand Down
6 changes: 6 additions & 0 deletions oauth2_provider/views/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def get(self, request, *args, **kwargs):
signing_algorithms = [Application.HS256_ALGORITHM]
if oauth2_settings.OIDC_RSA_PRIVATE_KEY:
signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM]

validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
validator = validator_class()
oidc_claims = list(set(validator.get_discovery_claims(request)))

data = {
"issuer": issuer_url,
"authorization_endpoint": authorization_endpoint,
Expand All @@ -57,6 +62,7 @@ def get(self, request, *args, **kwargs):
"token_endpoint_auth_methods_supported": (
oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED
),
"claims_supported": oidc_claims,
}
response = JsonResponse(data)
response["Access-Control-Allow-Origin"] = "*"
Expand Down
50 changes: 46 additions & 4 deletions tests/test_oidc_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_get_connect_discovery_info(self):
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256", "HS256"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"claims_supported": ["sub"],
}
response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
self.assertEqual(response.status_code, 200)
Expand All @@ -55,6 +56,7 @@ def test_get_connect_discovery_info_without_issuer_url(self):
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256", "HS256"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"claims_supported": ["sub"],
}
response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
self.assertEqual(response.status_code, 200)
Expand Down Expand Up @@ -146,11 +148,47 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client):
assert rsp.status_code == 401


EXAMPLE_EMAIL = "[email protected]"


def claim_user_email(request):
return EXAMPLE_EMAIL


@pytest.mark.django_db
def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
def get_additional_claims(self):
return {
"username": claim_user_email,
"email": claim_user_email,
}

oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator
auth_header = "Bearer %s" % oidc_tokens.access_token
rsp = client.get(
reverse("oauth2_provider:user-info"),
HTTP_AUTHORIZATION=auth_header,
)
data = rsp.json()
assert "sub" in data
assert data["sub"] == str(oidc_tokens.user.pk)

assert "username" in data
assert data["username"] == EXAMPLE_EMAIL

assert "email" in data
assert data["email"] == EXAMPLE_EMAIL


@pytest.mark.django_db
def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings):
def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
def get_additional_claims(self, request):
return {"state": "very nice"}
return {
"username": EXAMPLE_EMAIL,
"email": EXAMPLE_EMAIL,
}

oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator
auth_header = "Bearer %s" % oidc_tokens.access_token
Expand All @@ -161,5 +199,9 @@ def get_additional_claims(self, request):
data = rsp.json()
assert "sub" in data
assert data["sub"] == str(oidc_tokens.user.pk)
assert "state" in data
assert data["state"] == "very nice"

assert "username" in data
assert data["username"] == EXAMPLE_EMAIL

assert "email" in data
assert data["email"] == EXAMPLE_EMAIL