7
7
from functools import partial
8
8
from typing import (
9
9
TYPE_CHECKING ,
10
+ AbstractSet ,
10
11
Any ,
11
12
Callable ,
12
13
ContextManager ,
16
17
Literal ,
17
18
Optional ,
18
19
Protocol ,
20
+ Sequence ,
19
21
Tuple ,
20
22
Union ,
21
23
)
@@ -119,6 +121,56 @@ def django_db_createdb(request: pytest.FixtureRequest) -> bool:
119
121
return create_db
120
122
121
123
124
+ def _get_databases_for_test (test : pytest .Item ) -> tuple [Iterable [str ], bool ]:
125
+ """Get the database aliases that need to be setup for a test, and whether
126
+ they need to be serialized."""
127
+ from django .db import DEFAULT_DB_ALIAS , connections
128
+ from django .test import TransactionTestCase
129
+
130
+ test_cls = getattr (test , "cls" , None )
131
+ if test_cls and issubclass (test_cls , TransactionTestCase ):
132
+ serialized_rollback = getattr (test , "serialized_rollback" , False )
133
+ databases = getattr (test , "databases" , None )
134
+ else :
135
+ fixtures = getattr (test , "fixturenames" , ())
136
+ marker_db = test .get_closest_marker ("django_db" )
137
+ if marker_db :
138
+ (
139
+ transaction ,
140
+ reset_sequences ,
141
+ databases ,
142
+ serialized_rollback ,
143
+ available_apps ,
144
+ ) = validate_django_db (marker_db )
145
+ elif "db" in fixtures or "transactional_db" in fixtures or "live_server" in fixtures :
146
+ serialized_rollback = "django_db_serialized_rollback" in fixtures
147
+ databases = None
148
+ else :
149
+ return (), False
150
+ if databases is None :
151
+ return (DEFAULT_DB_ALIAS ,), serialized_rollback
152
+ elif databases == "__all__" :
153
+ return connections , serialized_rollback
154
+ else :
155
+ return databases , serialized_rollback
156
+
157
+
158
+ def _get_databases_for_setup (
159
+ items : Sequence [pytest .Item ],
160
+ ) -> tuple [AbstractSet [str ], AbstractSet [str ]]:
161
+ """Get the database aliases that need to be setup, and the subset that needs
162
+ to be serialized."""
163
+ # Code derived from django.test.utils.DiscoverRunner.get_databases().
164
+ aliases : set [str ] = set ()
165
+ serialized_aliases : set [str ] = set ()
166
+ for test in items :
167
+ databases , serialized_rollback = _get_databases_for_test (test )
168
+ aliases .update (databases )
169
+ if serialized_rollback :
170
+ serialized_aliases .update (databases )
171
+ return aliases , serialized_aliases
172
+
173
+
122
174
@pytest .fixture (scope = "session" )
123
175
def django_db_setup (
124
176
request : pytest .FixtureRequest ,
@@ -140,10 +192,14 @@ def django_db_setup(
140
192
if django_db_keepdb and not django_db_createdb :
141
193
setup_databases_args ["keepdb" ] = True
142
194
195
+ aliases , serialized_aliases = _get_databases_for_setup (request .session .items )
196
+
143
197
with django_db_blocker .unblock ():
144
198
db_cfg = setup_databases (
145
199
verbosity = request .config .option .verbose ,
146
200
interactive = False ,
201
+ aliases = aliases ,
202
+ serialized_aliases = serialized_aliases ,
147
203
** setup_databases_args ,
148
204
)
149
205
0 commit comments