Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions app/core/management/commands/wait_for_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ class Command(BaseCommand):

def handle(self, *args, **options):
self.stdout.write('Waiting for database...')
db_conn = None
while not db_conn:
cursor = None
while not cursor:
try:
db_conn = connections['default']
self.stdout.write('Obtain cursor and verify database...')
cursor = db_conn.cursor()
except OperationalError:
self.stdout.write('Database unavailable, waiting 1 second...')
time.sleep(1)
Expand Down
17 changes: 9 additions & 8 deletions app/core/tests/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import patch, MagicMock

from django.core.management import call_command
from django.db.utils import OperationalError
Expand All @@ -9,15 +9,16 @@ class CommandTests(TestCase):

def test_wait_for_db_ready(self):
"""Test waiting for db when db is available"""
with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
gi.return_value = True
conn = MagicMock(return_value=None)
with patch('django.db.utils.ConnectionHandler.__getitem__', return_value=conn):
call_command('wait_for_db')
self.assertEqual(gi.call_count, 1)
conn.cursor.assert_called_once()

@patch('time.sleep', return_value=True)
@patch('time.sleep', return_value=None)
def test_wait_for_db(self, ts):
"""Test waiting for db"""
with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
gi.side_effect = [OperationalError] * 5 + [True]
conn = MagicMock(return_value=None)
with patch('django.db.utils.ConnectionHandler.__getitem__', return_value=conn):
conn.cursor.side_effect = [OperationalError] * 5 + [True]
call_command('wait_for_db')
self.assertEqual(gi.call_count, 6)
self.assertEqual(conn.cursor.call_count, 6)