1616
1717package org .springframework .test .context .bean .override .mockito ;
1818
19+ import java .lang .reflect .Field ;
1920import java .util .Arrays ;
2021import java .util .HashSet ;
2122import java .util .Set ;
23+ import java .util .function .Predicate ;
2224
2325import org .mockito .Mockito ;
2426
3032import org .springframework .context .ApplicationContext ;
3133import org .springframework .context .ConfigurableApplicationContext ;
3234import org .springframework .core .Ordered ;
35+ import org .springframework .core .annotation .MergedAnnotation ;
36+ import org .springframework .core .annotation .MergedAnnotations ;
3337import org .springframework .lang .Nullable ;
3438import org .springframework .test .context .TestContext ;
39+ import org .springframework .test .context .TestContextAnnotationUtils ;
40+ import org .springframework .test .context .support .AbstractTestExecutionListener ;
41+ import org .springframework .util .ClassUtils ;
3542
3643/**
3744 * {@code TestExecutionListener} that resets any mock beans that have been marked
4350 * @see MockitoBean @MockitoBean
4451 * @see MockitoSpyBean @MockitoSpyBean
4552 */
46- public class MockitoResetTestExecutionListener extends AbstractMockitoTestExecutionListener {
53+ public class MockitoResetTestExecutionListener extends AbstractTestExecutionListener {
54+
55+ static final boolean mockitoPresent = ClassUtils .isPresent ("org.mockito.Mockito" ,
56+ MockitoResetTestExecutionListener .class .getClassLoader ());
57+
58+ private static final String SPRING_MOCKITO_PACKAGE = "org.springframework.test.context.bean.override.mockito" ;
59+
60+ private static final Predicate <MergedAnnotation <?>> isMockitoAnnotation = mergedAnnotation -> {
61+ String packageName = mergedAnnotation .getType ().getPackageName ();
62+ return packageName .startsWith (SPRING_MOCKITO_PACKAGE );
63+ };
4764
4865 /**
4966 * Executes before {@link org.springframework.test.context.bean.override.BeanOverrideTestExecutionListener}.
@@ -67,6 +84,7 @@ public void afterTestMethod(TestContext testContext) {
6784 }
6885 }
6986
87+
7088 private void resetMocks (ApplicationContext applicationContext , MockReset reset ) {
7189 if (applicationContext instanceof ConfigurableApplicationContext configurableContext ) {
7290 resetMocks (configurableContext , reset );
@@ -119,4 +137,56 @@ private static boolean isStandardBeanOrSingletonFactoryBean(BeanFactory beanFact
119137 return true ;
120138 }
121139
140+ /**
141+ * Determine if the test class for the supplied {@linkplain TestContext
142+ * test context} uses any of the annotations in this package (such as
143+ * {@link MockitoBean @MockitoBean}).
144+ */
145+ static boolean hasMockitoAnnotations (TestContext testContext ) {
146+ return hasMockitoAnnotations (testContext .getTestClass ());
147+ }
148+
149+ /**
150+ * Determine if Mockito annotations are declared on the supplied class, on an
151+ * interface it implements, on a superclass, or on an enclosing class or
152+ * whether a field in any such class is annotated with a Mockito annotation.
153+ */
154+ private static boolean hasMockitoAnnotations (Class <?> clazz ) {
155+ // Declared on the class?
156+ if (MergedAnnotations .from (clazz , MergedAnnotations .SearchStrategy .DIRECT ).stream ().anyMatch (isMockitoAnnotation )) {
157+ return true ;
158+ }
159+
160+ // Declared on a field?
161+ for (Field field : clazz .getDeclaredFields ()) {
162+ if (MergedAnnotations .from (field , MergedAnnotations .SearchStrategy .DIRECT ).stream ().anyMatch (isMockitoAnnotation )) {
163+ return true ;
164+ }
165+ }
166+
167+ // Declared on an interface?
168+ for (Class <?> ifc : clazz .getInterfaces ()) {
169+ if (hasMockitoAnnotations (ifc )) {
170+ return true ;
171+ }
172+ }
173+
174+ // Declared on a superclass?
175+ Class <?> superclass = clazz .getSuperclass ();
176+ if (superclass != null & superclass != Object .class ) {
177+ if (hasMockitoAnnotations (superclass )) {
178+ return true ;
179+ }
180+ }
181+
182+ // Declared on an enclosing class of an inner class?
183+ if (TestContextAnnotationUtils .searchEnclosingClass (clazz )) {
184+ if (hasMockitoAnnotations (clazz .getEnclosingClass ())) {
185+ return true ;
186+ }
187+ }
188+
189+ return false ;
190+ }
191+
122192}
0 commit comments