diff --git a/app/core/management/commands/wait_for_db.py b/app/core/management/commands/wait_for_db.py index b6bdcde..cc5f1aa 100644 --- a/app/core/management/commands/wait_for_db.py +++ b/app/core/management/commands/wait_for_db.py @@ -10,10 +10,12 @@ class Command(BaseCommand): def handle(self, *args, **options): self.stdout.write('Waiting for database...') - db_conn = None - while not db_conn: + cursor = None + while not cursor: try: db_conn = connections['default'] + self.stdout.write('Obtain cursor and verify database...') + cursor = db_conn.cursor() except OperationalError: self.stdout.write('Database unavailable, waiting 1 second...') time.sleep(1) diff --git a/app/core/tests/test_commands.py b/app/core/tests/test_commands.py index 1ccd0d9..89c1ab5 100644 --- a/app/core/tests/test_commands.py +++ b/app/core/tests/test_commands.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock from django.core.management import call_command from django.db.utils import OperationalError @@ -9,15 +9,16 @@ class CommandTests(TestCase): def test_wait_for_db_ready(self): """Test waiting for db when db is available""" - with patch('django.db.utils.ConnectionHandler.__getitem__') as gi: - gi.return_value = True + conn = MagicMock(return_value=None) + with patch('django.db.utils.ConnectionHandler.__getitem__', return_value=conn): call_command('wait_for_db') - self.assertEqual(gi.call_count, 1) + conn.cursor.assert_called_once() - @patch('time.sleep', return_value=True) + @patch('time.sleep', return_value=None) def test_wait_for_db(self, ts): """Test waiting for db""" - with patch('django.db.utils.ConnectionHandler.__getitem__') as gi: - gi.side_effect = [OperationalError] * 5 + [True] + conn = MagicMock(return_value=None) + with patch('django.db.utils.ConnectionHandler.__getitem__', return_value=conn): + conn.cursor.side_effect = [OperationalError] * 5 + [True] call_command('wait_for_db') - self.assertEqual(gi.call_count, 6) + self.assertEqual(conn.cursor.call_count, 6)