diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2d26e4a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.3 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/demo.py b/demo.py index a649240..339f052 100644 --- a/demo.py +++ b/demo.py @@ -3,6 +3,7 @@ import taskgroup import exceptiongroup + class ConnectionClosedError(Exception): pass @@ -39,7 +40,7 @@ async def main(): async with taskgroup.timeout(0.1): await asyncio.sleep(0.5) except asyncio.CancelledError: - print(f"got cancelled incorrectly!") + print("got cancelled incorrectly!") raise except TimeoutError: print("timeout as expected") diff --git a/taskgroup/install.py b/taskgroup/install.py index 7ac5292..b1259bc 100644 --- a/taskgroup/install.py +++ b/taskgroup/install.py @@ -3,7 +3,7 @@ import collections.abc import contextlib import types -from typing import Any, cast +from typing import cast from .tasks import task_factory as _task_factory, Task as _Task @@ -40,6 +40,7 @@ def add_done_callback(self, fn, *, context): def _async_yield(v): return (yield v) + _YieldT_co = TypeVar("_YieldT_co", covariant=True) _SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None) _ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None) @@ -47,8 +48,15 @@ def _async_yield(v): _ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True) -class WrapCoro(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]): - def __init__(self, coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], context: contextvars.Context): +class WrapCoro( + collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], + collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], +): + def __init__( + self, + coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], + context: contextvars.Context, + ): self._coro = coro self._context = context @@ -88,6 +96,7 @@ async def install_uncancel(): task = asyncio.current_task() assert task is not None + async def asyncio_main(): return await WrapCoro(task.get_coro(), context=context) # type: ignore # see python/typing#1480 diff --git a/taskgroup/runners.py b/taskgroup/runners.py index 7ab1086..04600e7 100644 --- a/taskgroup/runners.py +++ b/taskgroup/runners.py @@ -4,7 +4,7 @@ # support uncancel and contexts from __future__ import annotations -__all__ = ('Runner', 'run') +__all__ = ("Runner", "run") import collections.abc import contextvars @@ -28,6 +28,7 @@ class _State(enum.Enum): _T = TypeVar("_T") + @final class Runner: """A context manager that controls event loop life cycle. @@ -61,8 +62,8 @@ def __init__( self, *, debug: bool | None = None, - loop_factory: collections.abc.Callable[[], AbstractEventLoop] | None = None - ) -> None: + loop_factory: collections.abc.Callable[[], AbstractEventLoop] | None = None, + ) -> None: self._state = _State.CREATED self._debug = debug self._loop_factory = loop_factory @@ -102,7 +103,12 @@ def get_loop(self) -> AbstractEventLoop: assert self._loop is not None return self._loop - def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: contextvars.Context | None = None) -> _T: + def run( + self, + coro: collections.abc.Coroutine[Any, Any, _T], + *, + context: contextvars.Context | None = None, + ) -> _T: """Run a coroutine inside the embedded event loop.""" if not coroutines.iscoroutine(coro): raise ValueError("a coroutine was expected, got {!r}".format(coro)) @@ -110,7 +116,8 @@ def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: context if events._get_running_loop() is not None: # fail fast with short traceback raise RuntimeError( - "Runner.run() cannot be called from a running event loop") + "Runner.run() cannot be called from a running event loop" + ) self._lazy_init() assert self._loop is not None @@ -119,7 +126,8 @@ def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: context context = self._context task = _task_factory(self._loop, coro, context=context) - if (threading.current_thread() is threading.main_thread() + if ( + threading.current_thread() is threading.main_thread() and signal.getsignal(signal.SIGINT) is signal.default_int_handler ): sigint_handler = functools.partial(self._on_sigint, main_task=task) @@ -145,7 +153,8 @@ def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: context raise KeyboardInterrupt() raise # CancelledError finally: - if (sigint_handler is not None + if ( + sigint_handler is not None and signal.getsignal(signal.SIGINT) is sigint_handler ): signal.signal(signal.SIGINT, signal.default_int_handler) @@ -181,7 +190,9 @@ def _on_sigint(self, signum, frame, main_task): raise KeyboardInterrupt() -def run(main: collections.abc.Coroutine[Any, Any, _T], *, debug: bool | None = None) -> _T: +def run( + main: collections.abc.Coroutine[Any, Any, _T], *, debug: bool | None = None +) -> _T: """Execute the coroutine and return the result. This function runs the passed coroutine, taking care of @@ -207,8 +218,7 @@ async def main(): """ if events._get_running_loop() is not None: # fail fast with short traceback - raise RuntimeError( - "asyncio.run() cannot be called from a running event loop") + raise RuntimeError("asyncio.run() cannot be called from a running event loop") with Runner(debug=debug) as runner: return runner.run(main) @@ -228,8 +238,10 @@ def _cancel_all_tasks(loop): if task.cancelled(): continue if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'unhandled exception during asyncio.run() shutdown', - 'exception': task.exception(), - 'task': task, - }) + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/taskgroup/taskgroups.py b/taskgroup/taskgroups.py index 46ed0e2..a518e70 100644 --- a/taskgroup/taskgroups.py +++ b/taskgroup/taskgroups.py @@ -25,7 +25,6 @@ class TaskGroup: - def __init__(self) -> None: self._entered = False self._exiting = False @@ -40,24 +39,23 @@ def __init__(self) -> None: self._cmgr = self._cmgr_factory() def __repr__(self) -> str: - info = [''] + info = [""] if self._tasks: - info.append(f'tasks={len(self._tasks)}') + info.append(f"tasks={len(self._tasks)}") if self._errors: - info.append(f'errors={len(self._errors)}') + info.append(f"errors={len(self._errors)}") if self._aborting: - info.append('cancelling') + info.append("cancelling") elif self._entered: - info.append('entered') + info.append("entered") - info_str = ' '.join(info) - return f'' + info_str = " ".join(info) + return f"" @contextlib.asynccontextmanager async def _cmgr_factory(self) -> AsyncGenerator[Self, None]: if self._entered: - raise RuntimeError( - f"TaskGroup {self!r} has been already entered") + raise RuntimeError(f"TaskGroup {self!r} has been already entered") self._entered = True if self._loop is None: @@ -67,14 +65,17 @@ async def _cmgr_factory(self) -> AsyncGenerator[Self, None]: self._parent_task = tasks.current_task(self._loop) if self._parent_task is None: raise RuntimeError( - f'TaskGroup {self!r} cannot determine the parent task') + f"TaskGroup {self!r} cannot determine the parent task" + ) try: yield self finally: et, exc, _ = sys.exc_info() self._exiting = True - propagate_cancellation_error = exc if et is exceptions.CancelledError else None + propagate_cancellation_error = ( + exc if et is exceptions.CancelledError else None + ) if self._parent_cancel_requested: # If this flag is set we *must* call uncancel(). @@ -148,7 +149,7 @@ async def _cmgr_factory(self) -> AsyncGenerator[Self, None]: errors = self._errors self._errors = None - me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors) + me = BaseExceptionGroup("unhandled errors in a TaskGroup", errors) raise me from None async def __aenter__(self) -> Self: @@ -157,8 +158,13 @@ async def __aenter__(self) -> Self: async def __aexit__(self, *exc_info) -> bool | None: return await self._cmgr.__aexit__(*exc_info) # type: ignore - - def create_task(self, coro: Coroutine[Any, Any, _T], *, name: str | None = None, context: Context | None = None) -> Task[_T]: + def create_task( + self, + coro: Coroutine[Any, Any, _T], + *, + name: str | None = None, + context: Context | None = None, + ) -> Task[_T]: if not self._entered: raise RuntimeError(f"TaskGroup {self!r} has not been entered") if self._exiting and not self._tasks: @@ -220,12 +226,14 @@ def _on_task_done(self, task): if self._parent_task.done(): # Not sure if this case is possible, but we want to handle # it anyways. - self._loop.call_exception_handler({ - 'message': f'Task {task!r} has errored out but its parent ' - f'task {self._parent_task} is already completed', - 'exception': exc, - 'task': task, - }) + self._loop.call_exception_handler( + { + "message": f"Task {task!r} has errored out but its parent " + f"task {self._parent_task} is already completed", + "exception": exc, + "task": task, + } + ) return if not self._aborting and not self._parent_cancel_requested: diff --git a/taskgroup/tasks.py b/taskgroup/tasks.py index 3141b78..dc1ec80 100644 --- a/taskgroup/tasks.py +++ b/taskgroup/tasks.py @@ -3,7 +3,7 @@ import asyncio import collections.abc import contextvars -from typing import Any, cast, TypeAlias +from typing import Any, TypeAlias from typing_extensions import TypeVar import sys @@ -19,11 +19,17 @@ if sys.version_info >= (3, 12): _TaskCompatibleCoro: TypeAlias = collections.abc.Coroutine[Any, Any, _T_co] -elif sys.version_info >= (3, 9): - _TaskCompatibleCoro: TypeAlias = collectiona.abc.Generator[_TaskYieldType, None, _T_co] | Coroutine[Any, Any, _T_co] +else: + _TaskCompatibleCoro: TypeAlias = ( + collections.abc.Generator[_TaskYieldType, None, _T_co] + | collections.abc.Coroutine[Any, Any, _T_co] + ) -class _Interceptor(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]): +class _Interceptor( + collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], + collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], +): def __init__( self, coro: ( @@ -50,11 +56,7 @@ def close(self) -> None: class Task(asyncio.Task[_T_co]): def __init__( - self, - coro: _TaskCompatibleCoro[_T_co], - *args, - context=None, - **kwargs + self, coro: _TaskCompatibleCoro[_T_co], *args, context=None, **kwargs ) -> None: self._num_cancels_requested = 0 if context is not None: @@ -80,5 +82,8 @@ def get_coro(self) -> _TaskCompatibleCoro[_T_co] | None: return coro._Interceptor__coro # type: ignore return coro -def task_factory(loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any) -> Task[_T_co]: + +def task_factory( + loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any +) -> Task[_T_co]: return Task(coro, loop=loop, **kwargs) diff --git a/taskgroup/timeouts.py b/taskgroup/timeouts.py index cbfa5f6..e7994a3 100644 --- a/taskgroup/timeouts.py +++ b/taskgroup/timeouts.py @@ -33,7 +33,6 @@ class _State(enum.Enum): @final class Timeout: - def __init__(self, when: Optional[float]) -> None: self._state = _State.CREATED @@ -71,11 +70,11 @@ def expired(self) -> bool: return self._state in (_State.EXPIRING, _State.EXPIRED) def __repr__(self) -> str: - info = [''] + info = [""] if self._state is _State.ENTERED: when = round(self._when, 3) if self._when is not None else None info.append(f"when={when}") - info_str = ' '.join(info) + info_str = " ".join(info) return f"" @contextlib.asynccontextmanager @@ -100,7 +99,10 @@ async def _cmgr_factory(self) -> collections.abc.AsyncGenerator[Self, None]: if self._state is _State.EXPIRING: self._state = _State.EXPIRED - if self._task.uncancel() <= self._cancelling and exc_type is exceptions.CancelledError: + if ( + self._task.uncancel() <= self._cancelling + and exc_type is exceptions.CancelledError + ): # Since there are no outstanding cancel requests, we're # handling this. raise TimeoutError from exc_value