Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,5 @@ pySilver
Łukasz Skarżyński
Shaheed Haque
Peter Karman
Andrea Greco
Vinay Karanam
Eduardo Oliveira
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
* #1012 Return status for introspecting a nonexistent token from 401 to the correct value of 200 per [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662#section-2.2).
* #1068 Revert #967 which incorrectly changed an API. See #1066.

## [1.6.1] 2021-12-23

Expand Down
19 changes: 9 additions & 10 deletions docs/oidc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,16 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc),
and the ``sub`` claim will use the primary key of the user as the value.
You'll probably want to customize this and add additional claims or change
what is sent for the ``sub`` claim. To do so, you will need to add a method to
our custom validator.
Standard claim ``sub`` is included by default, for remove it override ``get_claim_list``::
our custom validator::

class CustomOAuth2Validator(OAuth2Validator):
def get_additional_claims(self):
def get_user_email(request):
return request.user.get_full_name()

# Element name, callback to obtain data
claims_list = [ ("email", get_sub_cod),
("username", get_user_email) ]
return claims_list

def get_additional_claims(self, request):
return {
"sub": request.user.email,
"first_name": request.user.first_name,
"last_name": request.user.last_name,
}

.. note::
This ``request`` object is not a ``django.http.Request`` object, but an
Expand Down
25 changes: 8 additions & 17 deletions oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,24 +740,15 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs):
def get_jwt_bearer_token(self, token, token_handler, request):
return self.get_id_token(token, token_handler, request)

def get_claim_list(self):
def get_sub_code(request):
return str(request.user.id)

list = [("sub", get_sub_code)]
def get_oidc_claims(self, token, token_handler, request):
# Required OIDC claims
claims = {
"sub": str(request.user.id),
}

# https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims
add = self.get_additional_claims()
list.extend(add)

return list
claims.update(**self.get_additional_claims(request))

def get_oidc_claims(self, token, token_handler, request):
data = self.get_claim_list()
claims = {}

for k, call in data:
claims[k] = call(request)
return claims

def get_id_token_dictionary(self, token, token_handler, request):
Expand Down Expand Up @@ -910,5 +901,5 @@ def get_userinfo_claims(self, request):
"""
return self.get_oidc_claims(None, None, request)

def get_additional_claims(self):
return []
def get_additional_claims(self, request):
return {}
8 changes: 0 additions & 8 deletions oauth2_provider/views/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def get(self, request, *args, **kwargs):
signing_algorithms = [Application.HS256_ALGORITHM]
if oauth2_settings.OIDC_RSA_PRIVATE_KEY:
signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM]

validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
validator = validator_class()
oidc_claims = []
for el, _ in validator.get_claim_list():
oidc_claims.append(el)

data = {
"issuer": issuer_url,
"authorization_endpoint": authorization_endpoint,
Expand All @@ -64,7 +57,6 @@ def get(self, request, *args, **kwargs):
"token_endpoint_auth_methods_supported": (
oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED
),
"claims_supported": oidc_claims,
}
response = JsonResponse(data)
response["Access-Control-Allow-Origin"] = "*"
Expand Down
24 changes: 4 additions & 20 deletions tests/test_oidc_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def test_get_connect_discovery_info(self):
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256", "HS256"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"claims_supported": ["sub"],
}
response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
self.assertEqual(response.status_code, 200)
Expand All @@ -56,7 +55,6 @@ def test_get_connect_discovery_info_without_issuer_url(self):
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256", "HS256"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
"claims_supported": ["sub"],
}
response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info"))
self.assertEqual(response.status_code, 200)
Expand Down Expand Up @@ -148,21 +146,11 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client):
assert rsp.status_code == 401


EXAMPLE_EMAIL = "[email protected]"


def claim_user_email(request):
return EXAMPLE_EMAIL


@pytest.mark.django_db
def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings):
class CustomValidator(OAuth2Validator):
def get_additional_claims(self):
return [
("username", claim_user_email),
("email", claim_user_email),
]
def get_additional_claims(self, request):
return {"state": "very nice"}

oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator
auth_header = "Bearer %s" % oidc_tokens.access_token
Expand All @@ -173,9 +161,5 @@ def get_additional_claims(self):
data = rsp.json()
assert "sub" in data
assert data["sub"] == str(oidc_tokens.user.pk)

assert "username" in data
assert data["username"] == EXAMPLE_EMAIL

assert "email" in data
assert data["email"] == EXAMPLE_EMAIL
assert "state" in data
assert data["state"] == "very nice"