diff --git a/src/main/java/com/intuit/springwebclient/client/CommonSpringWebClient.java b/src/main/java/com/intuit/springwebclient/client/CommonSpringWebClient.java index 8e8a3d4..ce0668c 100644 --- a/src/main/java/com/intuit/springwebclient/client/CommonSpringWebClient.java +++ b/src/main/java/com/intuit/springwebclient/client/CommonSpringWebClient.java @@ -11,10 +11,13 @@ import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.UnknownContentTypeException; import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClient.RequestBodySpec; + import reactor.core.publisher.Mono; import reactor.util.retry.Retry; import java.time.Duration; +import java.util.Objects; import java.util.function.Consumer; /** @@ -60,19 +63,27 @@ public ClientHttpResponse syncHttpResponse(ClientH } /** - * Generate Web Client Response spec from http request. - * @param httpRequest - * @return - */ - private WebClient.ResponseSpec generateResponseSpec(ClientHttpRequest httpRequest) { + * Generate Web Client Response spec from http request. + * + * @param httpRequest + * @return + */ + private WebClient.ResponseSpec generateResponseSpec( + ClientHttpRequest httpRequest) { - Consumer httpHeadersConsumer = (httpHeaders -> httpHeaders.putAll(httpRequest.getRequestHeaders())); - return webClient.method(httpRequest.getHttpMethod()) - .uri(httpRequest.getUrl()) - .headers(httpHeadersConsumer) - .body(Mono.just(httpRequest.getRequest()), httpRequest.getRequestType()) - .retrieve(); - } + Consumer httpHeadersConsumer = (httpHeaders -> httpHeaders + .putAll(httpRequest.getRequestHeaders())); + RequestBodySpec webClientBuilder = webClient.method(httpRequest.getHttpMethod()).uri(httpRequest.getUrl()) + .headers(httpHeadersConsumer); + + // set only when provided + if (Objects.nonNull(httpRequest.getRequest()) && Objects.nonNull(httpRequest.getRequestType())) { + webClientBuilder.body(Mono.just(httpRequest.getRequest()), httpRequest.getRequestType()); + } + + return webClientBuilder.retrieve(); + + } /** * Generates retry spec for the request based on config provided. diff --git a/src/test/java/com/intuit/springwebclient/client/CommonSpringWebClientTest.java b/src/test/java/com/intuit/springwebclient/client/CommonSpringWebClientTest.java index 301caf7..5c270b2 100644 --- a/src/test/java/com/intuit/springwebclient/client/CommonSpringWebClientTest.java +++ b/src/test/java/com/intuit/springwebclient/client/CommonSpringWebClientTest.java @@ -1,6 +1,8 @@ package com.intuit.springwebclient.client; import com.intuit.springwebclient.entity.ClientHttpRequest; +import com.intuit.springwebclient.entity.ClientHttpRequest.ClientHttpRequestBuilder; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.InjectMocks; @@ -39,7 +41,7 @@ public class CommonSpringWebClientTest { @Test public void testSyncHttpResponseSuccess() { - ClientHttpRequest clientHttpRequest = createClientHttpRequest(); + ClientHttpRequest clientHttpRequest = createClientHttpRequest().build(); mockRequestBody(); Mockito.when(headersSpec.retrieve()).thenReturn(responseSpec); @@ -48,11 +50,21 @@ public void testSyncHttpResponseSuccess() { commonSpringWebClient.syncHttpResponse(clientHttpRequest); } + + @Test + public void testSyncHttpResponseSuccessNoRequestBody() { + ClientHttpRequest clientHttpRequest = createClientHttpRequest().request(null).requestType(null).build(); + Mockito.when(webClient.method(HttpMethod.GET)).thenReturn(requestBodyUriSpec); + Mockito.doReturn(requestBodyUriSpec).when(requestBodyUriSpec).uri("test-url"); + Mockito.when(requestBodyUriSpec.headers(Mockito.any())).thenReturn(requestBodySpec); + + commonSpringWebClient.syncHttpResponse(clientHttpRequest); + } @Test public void testHttpStatusCodeException() { - ClientHttpRequest clientHttpRequest = createClientHttpRequest(); + ClientHttpRequest clientHttpRequest = createClientHttpRequest().build(); mockRequestBody(); HttpClientErrorException httpClientErrorException = Mockito.mock(HttpClientErrorException.class); @@ -67,7 +79,7 @@ public void testHttpStatusCodeException() { @Test public void testUnknownContentTypeException() { - ClientHttpRequest clientHttpRequest = createClientHttpRequest(); + ClientHttpRequest clientHttpRequest = createClientHttpRequest().build(); mockRequestBody(); UnknownContentTypeException unknownContentTypeException = Mockito.mock(UnknownContentTypeException.class); @@ -82,7 +94,7 @@ public void testUnknownContentTypeException() { @Test public void testOtherException() { - ClientHttpRequest clientHttpRequest = createClientHttpRequest(); + ClientHttpRequest clientHttpRequest = createClientHttpRequest().build(); mockRequestBody(); Mockito.when(headersSpec.retrieve()).thenThrow(IllegalArgumentException.class); @@ -90,7 +102,7 @@ public void testOtherException() { commonSpringWebClient.syncHttpResponse(clientHttpRequest); } - private ClientHttpRequest createClientHttpRequest() { + private ClientHttpRequestBuilder createClientHttpRequest() { HttpHeaders httpHeadersMock = new HttpHeaders(); Consumer httpHeadersConsumer = new Consumer() { @Override @@ -106,8 +118,7 @@ public void accept(HttpHeaders httpHeaders) { .requestHeaders(httpHeadersMock) .requestType(ParameterizedTypeReference.forType(String.class)) .request("hello") - .responseType(ParameterizedTypeReference.forType(String.class)) - .build(); + .responseType(ParameterizedTypeReference.forType(String.class)); } private void mockRequestBody() {