From 20b6fee3d26af7df4cd7bd27a1cca9458bea510a Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:15:29 -0700 Subject: [PATCH 1/3] Polish Tests Issue gh-16251 --- .../oauth2/jwt/NimbusJwtDecoderTests.java | 26 +++++++------------ 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index fb4535f240d..d638795df9a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,7 +60,6 @@ import org.springframework.cache.Cache; import org.springframework.cache.concurrent.ConcurrentMapCache; -import org.springframework.cache.support.SimpleValueWrapper; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; @@ -702,9 +701,8 @@ public void decodeWhenCacheStoredThenAbleToRetrieveJwkSetFromCache() { @Test public void decodeWhenCacheThenRetrieveFromCache() throws Exception { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); // @formatter:off NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .cache(cache) @@ -712,9 +710,7 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception { .build(); // @formatter:on jwtDecoder.decode(SIGNED_JWT); - verify(cache).get(eq(JWK_SET_URI), eq(String.class)); - verify(cache, times(2)).get(eq(JWK_SET_URI)); - verifyNoMoreInteractions(cache); + assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET); verifyNoInteractions(restOperations); } @@ -722,9 +718,8 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception { @Test public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(new SimpleValueWrapper(JWK_SET)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) .willReturn(new ResponseEntity<>(NEW_KID_JWK_SET, HttpStatus.OK)); @@ -794,9 +789,8 @@ public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtExceptio @Test public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIgnored() { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); // @formatter:off NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .cache(cache) @@ -804,9 +798,7 @@ public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIg .build(); // @formatter:on jwtDecoder.decode(SIGNED_JWT); - verify(cache).get(eq(JWK_SET_URI), eq(String.class)); - verify(cache, times(2)).get(eq(JWK_SET_URI)); - verifyNoMoreInteractions(cache); + assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET); verifyNoInteractions(restOperations); } From 234924799e95b1c6095f1dbc8108460e7c60f0e0 Mon Sep 17 00:00:00 2001 From: Daeho Kwon Date: Wed, 5 Feb 2025 02:51:32 +0900 Subject: [PATCH 2/3] Remove Deprecated Usages of RemoteJWKSet Closes gh-16251 Signed-off-by: Daeho Kwon --- .../security/oauth2/jwt/NimbusJwtDecoder.java | 180 +++++++++++------- .../security/oauth2/jwt/JwtDecodersTests.java | 3 +- 2 files changed, 115 insertions(+), 68 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 2713ee96b2d..732ecc2476a 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,12 @@ package org.springframework.security.oauth2.jwt; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.jwk.JWKMatcher; +import com.nimbusds.jose.jwk.JWKSelector; +import com.nimbusds.jose.jwk.source.JWKSetParseException; +import com.nimbusds.jose.jwk.source.JWKSetRetrievalException; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; @@ -26,8 +32,10 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.function.Function; @@ -35,17 +43,12 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; -import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jose.jwk.source.JWKSetCache; import com.nimbusds.jose.jwk.source.JWKSource; -import com.nimbusds.jose.jwk.source.RemoteJWKSet; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; -import com.nimbusds.jose.util.Resource; -import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; @@ -57,6 +60,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.cache.Cache; +import org.springframework.cache.support.NoOpCache; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -80,6 +84,7 @@ * @author Josh Cummings * @author Joe Grandja * @author Mykyta Bezverkhyi + * @author Daeho Kwon * @since 5.2 */ public final class NimbusJwtDecoder implements JwtDecoder { @@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) { .build(); // @formatter:on } - catch (RemoteKeySourceException ex) { + catch (KeySourceException ex) { this.logger.trace("Failed to retrieve JWK set", ex); if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); @@ -273,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder { private RestOperations restOperations = new RestTemplate(); - private Cache cache; + private Cache cache = new NoOpCache("default"); private Consumer> jwtProcessorCustomizer; @@ -376,18 +381,13 @@ JWSKeySelector jwsKeySelector(JWKSource jwkSou return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } - JWKSource jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { - if (this.cache == null) { - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); - } - JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache); + JWKSource jwkSource() { + String jwkSetUri = this.jwkSetUri.apply(this.restOperations); + return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); } JWTProcessor processor() { - ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); - String jwkSetUri = this.jwkSetUri.apply(this.restOperations); - JWKSource jwkSource = jwkSource(jwkSetRetriever, jwkSetUri); + JWKSource jwkSource = jwkSource(); ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); // Spring Security validates the claim set independent from Nimbus @@ -414,84 +414,130 @@ private static URL toURL(String url) { } } - private static final class SpringJWKSetCache implements JWKSetCache { + private static final class SpringJWKSource implements JWKSource { - private final String jwkSetUri; + private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); + + private final ReentrantLock reentrantLock = new ReentrantLock(); + + private final RestOperations restOperations; private final Cache cache; - private JWKSet jwkSet; + private final URL url; - SpringJWKSetCache(String jwkSetUri, Cache cache) { - this.jwkSetUri = jwkSetUri; + private final String jwkSetUri; + + private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; this.cache = cache; - this.updateJwkSetFromCache(); + this.url = url; + this.jwkSetUri = jwkSetUri; } - private void updateJwkSetFromCache() { + + @Override + public List get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); + JWKSet jwkSet = null; if (cachedJwkSet != null) { - try { - this.jwkSet = JWKSet.parse(cachedJwkSet); - } - catch (ParseException ignored) { - // Ignore invalid cache value + jwkSet = parse(cachedJwkSet); + } + if (jwkSet == null) { + if(reentrantLock.tryLock()) { + try { + String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class); + if (cachedJwkSetAfterLock != null) { + jwkSet = parse(cachedJwkSetAfterLock); + } + if(jwkSet == null) { + try { + jwkSet = fetchJWKSet(); + } catch (IOException e) { + throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); + } + } + } finally { + reentrantLock.unlock(); + } } } - } - - // Note: Only called from inside a synchronized block in RemoteJWKSet. - @Override - public void put(JWKSet jwkSet) { - this.jwkSet = jwkSet; - this.cache.put(this.jwkSetUri, jwkSet.toString(false)); - } - - @Override - public JWKSet get() { - return (!requiresRefresh()) ? this.jwkSet : null; - - } - - @Override - public boolean requiresRefresh() { - return this.cache.get(this.jwkSetUri) == null; - } - - } - - private static class RestOperationsResourceRetriever implements ResourceRetriever { - - private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); - - private final RestOperations restOperations; + List matches = jwkSelector.select(jwkSet); + if(!matches.isEmpty()) { + return matches; + } + String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); + if (soughtKeyID == null) { + return Collections.emptyList(); + } + if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { + return Collections.emptyList(); + } - RestOperationsResourceRetriever(RestOperations restOperations) { - Assert.notNull(restOperations, "restOperations cannot be null"); - this.restOperations = restOperations; + if(reentrantLock.tryLock()) { + try { + String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); + JWKSet cacheJwkSet = parse(jwkSetUri); + if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { + try { + jwkSet = fetchJWKSet(); + } catch (IOException e) { + throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); + } + } else if (jwkSetUri != null) { + jwkSet = parse(jwkSetUri); + } + } finally { + reentrantLock.unlock(); + } + } + if(jwkSet == null) { + return Collections.emptyList(); + } + return jwkSelector.select(jwkSet); } - @Override - public Resource retrieveResource(URL url) throws IOException { + private JWKSet fetchJWKSet() throws IOException, KeySourceException { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - ResponseEntity response = getResponse(url, headers); + ResponseEntity response = getResponse(headers); if (response.getStatusCode().value() != 200) { throw new IOException(response.toString()); } - return new Resource(response.getBody(), "UTF-8"); + try { + String jwkSet = response.getBody(); + this.cache.put(this.jwkSetUri, jwkSet); + return JWKSet.parse(jwkSet); + } catch (ParseException e) { + throw new JWKSetParseException("Unable to parse JWK set", e); + } } - private ResponseEntity getResponse(URL url, HttpHeaders headers) throws IOException { + private ResponseEntity getResponse(HttpHeaders headers) throws IOException { try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI()); return this.restOperations.exchange(request, String.class); - } - catch (Exception ex) { + } catch (Exception ex) { throw new IOException(ex); } } + private JWKSet parse(String cachedJwkSet) { + JWKSet jwkSet = null; + try { + jwkSet = JWKSet.parse(cachedJwkSet); + } catch (ParseException ignored) { + // Ignore invalid cache value + } + return jwkSet; + } + + private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) { + Set keyIDs = jwkMatcher.getKeyIDs(); + return (keyIDs == null || keyIDs.isEmpty()) ? + null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null); + } } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index f343cd2b69a..378a6dbd416 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -308,6 +308,7 @@ private void prepareConfigurationResponse() { private void prepareConfigurationResponse(String body) { this.server.enqueue(response(body)); this.server.enqueue(response(JWK_SET)); + this.server.enqueue(response(JWK_SET)); // default NoOpCache } private void prepareConfigurationResponseOidc() { From 457f4db7f52a97cc331e2f6c0489a4373e3e940f Mon Sep 17 00:00:00 2001 From: Josh Cummings <3627351+jzheaux@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:32:12 -0700 Subject: [PATCH 3/3] Polish Nimbus JWK Source Implementation Issue gh-16251 --- .../security/oauth2/jwt/NimbusJwtDecoder.java | 164 ++++++------------ .../security/oauth2/jwt/JwtDecodersTests.java | 3 +- 2 files changed, 50 insertions(+), 117 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 732ecc2476a..df0239ebfec 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -16,15 +16,7 @@ package org.springframework.security.oauth2.jwt; -import com.nimbusds.jose.KeySourceException; -import com.nimbusds.jose.jwk.JWK; -import com.nimbusds.jose.jwk.JWKMatcher; -import com.nimbusds.jose.jwk.JWKSelector; -import com.nimbusds.jose.jwk.source.JWKSetParseException; -import com.nimbusds.jose.jwk.source.JWKSetRetrievalException; -import java.io.IOException; -import java.net.MalformedURLException; -import java.net.URL; +import java.net.URI; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Arrays; @@ -32,7 +24,6 @@ import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.locks.ReentrantLock; @@ -43,8 +34,13 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.KeySourceException; +import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator; +import com.nimbusds.jose.jwk.source.JWKSetSource; import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.jwk.source.JWKSourceBuilder; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; @@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) { .build(); // @formatter:on } - catch (KeySourceException ex) { + catch (RemoteKeySourceException ex) { this.logger.trace("Failed to retrieve JWK set", ex); if (ex.getCause() instanceof ParseException) { throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex); @@ -383,7 +379,11 @@ JWSKeySelector jwsKeySelector(JWKSource jwkSou JWKSource jwkSource() { String jwkSetUri = this.jwkSetUri.apply(this.restOperations); - return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri); + return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri)) + .refreshAheadCache(false) + .rateLimited(false) + .cache(this.cache instanceof NoOpCache) + .build(); } JWTProcessor processor() { @@ -405,16 +405,7 @@ public NimbusJwtDecoder build() { return new NimbusJwtDecoder(processor()); } - private static URL toURL(String url) { - try { - return new URL(url); - } - catch (MalformedURLException ex) { - throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex); - } - } - - private static final class SpringJWKSource implements JWKSource { + private static final class SpringJWKSource implements JWKSetSource { private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); @@ -424,120 +415,63 @@ private static final class SpringJWKSource implements private final Cache cache; - private final URL url; - private final String jwkSetUri; - private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) { + private JWKSet jwkSet; + + private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; this.cache = cache; - this.url = url; this.jwkSetUri = jwkSetUri; - } - - - @Override - public List get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException { - String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); - JWKSet jwkSet = null; - if (cachedJwkSet != null) { - jwkSet = parse(cachedJwkSet); - } - if (jwkSet == null) { - if(reentrantLock.tryLock()) { - try { - String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class); - if (cachedJwkSetAfterLock != null) { - jwkSet = parse(cachedJwkSetAfterLock); - } - if(jwkSet == null) { - try { - jwkSet = fetchJWKSet(); - } catch (IOException e) { - throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); - } - } - } finally { - reentrantLock.unlock(); - } - } - } - List matches = jwkSelector.select(jwkSet); - if(!matches.isEmpty()) { - return matches; - } - String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher()); - if (soughtKeyID == null) { - return Collections.emptyList(); - } - if (jwkSet.getKeyByKeyId(soughtKeyID) != null) { - return Collections.emptyList(); - } - - if(reentrantLock.tryLock()) { + String jwks = this.cache.get(this.jwkSetUri, String.class); + if (jwks != null) { try { - String jwkSetUri = this.cache.get(this.jwkSetUri, String.class); - JWKSet cacheJwkSet = parse(jwkSetUri); - if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) { - try { - jwkSet = fetchJWKSet(); - } catch (IOException e) { - throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e); - } - } else if (jwkSetUri != null) { - jwkSet = parse(jwkSetUri); - } - } finally { - reentrantLock.unlock(); + this.jwkSet = JWKSet.parse(jwks); + } + catch (ParseException ignored) { + // Ignore invalid cache value } } - if(jwkSet == null) { - return Collections.emptyList(); - } - return jwkSelector.select(jwkSet); } - private JWKSet fetchJWKSet() throws IOException, KeySourceException { + private String fetchJwks() throws Exception { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - ResponseEntity response = getResponse(headers); - if (response.getStatusCode().value() != 200) { - throw new IOException(response.toString()); - } - try { - String jwkSet = response.getBody(); - this.cache.put(this.jwkSetUri, jwkSet); - return JWKSet.parse(jwkSet); - } catch (ParseException e) { - throw new JWKSetParseException("Unable to parse JWK set", e); - } + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri)); + ResponseEntity response = this.restOperations.exchange(request, String.class); + String jwks = response.getBody(); + this.jwkSet = JWKSet.parse(jwks); + return jwks; } - private ResponseEntity getResponse(HttpHeaders headers) throws IOException { + @Override + public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context) + throws KeySourceException { try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI()); - return this.restOperations.exchange(request, String.class); - } catch (Exception ex) { - throw new IOException(ex); + this.reentrantLock.lock(); + if (refreshEvaluator.requiresRefresh(this.jwkSet)) { + this.cache.invalidate(); + } + this.cache.get(this.jwkSetUri, this::fetchJwks); + return this.jwkSet; } - } - - private JWKSet parse(String cachedJwkSet) { - JWKSet jwkSet = null; - try { - jwkSet = JWKSet.parse(cachedJwkSet); - } catch (ParseException ignored) { - // Ignore invalid cache value + catch (Cache.ValueRetrievalException ex) { + if (ex.getCause() instanceof RemoteKeySourceException keys) { + throw keys; + } + throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause()); + } + finally { + this.reentrantLock.unlock(); } - return jwkSet; } - private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) { - Set keyIDs = jwkMatcher.getKeyIDs(); - return (keyIDs == null || keyIDs.isEmpty()) ? - null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null); + @Override + public void close() { + } + } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java index 378a6dbd416..f343cd2b69a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2025 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -308,7 +308,6 @@ private void prepareConfigurationResponse() { private void prepareConfigurationResponse(String body) { this.server.enqueue(response(body)); this.server.enqueue(response(JWK_SET)); - this.server.enqueue(response(JWK_SET)); // default NoOpCache } private void prepareConfigurationResponseOidc() {