Skip to content

Commit 23ce717

Browse files
committed
Simplify customizing OAuth2AuthorizationRequest
Fixes gh-7696
1 parent 6123d79 commit 23ce717

File tree

7 files changed

+344
-63
lines changed

7 files changed

+344
-63
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -41,6 +41,7 @@
4141
import java.util.Base64;
4242
import java.util.HashMap;
4343
import java.util.Map;
44+
import java.util.function.Consumer;
4445

4546
/**
4647
* An implementation of an {@link OAuth2AuthorizationRequestResolver} that attempts to
@@ -66,6 +67,7 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
6667
private final AntPathRequestMatcher authorizationRequestMatcher;
6768
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
6869
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
70+
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
6971

7072
/**
7173
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
@@ -98,6 +100,18 @@ public OAuth2AuthorizationRequest resolve(HttpServletRequest request, String reg
98100
return resolve(request, registrationId, redirectUriAction);
99101
}
100102

103+
/**
104+
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
105+
* allowing for further customizations.
106+
*
107+
* @since 5.3
108+
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
109+
*/
110+
public void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
111+
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
112+
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
113+
}
114+
101115
private String getAction(HttpServletRequest request, String defaultAction) {
102116
String action = request.getParameter("action");
103117
if (action == null) {
@@ -144,16 +158,17 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re
144158

145159
String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
146160

147-
OAuth2AuthorizationRequest authorizationRequest = builder
161+
builder
148162
.clientId(clientRegistration.getClientId())
149163
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
150164
.redirectUri(redirectUriStr)
151165
.scopes(clientRegistration.getScopes())
152166
.state(this.stateGenerator.generateKey())
153-
.attributes(attributes)
154-
.build();
167+
.attributes(attributes);
168+
169+
this.authorizationRequestCustomizer.accept(builder);
155170

156-
return authorizationRequest;
171+
return builder.build();
157172
}
158173

159174
private String resolveRegistrationId(HttpServletRequest request) {

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@
4646
import java.util.Base64;
4747
import java.util.HashMap;
4848
import java.util.Map;
49+
import java.util.function.Consumer;
4950

5051
/**
5152
* The default implementation of {@link ServerOAuth2AuthorizationRequestResolver}.
@@ -81,6 +82,8 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
8182

8283
private final StringKeyGenerator secureKeyGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
8384

85+
private Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer = customizer -> {};
86+
8487
/**
8588
* Creates a new instance
8689
* @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
@@ -121,6 +124,18 @@ public Mono<OAuth2AuthorizationRequest> resolve(ServerWebExchange exchange,
121124
.map(clientRegistration -> authorizationRequest(exchange, clientRegistration));
122125
}
123126

127+
/**
128+
* Sets the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
129+
* allowing for further customizations.
130+
*
131+
* @since 5.3
132+
* @param authorizationRequestCustomizer the {@code Consumer} to be provided the {@link OAuth2AuthorizationRequest.Builder}
133+
*/
134+
public final void setAuthorizationRequestCustomizer(Consumer<OAuth2AuthorizationRequest.Builder> authorizationRequestCustomizer) {
135+
Assert.notNull(authorizationRequestCustomizer, "authorizationRequestCustomizer cannot be null");
136+
this.authorizationRequestCustomizer = authorizationRequestCustomizer;
137+
}
138+
124139
private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
125140
return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
126141
.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
@@ -155,13 +170,17 @@ private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchan
155170
"Invalid Authorization Grant Type (" + clientRegistration.getAuthorizationGrantType().getValue()
156171
+ ") for Client Registration with Id: " + clientRegistration.getRegistrationId());
157172
}
158-
return builder
173+
builder
159174
.clientId(clientRegistration.getClientId())
160175
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
161-
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
176+
.redirectUri(redirectUriStr)
177+
.scopes(clientRegistration.getScopes())
162178
.state(this.stateGenerator.generateKey())
163-
.attributes(attributes)
164-
.build();
179+
.attributes(attributes);
180+
181+
this.authorizationRequestCustomizer.accept(builder);
182+
183+
return builder.build();
165184
}
166185

167186
/**

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/jackson2/OAuth2AuthorizationRequestMixinTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.util.CollectionUtils;
2828
import org.springframework.util.StringUtils;
2929

30+
import java.util.Collections;
3031
import java.util.LinkedHashMap;
3132
import java.util.Map;
3233
import java.util.stream.Collectors;
@@ -70,8 +71,8 @@ public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception
7071
this.authorizationRequestBuilder
7172
.scopes(null)
7273
.state(null)
73-
.additionalParameters(null)
74-
.attributes(null)
74+
.additionalParameters(Collections.emptyMap())
75+
.attributes(Collections.emptyMap())
7576
.build();
7677
String expectedJson = asJson(authorizationRequest);
7778
String json = this.mapper.writeValueAsString(authorizationRequest);
@@ -118,8 +119,8 @@ public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Excep
118119
this.authorizationRequestBuilder
119120
.scopes(null)
120121
.state(null)
121-
.additionalParameters(null)
122-
.attributes(null)
122+
.additionalParameters(Collections.emptyMap())
123+
.attributes(Collections.emptyMap())
123124
.build();
124125
String json = asJson(expectedAuthorizationRequest);
125126
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,7 +31,9 @@
3131
import org.springframework.security.oauth2.core.oidc.OidcScopes;
3232
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
3333

34-
import static org.assertj.core.api.Assertions.*;
34+
import static org.assertj.core.api.Assertions.assertThat;
35+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
36+
import static org.assertj.core.api.Assertions.entry;
3537

3638
/**
3739
* Tests for {@link DefaultOAuth2AuthorizationRequestResolver}.
@@ -81,6 +83,12 @@ public void constructorWhenAuthorizationRequestBaseUriIsNullThenThrowIllegalArgu
8183
.isInstanceOf(IllegalArgumentException.class);
8284
}
8385

86+
@Test
87+
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
88+
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
89+
.isInstanceOf(IllegalArgumentException.class);
90+
}
91+
8492
@Test
8593
public void resolveWhenNotAuthorizationRequestThenDoesNotResolve() {
8694
String requestUri = "/path";
@@ -414,6 +422,76 @@ public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() {
414422
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
415423
}
416424

425+
// gh-7696
426+
@Test
427+
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
428+
ClientRegistration clientRegistration = this.oidcRegistration;
429+
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
430+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
431+
request.setServletPath(requestUri);
432+
433+
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
434+
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
435+
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
436+
437+
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
438+
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
439+
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
440+
assertThat(authorizationRequest.getAuthorizationRequestUri())
441+
.matches("https://example.com/login/oauth/authorize\\?" +
442+
"response_type=code&client_id=client-id&" +
443+
"scope=openid&state=.{15,}&" +
444+
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id");
445+
}
446+
447+
@Test
448+
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
449+
ClientRegistration clientRegistration = this.oidcRegistration;
450+
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
451+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
452+
request.setServletPath(requestUri);
453+
454+
this.resolver.setAuthorizationRequestCustomizer(customizer ->
455+
customizer.authorizationRequestUri(uriBuilder -> {
456+
uriBuilder.queryParam("param1", "value1");
457+
return uriBuilder.build();
458+
})
459+
);
460+
461+
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
462+
assertThat(authorizationRequest.getAuthorizationRequestUri())
463+
.matches("https://example.com/login/oauth/authorize\\?" +
464+
"response_type=code&client_id=client-id&" +
465+
"scope=openid&state=.{15,}&" +
466+
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
467+
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
468+
"param1=value1");
469+
}
470+
471+
@Test
472+
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
473+
ClientRegistration clientRegistration = this.oidcRegistration;
474+
String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
475+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
476+
request.setServletPath(requestUri);
477+
478+
this.resolver.setAuthorizationRequestCustomizer(customizer ->
479+
customizer.parameters(params -> {
480+
params.put("appid", params.get("client_id"));
481+
params.remove("client_id");
482+
})
483+
);
484+
485+
OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
486+
assertThat(authorizationRequest.getAuthorizationRequestUri())
487+
.matches("https://example.com/login/oauth/authorize\\?" +
488+
"response_type=code&" +
489+
"scope=openid&state=.{15,}&" +
490+
"redirect_uri=http://localhost/login/oauth2/code/oidc-registration-id&" +
491+
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
492+
"appid=client-id");
493+
}
494+
417495
private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
418496
return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
419497
.redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolverTests.java

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -37,6 +37,7 @@
3737
import reactor.core.publisher.Mono;
3838

3939
import static org.assertj.core.api.Assertions.assertThat;
40+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4041
import static org.assertj.core.api.Assertions.catchThrowableOfType;
4142
import static org.mockito.ArgumentMatchers.any;
4243
import static org.mockito.Mockito.when;
@@ -59,6 +60,12 @@ public void setup() {
5960
this.resolver = new DefaultServerOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository);
6061
}
6162

63+
@Test
64+
public void setAuthorizationRequestCustomizerWhenNullThenThrowIllegalArgumentException() {
65+
assertThatThrownBy(() -> this.resolver.setAuthorizationRequestCustomizer(null))
66+
.isInstanceOf(IllegalArgumentException.class);
67+
}
68+
6269
@Test
6370
public void resolveWhenNotMatchThenNull() {
6471
assertThat(resolve("/")).isNull();
@@ -139,6 +146,79 @@ public void resolveWhenAuthenticationRequestWithValidOidcClientThenResolves() {
139146
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
140147
}
141148

149+
// gh-7696
150+
@Test
151+
public void resolveWhenAuthorizationRequestCustomizerRemovesNonceThenQueryExcludesNonce() {
152+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
153+
Mono.just(TestClientRegistrations.clientRegistration()
154+
.scope(OidcScopes.OPENID)
155+
.build()));
156+
157+
this.resolver.setAuthorizationRequestCustomizer(customizer -> customizer
158+
.additionalParameters(params -> params.remove(OidcParameterNames.NONCE))
159+
.attributes(attrs -> attrs.remove(OidcParameterNames.NONCE)));
160+
161+
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
162+
163+
assertThat(authorizationRequest.getAdditionalParameters()).doesNotContainKey(OidcParameterNames.NONCE);
164+
assertThat(authorizationRequest.getAttributes()).doesNotContainKey(OidcParameterNames.NONCE);
165+
assertThat(authorizationRequest.getAuthorizationRequestUri())
166+
.matches("https://example.com/login/oauth/authorize\\?" +
167+
"response_type=code&client_id=client-id&" +
168+
"scope=openid&state=.{15,}&" +
169+
"redirect_uri=/login/oauth2/code/registration-id");
170+
}
171+
172+
@Test
173+
public void resolveWhenAuthorizationRequestCustomizerAddsParameterThenQueryIncludesParameter() {
174+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
175+
Mono.just(TestClientRegistrations.clientRegistration()
176+
.scope(OidcScopes.OPENID)
177+
.build()));
178+
179+
this.resolver.setAuthorizationRequestCustomizer(customizer ->
180+
customizer.authorizationRequestUri(uriBuilder -> {
181+
uriBuilder.queryParam("param1", "value1");
182+
return uriBuilder.build();
183+
})
184+
);
185+
186+
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
187+
188+
assertThat(authorizationRequest.getAuthorizationRequestUri())
189+
.matches("https://example.com/login/oauth/authorize\\?" +
190+
"response_type=code&client_id=client-id&" +
191+
"scope=openid&state=.{15,}&" +
192+
"redirect_uri=/login/oauth2/code/registration-id&" +
193+
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
194+
"param1=value1");
195+
}
196+
197+
@Test
198+
public void resolveWhenAuthorizationRequestCustomizerOverridesParameterThenQueryIncludesParameter() {
199+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(
200+
Mono.just(TestClientRegistrations.clientRegistration()
201+
.scope(OidcScopes.OPENID)
202+
.build()));
203+
204+
this.resolver.setAuthorizationRequestCustomizer(customizer ->
205+
customizer.parameters(params -> {
206+
params.put("appid", params.get("client_id"));
207+
params.remove("client_id");
208+
})
209+
);
210+
211+
OAuth2AuthorizationRequest authorizationRequest = resolve("/oauth2/authorization/registration-id");
212+
213+
assertThat(authorizationRequest.getAuthorizationRequestUri())
214+
.matches("https://example.com/login/oauth/authorize\\?" +
215+
"response_type=code&" +
216+
"scope=openid&state=.{15,}&" +
217+
"redirect_uri=/login/oauth2/code/registration-id&" +
218+
"nonce=([a-zA-Z0-9\\-\\.\\_\\~]){43}&" +
219+
"appid=client-id");
220+
}
221+
142222
private OAuth2AuthorizationRequest resolve(String path) {
143223
ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get(path));
144224
return this.resolver.resolve(exchange).block();

0 commit comments

Comments
 (0)