Skip to content

Commit 15f8179

Browse files
committed
Introduce class-level execution phases for @Sql
Resolves gh-18929.
1 parent 7bf520f commit 15f8179

File tree

8 files changed

+233
-9
lines changed

8 files changed

+233
-9
lines changed

spring-test/src/main/java/org/springframework/test/context/TestContext.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,25 @@ default void publishEvent(Function<TestContext, ? extends ApplicationEvent> even
110110
*/
111111
Object getTestInstance();
112112

113+
/**
114+
* Tests whether a test method is part of this test context. Returns
115+
* {@code true} if this context has a current test method, {@code false}
116+
* otherwise.
117+
*
118+
* <p>The default implementation of this method always returns {@code false}.
119+
* Custom {@code TestContext} implementations are therefore highly encouraged
120+
* to override this method with a more meaningful implementation. Note that
121+
* the standard {@code TestContext} implementation in Spring overrides this
122+
* method appropriately.
123+
* @return {@code true} if the test execution has already entered a test
124+
* method
125+
* @since 6.1
126+
* @see #getTestMethod()
127+
*/
128+
default boolean hasTestMethod() {
129+
return false;
130+
}
131+
113132
/**
114133
* Get the current {@linkplain Method test method} for this test context.
115134
* <p>Note: this is a mutable property.

spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
*
3434
* <p>Method-level declarations override class-level declarations by default,
3535
* but this behavior can be configured via {@link SqlMergeMode @SqlMergeMode}.
36+
* However, this does not apply to class-level declarations that use
37+
* {@link ExecutionPhase#BEFORE_TEST_CLASS} or
38+
* {@link ExecutionPhase#AFTER_TEST_CLASS}. Such declarations are retained and
39+
* scripts and statements are executed once per class in addition to any
40+
* method-level annotations.
3641
*
3742
* <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener},
3843
* which is enabled by default.
@@ -161,6 +166,18 @@
161166
*/
162167
enum ExecutionPhase {
163168

169+
/**
170+
* The configured SQL scripts and statements will be executed
171+
* once <em>before</em> any test method is run.
172+
*/
173+
BEFORE_TEST_CLASS,
174+
175+
/**
176+
* The configured SQL scripts and statements will be executed
177+
* once <em>after</em> any test method is run.
178+
*/
179+
AFTER_TEST_CLASS,
180+
164181
/**
165182
* The configured SQL scripts and statements will be executed
166183
* <em>before</em> the corresponding test method.

spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,17 @@
6767
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
6868
* configured via the {@link Sql @Sql} annotation.
6969
*
70-
* <p>Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before}
71-
* or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding
72-
* {@linkplain java.lang.reflect.Method test method}, depending on the configured
73-
* value of the {@link Sql#executionPhase executionPhase} flag.
70+
* <p>Class-level annotations that are constrained to a class-level execution
71+
* phase ({@link ExecutionPhase#BEFORE_TEST_CLASS} or
72+
* {@link ExecutionPhase#AFTER_TEST_CLASS}) will be run
73+
* {@linkplain #beforeTestClass(TestContext) once before all test methods} or
74+
* {@linkplain #afterTestMethod(TestContext) once after all test methods},
75+
* respectively. All other scripts and inlined statements will be executed
76+
* {@linkplain #beforeTestMethod(TestContext) before} or
77+
* {@linkplain #afterTestMethod(TestContext) after} execution of the
78+
* corresponding {@linkplain java.lang.reflect.Method test method}, depending
79+
* on the configured value of the {@link Sql#executionPhase executionPhase}
80+
* flag.
7481
*
7582
* <p>Scripts and inlined statements will be executed without a transaction,
7683
* within an existing Spring-managed transaction, or within an isolated transaction,
@@ -126,6 +133,26 @@ public final int getOrder() {
126133
return 5000;
127134
}
128135

136+
/**
137+
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
138+
* {@link TestContext} once per test class <em>before</em> any test method
139+
* is run.
140+
*/
141+
@Override
142+
public void beforeTestClass(TestContext testContext) throws Exception {
143+
executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_CLASS);
144+
}
145+
146+
/**
147+
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
148+
* {@link TestContext} once per test class <em>after</em> all test methods
149+
* have been run.
150+
*/
151+
@Override
152+
public void afterTestClass(TestContext testContext) throws Exception {
153+
executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.AFTER_TEST_CLASS);
154+
}
155+
129156
/**
130157
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
131158
* {@link TestContext} <em>before</em> the current test method.
@@ -159,6 +186,17 @@ public void processAheadOfTime(RuntimeHints runtimeHints, Class<?> testClass, Cl
159186
registerClasspathResources(getScripts(sql, testClass, testMethod, false), runtimeHints, classLoader)));
160187
}
161188

189+
/**
190+
* Execute class-level SQL scripts configured via {@link Sql @Sql} for the
191+
* supplied {@link TestContext} and the execution phases
192+
* {@link ExecutionPhase#BEFORE_TEST_CLASS} and
193+
* {@link ExecutionPhase#AFTER_TEST_CLASS}.
194+
*/
195+
private void executeBeforeOrAfterClassSqlScripts(TestContext testContext, ExecutionPhase executionPhase) {
196+
Class<?> testClass = testContext.getTestClass();
197+
executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
198+
}
199+
162200
/**
163201
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
164202
* {@link TestContext} and {@link ExecutionPhase}.
@@ -260,7 +298,12 @@ else if (logger.isDebugEnabled()) {
260298
.formatted(executionPhase, testContext.getTestClass().getName()));
261299
}
262300

263-
String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel);
301+
Method testMethod = null;
302+
if (testContext.hasTestMethod()) {
303+
testMethod = testContext.getTestMethod();
304+
}
305+
306+
String[] scripts = getScripts(sql, testContext.getTestClass(), testMethod, classLevel);
264307
List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
265308
testContext.getApplicationContext(), scripts);
266309
for (String stmt : sql.statements()) {
@@ -354,7 +397,7 @@ private DataSource getDataSourceFromTransactionManager(PlatformTransactionManage
354397
return null;
355398
}
356399

357-
private String[] getScripts(Sql sql, Class<?> testClass, Method testMethod, boolean classLevel) {
400+
private String[] getScripts(Sql sql, Class<?> testClass, @Nullable Method testMethod, boolean classLevel) {
358401
String[] scripts = sql.scripts();
359402
if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
360403
scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)};
@@ -366,7 +409,11 @@ private String[] getScripts(Sql sql, Class<?> testClass, Method testMethod, bool
366409
* Detect a default SQL script by implementing the algorithm defined in
367410
* {@link Sql#scripts}.
368411
*/
369-
private String detectDefaultScript(Class<?> testClass, Method testMethod, boolean classLevel) {
412+
private String detectDefaultScript(Class<?> testClass, @Nullable Method testMethod, boolean classLevel) {
413+
if (!classLevel && testMethod == null) {
414+
throw new AssertionError("Method-level @Sql requires a testMethod");
415+
}
416+
370417
String elementType = (classLevel ? "class" : "method");
371418
String elementName = (classLevel ? testClass.getName() : testMethod.toString());
372419

spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ public final Object getTestInstance() {
166166
return testInstance;
167167
}
168168

169+
@Override
170+
public boolean hasTestMethod() {
171+
return this.testMethod != null;
172+
}
173+
169174
@Override
170175
public final Method getTestMethod() {
171176
Method testMethod = this.testMethod;

spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ private static void logBeansException(TestContext testContext, BeansException ex
227227
/**
228228
* Create a delegating {@link TransactionAttribute} for the supplied target
229229
* {@link TransactionAttribute} and {@link TestContext}, using the names of
230-
* the test class and test method to build the name of the transaction.
230+
* the test class and test method (if available) to build the name of the
231+
* transaction.
231232
* @param testContext the {@code TestContext} upon which to base the name
232233
* @param targetAttribute the {@code TransactionAttribute} to delegate to
233234
* @return the delegating {@code TransactionAttribute}
@@ -248,7 +249,13 @@ private static class TestContextTransactionAttribute extends DelegatingTransacti
248249

249250
public TestContextTransactionAttribute(TransactionAttribute targetAttribute, TestContext testContext) {
250251
super(targetAttribute);
251-
this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
252+
253+
if (testContext.hasTestMethod()) {
254+
this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
255+
}
256+
else {
257+
this.name = testContext.getTestClass().getName();
258+
}
252259
}
253260

254261
@Override
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2002-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.jdbc;
18+
19+
import javax.sql.DataSource;
20+
21+
import org.junit.jupiter.api.Test;
22+
23+
import org.springframework.core.Ordered;
24+
import org.springframework.jdbc.BadSqlGrammarException;
25+
import org.springframework.jdbc.core.JdbcTemplate;
26+
import org.springframework.test.annotation.DirtiesContext;
27+
import org.springframework.test.context.TestContext;
28+
import org.springframework.test.context.TestExecutionListener;
29+
import org.springframework.test.context.TestExecutionListeners;
30+
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
31+
import org.springframework.test.context.transaction.TestContextTransactionUtils;
32+
33+
@SpringJUnitConfig(PopulatedSchemaDatabaseConfig.class)
34+
@DirtiesContext
35+
@Sql(value = {"drop-schema.sql"}, executionPhase = Sql.ExecutionPhase.AFTER_TEST_CLASS)
36+
@TestExecutionListeners(
37+
value = AfterTestClassSqlScriptsTests.VerifyTestExecutionListener.class,
38+
mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS
39+
)
40+
class AfterTestClassSqlScriptsTests extends AbstractTransactionalTests {
41+
42+
@Test
43+
@Sql(scripts = "data-add-catbert.sql")
44+
void databaseHasBeenInitialized() {
45+
// Ensure that the database has been initialized and can be accessed.
46+
assertUsers("Catbert");
47+
}
48+
49+
static class VerifyTestExecutionListener implements TestExecutionListener, Ordered {
50+
51+
@Override
52+
public void afterTestClass(TestContext testContext) throws Exception {
53+
DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, null);
54+
try {
55+
new JdbcTemplate(dataSource).queryForList("SELECT name FROM user", String.class);
56+
throw new AssertionError("BadSqlGrammarException should have been thrown.");
57+
}
58+
catch (BadSqlGrammarException expected) {
59+
}
60+
}
61+
62+
@Override
63+
public int getOrder() {
64+
// Must run before DirtiesContextTestExecutionListener. Otherwise, the old data source will be removed and
65+
// replaced with a new one.
66+
return 3001;
67+
}
68+
}
69+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2002-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.jdbc;
18+
19+
import org.junit.jupiter.api.Test;
20+
21+
import org.springframework.test.annotation.DirtiesContext;
22+
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
23+
24+
import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE;
25+
import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.OVERRIDE;
26+
27+
@SpringJUnitConfig(classes = EmptyDatabaseConfig.class)
28+
@DirtiesContext
29+
@Sql(value = {"schema.sql", "data-add-catbert.sql"}, executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS)
30+
class BeforeTestClassSqlScriptsTests extends AbstractTransactionalTests {
31+
32+
@Test
33+
void classLevelScriptsHaveBeenRun() {
34+
assertUsers("Catbert");
35+
}
36+
37+
@Test
38+
@Sql("data-add-dogbert.sql")
39+
@SqlMergeMode(MERGE)
40+
void mergeDoesNotAffectClassLevelPhase() {
41+
assertUsers("Catbert", "Dogbert");
42+
}
43+
44+
@Test
45+
@Sql({"data-add-dogbert.sql"})
46+
@SqlMergeMode(OVERRIDE)
47+
void overrideDoesNotAffectClassLevelPhase() {
48+
assertUsers("Dogbert", "Catbert");
49+
}
50+
51+
@Test
52+
@Sql(scripts = {"data-add-catbert.sql"}, executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS)
53+
void classLevelPhaseIsIgnoredOnMethod() {
54+
// There's a unique constraint on the name. If the script succeeded, there would be a
55+
// constraint violation.
56+
assertUsers("Catbert");
57+
}
58+
}
59+

spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ void missingValueAndScriptsAndStatementsAtClassLevel() throws Exception {
5656
void missingValueAndScriptsAndStatementsAtMethodLevel() throws Exception {
5757
Class<?> clazz = MissingValueAndScriptsAndStatementsAtMethodLevel.class;
5858
BDDMockito.<Class<?>> given(testContext.getTestClass()).willReturn(clazz);
59+
given(testContext.hasTestMethod()).willReturn(true);
5960
given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("foo"));
6061

6162
assertExceptionContains(clazz.getSimpleName() + ".foo" + ".sql");

0 commit comments

Comments
 (0)