diff --git a/.ci/appveyor.yml b/.ci/appveyor.yml index 6e85843d..caa39ff5 100644 --- a/.ci/appveyor.yml +++ b/.ci/appveyor.yml @@ -26,13 +26,13 @@ branches: install: - "%PYTHON% -m pip install --upgrade pip wheel setuptools" - - "%PYTHON% -m pip install -r .ci/requirements-win.txt" + - "%PYTHON% -m pip install --upgrade -r .ci/requirements-win.txt" build_script: - "%PYTHON% setup.py build_ext --inplace" test_script: - - "%PYTHON% -m unittest discover -s tests" + - "%PYTHON% setup.py test" after_test: - "%PYTHON% setup.py bdist_wheel" diff --git a/.ci/requirements-win.txt b/.ci/requirements-win.txt index 82ad782c..48cf3ac7 100644 --- a/.ci/requirements-win.txt +++ b/.ci/requirements-win.txt @@ -1,2 +1,2 @@ -cython>=0.24 +cython>=0.27.2 tinys3 diff --git a/.ci/requirements.txt b/.ci/requirements.txt index 54cc2771..a981c788 100644 --- a/.ci/requirements.txt +++ b/.ci/requirements.txt @@ -1,5 +1,5 @@ -cython>=0.24 +cython>=0.27.2 flake8>=3.4.1 -uvloop>=0.5.0 +uvloop>=0.8.0 tinys3 twine diff --git a/.ci/travis-install.sh b/.ci/travis-install.sh index 1af43fdb..e9715eed 100755 --- a/.ci/travis-install.sh +++ b/.ci/travis-install.sh @@ -10,4 +10,4 @@ fi pip install --upgrade pip wheel pip install --upgrade setuptools -pip install -r .ci/requirements.txt +pip install --upgrade -r .ci/requirements.txt diff --git a/.gitignore b/.gitignore index 9ad1a9a4..38b642a5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ *.ymlc~ *.scssc *.so +*.pyd *~ .#* .DS_Store diff --git a/asyncpg/_testbase.py b/asyncpg/_testbase/__init__.py similarity index 60% rename from asyncpg/_testbase.py rename to asyncpg/_testbase/__init__.py index 9b56b580..fb4f1793 100644 --- a/asyncpg/_testbase.py +++ b/asyncpg/_testbase/__init__.py @@ -21,6 +21,8 @@ from asyncpg import connection as pg_connection from asyncpg import pool as pg_pool +from . import fuzzer + @contextlib.contextmanager def silence_asyncio_long_exec_warning(): @@ -36,7 +38,16 @@ def flt(log_record): logger.removeFilter(flt) +def with_timeout(timeout): + def wrap(func): + func.__timeout__ = timeout + return func + + return wrap + + class TestCaseMeta(type(unittest.TestCase)): + TEST_TIMEOUT = None @staticmethod def _iter_methods(bases, ns): @@ -64,7 +75,18 @@ def __new__(mcls, name, bases, ns): for methname, meth in mcls._iter_methods(bases, ns): @functools.wraps(meth) def wrapper(self, *args, __meth__=meth, **kwargs): - self.loop.run_until_complete(__meth__(self, *args, **kwargs)) + coro = __meth__(self, *args, **kwargs) + timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT) + if timeout: + coro = asyncio.wait_for(coro, timeout, loop=self.loop) + try: + self.loop.run_until_complete(coro) + except asyncio.TimeoutError: + raise self.failureException( + 'test timed out after {} seconds'.format( + timeout)) from None + else: + self.loop.run_until_complete(coro) ns[methname] = wrapper return super().__new__(mcls, name, bases, ns) @@ -128,16 +150,30 @@ def handler(loop, ctx): _default_cluster = None -def _start_cluster(ClusterCls, cluster_kwargs, server_settings): +def _start_cluster(ClusterCls, cluster_kwargs, server_settings, + initdb_options=None): cluster = ClusterCls(**cluster_kwargs) - cluster.init() + cluster.init(**(initdb_options or {})) cluster.trust_local_connections() cluster.start(port='dynamic', server_settings=server_settings) atexit.register(_shutdown_cluster, cluster) return cluster -def _start_default_cluster(server_settings={}): +def _get_initdb_options(initdb_options=None): + if not initdb_options: + initdb_options = {} + else: + initdb_options = dict(initdb_options) + + # Make the default superuser name stable. + if 'username' not in initdb_options: + initdb_options['username'] = 'postgres' + + return initdb_options + + +def _start_default_cluster(server_settings={}, initdb_options=None): global _default_cluster if _default_cluster is None: @@ -147,13 +183,16 @@ def _start_default_cluster(server_settings={}): _default_cluster = pg_cluster.RunningCluster() else: _default_cluster = _start_cluster( - pg_cluster.TempCluster, {}, server_settings) + pg_cluster.TempCluster, cluster_kwargs={}, + server_settings=server_settings, + initdb_options=_get_initdb_options(initdb_options)) return _default_cluster def _shutdown_cluster(cluster): - cluster.stop() + if cluster.get_status() == 'running': + cluster.stop() cluster.destroy() @@ -193,15 +232,78 @@ def setUpClass(cls): super().setUpClass() cls.setup_cluster() - def create_pool(self, pool_class=pg_pool.Pool, **kwargs): - conn_spec = self.cluster.get_connection_spec() + @classmethod + def get_connection_spec(cls, kwargs={}): + conn_spec = cls.cluster.get_connection_spec() conn_spec.update(kwargs) - return create_pool(loop=self.loop, pool_class=pool_class, **conn_spec) + if not os.environ.get('PGHOST'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + def create_pool(self, pool_class=pg_pool.Pool, + connection_class=pg_connection.Connection, **kwargs): + conn_spec = self.get_connection_spec(kwargs) + return create_pool(loop=self.loop, pool_class=pool_class, + connection_class=connection_class, **conn_spec) + + @classmethod + def connect(cls, **kwargs): + conn_spec = cls.get_connection_spec(kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) @classmethod def start_cluster(cls, ClusterCls, *, - cluster_kwargs={}, server_settings={}): - return _start_cluster(ClusterCls, cluster_kwargs, server_settings) + cluster_kwargs={}, server_settings={}, + initdb_options={}): + return _start_cluster( + ClusterCls, cluster_kwargs, + server_settings, _get_initdb_options(initdb_options)) + + +class ProxiedClusterTestCase(ClusterTestCase): + @classmethod + def get_server_settings(cls): + settings = dict(super().get_server_settings()) + settings['listen_addresses'] = '127.0.0.1' + return settings + + @classmethod + def get_proxy_settings(cls): + return {'fuzzing-mode': None} + + @classmethod + def setUpClass(cls): + super().setUpClass() + conn_spec = cls.cluster.get_connection_spec() + host = conn_spec.get('host') + if not host: + host = '127.0.0.1' + elif host.startswith('/'): + host = '127.0.0.1' + cls.proxy = fuzzer.TCPFuzzingProxy( + backend_host=host, + backend_port=conn_spec['port'], + ) + cls.proxy.start() + + @classmethod + def tearDownClass(cls): + cls.proxy.stop() + super().tearDownClass() + + @classmethod + def get_connection_spec(cls, kwargs): + conn_spec = super().get_connection_spec(kwargs) + conn_spec['host'] = cls.proxy.listening_addr + conn_spec['port'] = cls.proxy.listening_port + return conn_spec + + def tearDown(self): + self.proxy.reset() + super().tearDown() def with_connection_options(**options): @@ -223,13 +325,7 @@ def setUp(self): # Extract options set up with `with_connection_options`. test_func = getattr(self, self._testMethodName).__func__ opts = getattr(test_func, '__connect_options__', {}) - if 'database' not in opts: - opts = dict(opts) - opts['database'] = 'postgres' - - self.con = self.loop.run_until_complete( - self.cluster.connect(loop=self.loop, **opts)) - + self.con = self.loop.run_until_complete(self.connect(**opts)) self.server_version = self.con.get_server_version() def tearDown(self): diff --git a/asyncpg/_testbase/fuzzer.py b/asyncpg/_testbase/fuzzer.py new file mode 100644 index 00000000..da1150cf --- /dev/null +++ b/asyncpg/_testbase/fuzzer.py @@ -0,0 +1,295 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncio +import socket +import threading +import typing + +from asyncpg import cluster + + +class StopServer(Exception): + pass + + +class TCPFuzzingProxy: + def __init__(self, *, listening_addr: str='127.0.0.1', + listening_port: typing.Optional[int]=None, + backend_host: str, backend_port: int, + settings: typing.Optional[dict]=None) -> None: + self.listening_addr = listening_addr + self.listening_port = listening_port + self.backend_host = backend_host + self.backend_port = backend_port + self.settings = settings or {} + self.loop = None + self.connectivity = None + self.connectivity_loss = None + self.stop_event = None + self.connections = {} + self.sock = None + self.listen_task = None + + async def _wait(self, work): + work_task = asyncio.ensure_future(work, loop=self.loop) + stop_event_task = asyncio.ensure_future(self.stop_event.wait(), + loop=self.loop) + + try: + await asyncio.wait( + [work_task, stop_event_task], + return_when=asyncio.FIRST_COMPLETED, + loop=self.loop) + + if self.stop_event.is_set(): + raise StopServer() + else: + return work_task.result() + finally: + if not work_task.done(): + work_task.cancel() + if not stop_event_task.done(): + stop_event_task.cancel() + + def start(self): + started = threading.Event() + self.thread = threading.Thread(target=self._start, args=(started,)) + self.thread.start() + if not started.wait(timeout=2): + raise RuntimeError('fuzzer proxy failed to start') + + def stop(self): + self.loop.call_soon_threadsafe(self._stop) + self.thread.join() + + def _stop(self): + self.stop_event.set() + + def _start(self, started_event): + self.loop = asyncio.new_event_loop() + + self.connectivity = asyncio.Event(loop=self.loop) + self.connectivity.set() + self.connectivity_loss = asyncio.Event(loop=self.loop) + self.stop_event = asyncio.Event(loop=self.loop) + + if self.listening_port is None: + self.listening_port = cluster.find_available_port() + + self.sock = socket.socket() + self.sock.bind((self.listening_addr, self.listening_port)) + self.sock.listen(50) + self.sock.setblocking(False) + + try: + self.loop.run_until_complete(self._main(started_event)) + finally: + self.loop.close() + + async def _main(self, started_event): + self.listen_task = asyncio.ensure_future(self.listen(), loop=self.loop) + # Notify the main thread that we are ready to go. + started_event.set() + try: + await self.listen_task + finally: + for c in list(self.connections): + c.close() + await asyncio.sleep(0.01, loop=self.loop) + if hasattr(self.loop, 'remove_reader'): + self.loop.remove_reader(self.sock.fileno()) + self.sock.close() + + async def listen(self): + while True: + try: + client_sock, _ = await self._wait( + self.loop.sock_accept(self.sock)) + + backend_sock = socket.socket() + backend_sock.setblocking(False) + + await self._wait(self.loop.sock_connect( + backend_sock, (self.backend_host, self.backend_port))) + except StopServer: + break + + conn = Connection(client_sock, backend_sock, self) + conn_task = self.loop.create_task(conn.handle()) + self.connections[conn] = conn_task + + def trigger_connectivity_loss(self): + self.loop.call_soon_threadsafe(self._trigger_connectivity_loss) + + def _trigger_connectivity_loss(self): + self.connectivity.clear() + self.connectivity_loss.set() + + def restore_connectivity(self): + self.loop.call_soon_threadsafe(self._restore_connectivity) + + def _restore_connectivity(self): + self.connectivity.set() + self.connectivity_loss.clear() + + def reset(self): + self.restore_connectivity() + + def _close_connection(self, connection): + conn_task = self.connections.pop(connection, None) + if conn_task is not None: + conn_task.cancel() + + +class Connection: + def __init__(self, client_sock, backend_sock, proxy): + self.client_sock = client_sock + self.backend_sock = backend_sock + self.proxy = proxy + self.loop = proxy.loop + self.connectivity = proxy.connectivity + self.connectivity_loss = proxy.connectivity_loss + self.proxy_to_backend_task = None + self.proxy_from_backend_task = None + self.is_closed = False + + def close(self): + if self.is_closed: + return + + self.is_closed = True + + if self.proxy_to_backend_task is not None: + self.proxy_to_backend_task.cancel() + self.proxy_to_backend_task = None + + if self.proxy_from_backend_task is not None: + self.proxy_from_backend_task.cancel() + self.proxy_from_backend_task = None + + self.proxy._close_connection(self) + + async def handle(self): + self.proxy_to_backend_task = asyncio.ensure_future( + self.proxy_to_backend(), loop=self.loop) + + self.proxy_from_backend_task = asyncio.ensure_future( + self.proxy_from_backend(), loop=self.loop) + + try: + await asyncio.wait( + [self.proxy_to_backend_task, self.proxy_from_backend_task], + loop=self.loop, return_when=asyncio.FIRST_COMPLETED) + + finally: + if hasattr(self.loop, 'remove_reader'): + # Asyncio *really* doesn't like when the sockets are + # closed under it. + self.loop.remove_reader(self.client_sock.fileno()) + self.loop.remove_writer(self.client_sock.fileno()) + self.loop.remove_reader(self.backend_sock.fileno()) + self.loop.remove_writer(self.backend_sock.fileno()) + + self.client_sock.close() + self.backend_sock.close() + + async def _read(self, sock, n): + read_task = asyncio.ensure_future( + self.loop.sock_recv(sock, n), + loop=self.loop) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait(), + loop=self.loop) + + try: + await asyncio.wait( + [read_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED, + loop=self.loop) + + if self.connectivity_loss.is_set(): + return None + else: + return read_task.result() + finally: + if not read_task.done(): + read_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def _write(self, sock, data): + write_task = asyncio.ensure_future( + self.loop.sock_sendall(sock, data), loop=self.loop) + conn_event_task = asyncio.ensure_future( + self.connectivity_loss.wait(), loop=self.loop) + + try: + await asyncio.wait( + [write_task, conn_event_task], + return_when=asyncio.FIRST_COMPLETED, + loop=self.loop) + + if self.connectivity_loss.is_set(): + return None + else: + return write_task.result() + finally: + if not write_task.done(): + write_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() + + async def proxy_to_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.client_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.backend_sock, data) + + except ConnectionError: + pass + + finally: + self.loop.call_soon(self.close) + + async def proxy_from_backend(self): + buf = None + + try: + while True: + await self.connectivity.wait() + if buf is not None: + data = buf + buf = None + else: + data = await self._read(self.backend_sock, 4096) + if data == b'': + break + if self.connectivity_loss.is_set(): + if data: + buf = data + continue + await self._write(self.client_sock, data) + + except ConnectionError: + pass + + finally: + self.loop.call_soon(self.close) diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py index 55262202..31a40e37 100644 --- a/asyncpg/cluster.py +++ b/asyncpg/cluster.py @@ -183,7 +183,7 @@ def start(self, wait=60, *, server_settings={}, **opts): # Make sure server certificate key file has correct permissions. keyfile = os.path.join(self._data_dir, 'srvkey.pem') shutil.copy(ssl_key, keyfile) - os.chmod(keyfile, 0o400) + os.chmod(keyfile, 0o600) server_settings = server_settings.copy() server_settings['ssl_key_file'] = keyfile @@ -202,16 +202,24 @@ def start(self, wait=60, *, server_settings={}, **opts): # of postgres daemon under an Administrative account # is not permitted and there is no easy way to drop # privileges. + if os.getenv('ASYNCPG_DEBUG_SERVER'): + stdout = sys.stdout + else: + stdout = subprocess.DEVNULL + process = subprocess.run( [self._pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args)], - stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - stderr = process.stderr + stdout=stdout, stderr=subprocess.STDOUT) if process.returncode != 0: + if process.stderr: + stderr = ':\n{}'.format(process.stderr.decode()) + else: + stderr = '' raise ClusterError( - 'pg_ctl start exited with status {:d}: {}'.format( - process.returncode, stderr.decode())) + 'pg_ctl start exited with status {:d}{}'.format( + process.returncode, stderr)) else: if os.getenv('ASYNCPG_DEBUG_SERVER'): stdout = sys.stdout @@ -448,6 +456,7 @@ def _test_connection(self, timeout=60): try: con = loop.run_until_complete( asyncpg.connect(database='postgres', + user='postgres', timeout=5, loop=loop, **self._connection_addr)) except (OSError, asyncio.TimeoutError, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index c962426a..880c283b 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -44,7 +44,7 @@ class Connection(metaclass=ConnectionMeta): '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners') + '_log_listeners', '_cancellations') def __init__(self, protocol, transport, loop, addr: (str, int) or str, @@ -74,6 +74,7 @@ def __init__(self, protocol, transport, loop, self._listeners = {} self._log_listeners = set() + self._cancellations = set() settings = self._protocol.get_settings() ver_string = settings.server_version @@ -958,15 +959,30 @@ def is_closed(self): """ return not self._protocol.is_connected() or self._aborted - async def close(self): - """Close the connection gracefully.""" + async def close(self, *, timeout=None): + """Close the connection gracefully. + + :param float timeout: + Optional timeout value in seconds. + + .. versionchanged:: 0.14.0 + Added the *timeout* parameter. + """ if self.is_closed(): return self._mark_stmts_as_closed() self._listeners.clear() self._log_listeners.clear() self._aborted = True - await self._protocol.close() + try: + await self._protocol.close(timeout) + except Exception: + # If we fail to close gracefully, abort the connection. + self._aborted = True + self._protocol.abort() + raise + finally: + self._clean_tasks() def terminate(self): """Terminate the connection without waiting for pending data.""" @@ -975,14 +991,23 @@ def terminate(self): self._log_listeners.clear() self._aborted = True self._protocol.abort() + self._clean_tasks() - async def reset(self): + async def reset(self, *, timeout=None): self._check_open() self._listeners.clear() self._log_listeners.clear() reset_query = self._get_reset_query() if reset_query: - await self.execute(reset_query) + await self.execute(reset_query, timeout=timeout) + + def _clean_tasks(self): + # Wrap-up any remaining tasks associated with this connection. + if self._cancellations: + for fut in self._cancellations: + if not fut.done(): + fut.cancel() + self._cancellations.clear() def _check_open(self): if self.is_closed(): @@ -1027,36 +1052,47 @@ async def _cleanup_stmts(self): # so we ignore the timeout. await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) - def _cancel_current_command(self, waiter): - async def cancel(): - try: - # Open new connection to the server - r, w = await connect_utils._open_connection( - loop=self._loop, addr=self._addr, params=self._params) - except Exception as ex: + async def _cancel(self, waiter): + r = w = None + + try: + # Open new connection to the server + r, w = await connect_utils._open_connection( + loop=self._loop, addr=self._addr, params=self._params) + + # Pack CancelRequest message + msg = struct.pack('!llll', 16, 80877102, + self._protocol.backend_pid, + self._protocol.backend_secret) + + w.write(msg) + await r.read() # Wait until EOF + except ConnectionResetError as ex: + # On some systems Postgres will reset the connection + # after processing the cancellation command. + if r is None and not waiter.done(): waiter.set_exception(ex) - return - - try: - # Pack CancelRequest message - msg = struct.pack('!llll', 16, 80877102, - self._protocol.backend_pid, - self._protocol.backend_secret) - - w.write(msg) - await r.read() # Wait until EOF - except ConnectionResetError: - # On some systems Postgres will reset the connection - # after processing the cancellation command. - pass - except Exception as ex: + except asyncio.CancelledError: + # There are two scenarios in which the cancellation + # itself will be cancelled: 1) the connection is being closed, + # 2) the event loop is being shut down. + # In either case we do not care about the propagation of + # the CancelledError, and don't want the loop to warn about + # an unretrieved exception. + pass + except Exception as ex: + if not waiter.done(): waiter.set_exception(ex) - finally: - if not waiter.done(): # Ensure set_exception wasn't called. - waiter.set_result(None) + finally: + self._cancellations.discard( + asyncio.Task.current_task(self._loop)) + if not waiter.done(): + waiter.set_result(None) + if w is not None: w.close() - self._loop.create_task(cancel()) + def _cancel_current_command(self, waiter): + self._cancellations.add(self._loop.create_task(self._cancel(waiter))) def _process_log_message(self, fields, last_query): if not self._log_listeners: diff --git a/asyncpg/pool.py b/asyncpg/pool.py index dff51bfc..6b8df59d 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -8,6 +8,7 @@ import asyncio import functools import inspect +import time from . import connection from . import connect_utils @@ -94,7 +95,7 @@ class PoolConnectionHolder: '_connect_args', '_connect_kwargs', '_max_queries', '_setup', '_init', '_max_inactive_time', '_in_use', - '_inactive_callback') + '_inactive_callback', '_timeout') def __init__(self, pool, *, connect_args, connect_kwargs, max_queries, setup, init, max_inactive_time): @@ -110,6 +111,7 @@ def __init__(self, pool, *, connect_args, connect_kwargs, self._init = init self._inactive_callback = None self._in_use = False + self._timeout = None async def connect(self): assert self._con is None @@ -172,9 +174,10 @@ async def acquire(self) -> PoolConnectionProxy: self._in_use = True return proxy - async def release(self): + async def release(self, timeout): assert self._in_use self._in_use = False + self._timeout = None self._con._on_release() @@ -183,13 +186,25 @@ async def release(self): elif self._con._protocol.queries_count >= self._max_queries: try: - await self._con.close() + await self._con.close(timeout=timeout) finally: self._con = None else: try: - await self._con.reset() + budget = timeout + + if self._con._protocol._is_cancelling(): + # If the connection is in cancellation state, + # wait for the cancellation + started = time.monotonic() + await asyncio.wait_for( + self._con._protocol._wait_for_cancellation(), + budget, loop=self._pool._loop) + if budget is not None: + budget -= time.monotonic() - started + + await self._con.reset(timeout=budget) except Exception as ex: # If the `reset` call failed, terminate the connection. # A new one will be created when `acquire` is called @@ -449,6 +464,9 @@ async def _acquire_impl(): self._queue.put_nowait(ch) raise else: + # Record the timeout, as we will apply it by default + # in release(). + ch._timeout = timeout return proxy self._check_init() @@ -458,11 +476,22 @@ async def _acquire_impl(): return await asyncio.wait_for( _acquire_impl(), timeout=timeout, loop=self._loop) - async def release(self, connection): - """Release a database connection back to the pool.""" - async def _release_impl(ch: PoolConnectionHolder): + async def release(self, connection, *, timeout=None): + """Release a database connection back to the pool. + + :param Connection connection: + A :class:`~asyncpg.connection.Connection` object to release. + :param float timeout: + A timeout for releasing the connection. If not specified, defaults + to the timeout provided in the corresponding call to the + :meth:`Pool.acquire() ` method. + + .. versionchanged:: 0.14.0 + Added the *timeout* parameter. + """ + async def _release_impl(ch: PoolConnectionHolder, timeout: float): try: - await ch.release() + await ch.release(timeout) finally: self._queue.put_nowait(ch) @@ -481,11 +510,14 @@ async def _release_impl(ch: PoolConnectionHolder): connection._detach() + if timeout is None: + timeout = connection._holder._timeout + # Use asyncio.shield() to guarantee that task cancellation # does not prevent the connection from being returned to the # pool properly. - return await asyncio.shield(_release_impl(connection._holder), - loop=self._loop) + return await asyncio.shield( + _release_impl(connection._holder, timeout), loop=self._loop) async def close(self): """Gracefully close all connections in the pool.""" diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index ac9e08d3..09fc8c11 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -353,6 +353,9 @@ cdef class BaseProtocol(CoreProtocol): # Abort the COPY operation on any error in # output sink. self._request_cancel() + # Make asyncio shut up about unretrieved + # QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) raise # done will be True upon receipt of CopyDone. @@ -474,6 +477,8 @@ cdef class BaseProtocol(CoreProtocol): except Exception as e: self._write_copy_fail_msg(str(e)) self._request_cancel() + # Make asyncio shut up about unretrieved QueryCanceledError + waiter.add_done_callback(lambda f: f.exception()) raise self._write_copy_done_msg() @@ -521,24 +526,45 @@ cdef class BaseProtocol(CoreProtocol): self._terminate() self.transport.abort() - async def close(self): + async def close(self, timeout): + if self.closing: + return + + self.closing = True + + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: await self.cancel_waiter - if self.cancel_sent_waiter is not None: + + if self.waiter is not None: + # If there is a query running, cancel it + self._request_cancel() await self.cancel_sent_waiter self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter - self._handle_waiter_on_connection_lost(None) assert self.waiter is None - if self.closing: - return + timeout = self._get_timeout_impl(timeout) + # Ask the server to terminate the connection and wait for it + # to drop. + self.waiter = self._new_waiter(timeout) self._terminate() - self.waiter = self.create_future() - self.closing = True + try: + await self.waiter + except ConnectionResetError: + # There appears to be a difference in behaviour of asyncio + # in Windows, where, instead of calling protocol.connection_lost() + # a ConnectionResetError will be thrown into the task. + pass + finally: + self.waiter = None self.transport.abort() - return await self.waiter def _request_cancel(self): self.cancel_waiter = self.create_future() @@ -615,6 +641,19 @@ cdef class BaseProtocol(CoreProtocol): raise apg_exc.InterfaceError( 'cannot perform operation: another operation is in progress') + def _is_cancelling(self): + return ( + self.cancel_waiter is not None or + self.cancel_sent_waiter is not None + ) + + async def _wait_for_cancellation(self): + if self.cancel_sent_waiter is not None: + await self.cancel_sent_waiter + self.cancel_sent_waiter = None + if self.cancel_waiter is not None: + await self.cancel_waiter + cdef _coreproto_error(self): try: if self.waiter is not None: @@ -764,12 +803,16 @@ cdef class BaseProtocol(CoreProtocol): self.timeout_handle = None if self.cancel_waiter is not None: - # We have received the result of a cancelled operation. - # Simply ignore the result. - self.cancel_waiter.set_result(None) + # We have received the result of a cancelled command. + if not self.cancel_waiter.done(): + # The cancellation future might have been cancelled + # by the cancellation of the entire task running the query. + self.cancel_waiter.set_result(None) self.cancel_waiter = None - self.waiter = None - return + if self.waiter is not None and self.waiter.done(): + self.waiter = None + if self.waiter is None: + return try: self._dispatch_result() diff --git a/tests/test_adversity.py b/tests/test_adversity.py new file mode 100644 index 00000000..a9032b54 --- /dev/null +++ b/tests/test_adversity.py @@ -0,0 +1,64 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + +"""Tests how asyncpg behaves in non-ideal conditions.""" + +import asyncio +import os +import unittest + +from asyncpg import _testbase as tb + + +@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') +class TestConnectionLoss(tb.ProxiedClusterTestCase): + @tb.with_timeout(30.0) + async def test_connection_close_timeout(self): + con = await self.connect() + self.proxy.trigger_connectivity_loss() + with self.assertRaises(asyncio.TimeoutError): + await con.close(timeout=0.5) + + @tb.with_timeout(30.0) + async def test_pool_release_timeout(self): + pool = await self.create_pool( + database='postgres', min_size=2, max_size=2) + try: + with self.assertRaises(asyncio.TimeoutError): + async with pool.acquire(timeout=0.5): + self.proxy.trigger_connectivity_loss() + finally: + self.proxy.restore_connectivity() + await pool.close() + + @tb.with_timeout(30.0) + async def test_pool_handles_abrupt_connection_loss(self): + pool_size = 3 + query_runtime = 0.5 + pool_timeout = cmd_timeout = 1.0 + concurrency = 9 + pool_concurrency = (concurrency - 1) // pool_size + 1 + + # Worst expected runtime + 20% to account for other latencies. + worst_runtime = (pool_timeout + cmd_timeout) * pool_concurrency * 1.2 + + async def worker(pool): + async with pool.acquire(timeout=pool_timeout) as con: + await con.fetch('SELECT pg_sleep($1)', query_runtime) + + def kill_connectivity(): + self.proxy.trigger_connectivity_loss() + + new_pool = self.create_pool( + database='postgres', min_size=pool_size, max_size=pool_size, + timeout=cmd_timeout, command_timeout=cmd_timeout) + + with self.assertRunUnder(worst_runtime): + async with new_pool as pool: + workers = [worker(pool) for _ in range(concurrency)] + self.loop.call_later(1, kill_connectivity) + await asyncio.gather( + *workers, loop=self.loop, return_exceptions=True) diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 07359fd8..2b53fe73 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1011,7 +1011,7 @@ async def test_custom_codec_override_binary(self): """Test overriding core codecs.""" import json - conn = await self.cluster.connect(database='postgres', loop=self.loop) + conn = await self.connect() try: def _encoder(value): return json.dumps(value).encode('utf-8') @@ -1035,7 +1035,7 @@ async def test_custom_codec_override_text(self): """Test overriding core codecs.""" import json - conn = await self.cluster.connect(database='postgres', loop=self.loop) + conn = await self.connect() try: def _encoder(value): return json.dumps(value) @@ -1087,7 +1087,7 @@ async def test_custom_codec_override_tuple(self): ('interval', (2, 3, 1), '2 mons 3 days 00:00:00.000001') ] - conn = await self.cluster.connect(database='postgres', loop=self.loop) + conn = await self.connect() def _encoder(value): return tuple(value) diff --git a/tests/test_connect.py b/tests/test_connect.py index a1e28004..3bc49d95 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -137,57 +137,50 @@ async def _try_connect(self, **kwargs): if _system == 'Windows': for tried in range(3): try: - return await self.cluster.connect(**kwargs) + return await self.connect(**kwargs) except asyncpg.ConnectionDoesNotExistError: pass - return await self.cluster.connect(**kwargs) + return await self.connect(**kwargs) async def test_auth_bad_user(self): with self.assertRaises( asyncpg.InvalidAuthorizationSpecificationError): - await self._try_connect(user='__nonexistent__', - database='postgres', - loop=self.loop) + await self._try_connect(user='__nonexistent__') async def test_auth_trust(self): - conn = await self.cluster.connect( - user='trust_user', database='postgres', loop=self.loop) + conn = await self.connect(user='trust_user') await conn.close() async def test_auth_reject(self): with self.assertRaisesRegex( asyncpg.InvalidAuthorizationSpecificationError, 'pg_hba.conf rejects connection'): - await self._try_connect( - user='reject_user', database='postgres', - loop=self.loop) + await self._try_connect(user='reject_user') async def test_auth_password_cleartext(self): - conn = await self.cluster.connect( - user='password_user', database='postgres', - password='correctpassword', loop=self.loop) + conn = await self.connect( + user='password_user', + password='correctpassword') await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "password_user"'): await self._try_connect( - user='password_user', database='postgres', - password='wrongpassword', loop=self.loop) + user='password_user', + password='wrongpassword') async def test_auth_password_md5(self): - conn = await self.cluster.connect( - user='md5_user', database='postgres', password='correctpassword', - loop=self.loop) + conn = await self.connect( + user='md5_user', password='correctpassword') await conn.close() with self.assertRaisesRegex( asyncpg.InvalidPasswordError, 'password authentication failed for user "md5_user"'): await self._try_connect( - user='md5_user', database='postgres', password='wrongpassword', - loop=self.loop) + user='md5_user', password='wrongpassword') async def test_auth_unsupported(self): pass @@ -494,11 +487,9 @@ async def test_connection_ssl_to_no_ssl_server(self): ssl_context.load_verify_locations(SSL_CA_CERT_FILE) with self.assertRaisesRegex(ConnectionError, 'rejected SSL'): - await self.cluster.connect( + await self.connect( host='localhost', user='ssl_user', - database='postgres', - loop=self.loop, ssl=ssl_context) async def test_connection_ssl_unix(self): @@ -507,15 +498,17 @@ async def test_connection_ssl_unix(self): with self.assertRaisesRegex(asyncpg.InterfaceError, 'can only be enabled for TCP addresses'): - await self.cluster.connect( + await self.connect( host='/tmp', - loop=self.loop, ssl=ssl_context) async def test_connection_implicit_host(self): - conn_spec = self.cluster.get_connection_spec() + conn_spec = self.get_connection_spec() con = await asyncpg.connect( - port=conn_spec.get('port'), database='postgres', loop=self.loop) + port=conn_spec.get('port'), + database=conn_spec.get('database'), + user=conn_spec.get('user'), + loop=self.loop) await con.close() @@ -576,11 +569,9 @@ async def test_ssl_connection_custom_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) - con = await self.cluster.connect( + con = await self.connect( host='localhost', user='ssl_user', - database='postgres', - loop=self.loop, ssl=ssl_context) try: @@ -595,11 +586,9 @@ async def test_ssl_connection_custom_context(self): async def test_ssl_connection_default_context(self): with self.assertRaisesRegex(ssl.SSLError, 'verify failed'): - await self.cluster.connect( + await self.connect( host='localhost', user='ssl_user', - database='postgres', - loop=self.loop, ssl=True) async def test_ssl_connection_pool(self): diff --git a/tests/test_execute.py b/tests/test_execute.py index 78b0b000..ccde0993 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -78,8 +78,7 @@ async def test_execute_script_interrupted_close(self): await self.con.close() self.assertTrue(self.con.is_closed()) - with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError, - 'closed in the middle'): + with self.assertRaises(asyncpg.QueryCanceledError): await fut async def test_execute_script_interrupted_terminate(self): diff --git a/tests/test_introspection.py b/tests/test_introspection.py index 8960399b..d46095f8 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -16,8 +16,7 @@ class TestIntrospection(tb.ConnectedTestCase): @classmethod def setUpClass(cls): super().setUpClass() - cls.adminconn = cls.loop.run_until_complete( - cls.cluster.connect(database='postgres', loop=cls.loop)) + cls.adminconn = cls.loop.run_until_complete(cls.connect()) cls.loop.run_until_complete( cls.adminconn.execute('CREATE DATABASE asyncpg_intro_test')) diff --git a/tests/test_pool.py b/tests/test_pool.py index 30d3d362..8aefb1b4 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -31,15 +31,16 @@ class SlowResetConnection(pg_connection.Connection): """Connection class to simulate races with Connection.reset().""" - async def reset(self): + async def reset(self, *, timeout=None): await asyncio.sleep(0.2, loop=self._loop) - return await super().reset() + return await super().reset(timeout=timeout) -class SlowResetConnectionPool(pg_pool.Pool): - async def _connect(self, *args, **kwargs): - return await pg_connection.connect( - *args, connection_class=SlowResetConnection, **kwargs) +class SlowCancelConnection(pg_connection.Connection): + """Connection class to simulate races with Connection._cancel().""" + async def _cancel(self, waiter): + await asyncio.sleep(0.2, loop=self._loop) + return await super()._cancel(waiter) class TestPool(tb.ConnectedTestCase): @@ -351,12 +352,12 @@ async def worker(): self.cluster.trust_local_connections() self.cluster.reload() - async def test_pool_handles_cancel_in_release(self): + async def test_pool_handles_task_cancel_in_release(self): # Use SlowResetConnectionPool to simulate # the Task.cancel() and __aexit__ race. pool = await self.create_pool(database='postgres', min_size=1, max_size=1, - pool_class=SlowResetConnectionPool) + connection_class=SlowResetConnection) async def worker(): async with pool.acquire(): @@ -372,6 +373,27 @@ async def worker(): # Check that the connection has been returned to the pool. self.assertEqual(pool._queue.qsize(), 1) + async def test_pool_handles_query_cancel_in_release(self): + # Use SlowResetConnectionPool to simulate + # the Task.cancel() and __aexit__ race. + pool = await self.create_pool(database='postgres', + min_size=1, max_size=1, + connection_class=SlowCancelConnection) + + async def worker(): + async with pool.acquire() as con: + await con.execute('SELECT pg_sleep(10)') + + task = self.loop.create_task(worker()) + # Let the worker() run. + await asyncio.sleep(0.1, loop=self.loop) + # Cancel the worker. + task.cancel() + # Wait to make sure the cleanup has completed. + await asyncio.sleep(0.5, loop=self.loop) + # Check that the connection has been returned to the pool. + self.assertEqual(pool._queue.qsize(), 1) + async def test_pool_no_acquire_deadlock(self): async with self.create_pool(database='postgres', min_size=1, max_size=1, @@ -492,9 +514,10 @@ async def run(N, meth): methods = [test_fetch, test_fetchrow, test_fetchval, test_execute, test_execute_with_arg] - for method in methods: - with self.subTest(method=method.__name__): - await run(200, method) + with tb.silence_asyncio_long_exec_warning(): + for method in methods: + with self.subTest(method=method.__name__): + await run(200, method) async def test_pool_connection_execute_many(self): async def worker(pool): @@ -656,7 +679,8 @@ def setUpClass(cls): try: con = cls.loop.run_until_complete( - cls.master_cluster.connect(database='postgres', loop=cls.loop)) + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) cls.loop.run_until_complete( con.execute(''' @@ -696,8 +720,9 @@ def create_pool(self, **kwargs): async def test_standby_pool_01(self): for n in {1, 3, 5, 10, 20, 100}: with self.subTest(tasksnum=n): - pool = await self.create_pool(database='postgres', - min_size=5, max_size=10) + pool = await self.create_pool( + database='postgres', user='postgres', + min_size=5, max_size=10) async def worker(): con = await pool.acquire() @@ -710,7 +735,7 @@ async def worker(): async def test_standby_cursors(self): con = await self.standby_cluster.connect( - database='postgres', loop=self.loop) + database='postgres', user='postgres', loop=self.loop) try: async with con.transaction(): diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 0f2efca0..510e1929 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -94,8 +94,7 @@ async def test_prepare_06_interrupted_close(self): await self.con.close() self.assertTrue(self.con.is_closed()) - with self.assertRaisesRegex(asyncpg.ConnectionDoesNotExistError, - 'closed in the middle'): + with self.assertRaises(asyncpg.QueryCanceledError): await fut # Test that it's OK to call close again diff --git a/tests/test_timeout.py b/tests/test_timeout.py index 0e10029c..e03a3387 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -114,9 +114,7 @@ async def test_invalid_timeout(self): with self.subTest(command_timeout=command_timeout): with self.assertRaisesRegex(ValueError, 'invalid command_timeout'): - await self.cluster.connect( - database='postgres', loop=self.loop, - command_timeout=command_timeout) + await self.connect(command_timeout=command_timeout) # Note: negative timeouts are OK for method calls. for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: