Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.3
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
3 changes: 2 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import taskgroup
import exceptiongroup


class ConnectionClosedError(Exception):
pass

Expand Down Expand Up @@ -39,7 +40,7 @@ async def main():
async with taskgroup.timeout(0.1):
await asyncio.sleep(0.5)
except asyncio.CancelledError:
print(f"got cancelled incorrectly!")
print("got cancelled incorrectly!")
raise
except TimeoutError:
print("timeout as expected")
Expand Down
15 changes: 12 additions & 3 deletions taskgroup/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections.abc
import contextlib
import types
from typing import Any, cast
from typing import cast

from .tasks import task_factory as _task_factory, Task as _Task

Expand Down Expand Up @@ -40,15 +40,23 @@ def add_done_callback(self, fn, *, context):
def _async_yield(v):
return (yield v)


_YieldT_co = TypeVar("_YieldT_co", covariant=True)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
_ReturnT_co = TypeVar("_ReturnT_co", covariant=True, default=None)
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)


class WrapCoro(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):
def __init__(self, coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], context: contextvars.Context):
class WrapCoro(
collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
):
def __init__(
self,
coro: collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
context: contextvars.Context,
):
self._coro = coro
self._context = context

Expand Down Expand Up @@ -88,6 +96,7 @@ async def install_uncancel():

task = asyncio.current_task()
assert task is not None

async def asyncio_main():
return await WrapCoro(task.get_coro(), context=context) # type: ignore # see python/typing#1480

Expand Down
42 changes: 27 additions & 15 deletions taskgroup/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# support uncancel and contexts
from __future__ import annotations

__all__ = ('Runner', 'run')
__all__ = ("Runner", "run")

import collections.abc
import contextvars
Expand All @@ -28,6 +28,7 @@ class _State(enum.Enum):

_T = TypeVar("_T")


@final
class Runner:
"""A context manager that controls event loop life cycle.
Expand Down Expand Up @@ -61,8 +62,8 @@ def __init__(
self,
*,
debug: bool | None = None,
loop_factory: collections.abc.Callable[[], AbstractEventLoop] | None = None
) -> None:
loop_factory: collections.abc.Callable[[], AbstractEventLoop] | None = None,
) -> None:
self._state = _State.CREATED
self._debug = debug
self._loop_factory = loop_factory
Expand Down Expand Up @@ -102,15 +103,21 @@ def get_loop(self) -> AbstractEventLoop:
assert self._loop is not None
return self._loop

def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: contextvars.Context | None = None) -> _T:
def run(
self,
coro: collections.abc.Coroutine[Any, Any, _T],
*,
context: contextvars.Context | None = None,
) -> _T:
"""Run a coroutine inside the embedded event loop."""
if not coroutines.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))

if events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
"Runner.run() cannot be called from a running event loop")
"Runner.run() cannot be called from a running event loop"
)

self._lazy_init()
assert self._loop is not None
Expand All @@ -119,7 +126,8 @@ def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: context
context = self._context
task = _task_factory(self._loop, coro, context=context)

if (threading.current_thread() is threading.main_thread()
if (
threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
Expand All @@ -145,7 +153,8 @@ def run(self, coro: collections.abc.Coroutine[Any, Any, _T], *, context: context
raise KeyboardInterrupt()
raise # CancelledError
finally:
if (sigint_handler is not None
if (
sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)
Expand Down Expand Up @@ -181,7 +190,9 @@ def _on_sigint(self, signum, frame, main_task):
raise KeyboardInterrupt()


def run(main: collections.abc.Coroutine[Any, Any, _T], *, debug: bool | None = None) -> _T:
def run(
main: collections.abc.Coroutine[Any, Any, _T], *, debug: bool | None = None
) -> _T:
"""Execute the coroutine and return the result.

This function runs the passed coroutine, taking care of
Expand All @@ -207,8 +218,7 @@ async def main():
"""
if events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop")
raise RuntimeError("asyncio.run() cannot be called from a running event loop")

with Runner(debug=debug) as runner:
return runner.run(main)
Expand All @@ -228,8 +238,10 @@ def _cancel_all_tasks(loop):
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
50 changes: 29 additions & 21 deletions taskgroup/taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class TaskGroup:

def __init__(self) -> None:
self._entered = False
self._exiting = False
Expand All @@ -40,24 +39,23 @@ def __init__(self) -> None:
self._cmgr = self._cmgr_factory()

def __repr__(self) -> str:
info = ['']
info = [""]
if self._tasks:
info.append(f'tasks={len(self._tasks)}')
info.append(f"tasks={len(self._tasks)}")
if self._errors:
info.append(f'errors={len(self._errors)}')
info.append(f"errors={len(self._errors)}")
if self._aborting:
info.append('cancelling')
info.append("cancelling")
elif self._entered:
info.append('entered')
info.append("entered")

info_str = ' '.join(info)
return f'<TaskGroup{info_str}>'
info_str = " ".join(info)
return f"<TaskGroup{info_str}>"

@contextlib.asynccontextmanager
async def _cmgr_factory(self) -> AsyncGenerator[Self, None]:
if self._entered:
raise RuntimeError(
f"TaskGroup {self!r} has been already entered")
raise RuntimeError(f"TaskGroup {self!r} has been already entered")
self._entered = True

if self._loop is None:
Expand All @@ -67,14 +65,17 @@ async def _cmgr_factory(self) -> AsyncGenerator[Self, None]:
self._parent_task = tasks.current_task(self._loop)
if self._parent_task is None:
raise RuntimeError(
f'TaskGroup {self!r} cannot determine the parent task')
f"TaskGroup {self!r} cannot determine the parent task"
)

try:
yield self
finally:
et, exc, _ = sys.exc_info()
self._exiting = True
propagate_cancellation_error = exc if et is exceptions.CancelledError else None
propagate_cancellation_error = (
exc if et is exceptions.CancelledError else None
)

if self._parent_cancel_requested:
# If this flag is set we *must* call uncancel().
Expand Down Expand Up @@ -148,7 +149,7 @@ async def _cmgr_factory(self) -> AsyncGenerator[Self, None]:
errors = self._errors
self._errors = None

me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)
me = BaseExceptionGroup("unhandled errors in a TaskGroup", errors)
raise me from None

async def __aenter__(self) -> Self:
Expand All @@ -157,8 +158,13 @@ async def __aenter__(self) -> Self:
async def __aexit__(self, *exc_info) -> bool | None:
return await self._cmgr.__aexit__(*exc_info) # type: ignore


def create_task(self, coro: Coroutine[Any, Any, _T], *, name: str | None = None, context: Context | None = None) -> Task[_T]:
def create_task(
self,
coro: Coroutine[Any, Any, _T],
*,
name: str | None = None,
context: Context | None = None,
) -> Task[_T]:
if not self._entered:
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
if self._exiting and not self._tasks:
Expand Down Expand Up @@ -220,12 +226,14 @@ def _on_task_done(self, task):
if self._parent_task.done():
# Not sure if this case is possible, but we want to handle
# it anyways.
self._loop.call_exception_handler({
'message': f'Task {task!r} has errored out but its parent '
f'task {self._parent_task} is already completed',
'exception': exc,
'task': task,
})
self._loop.call_exception_handler(
{
"message": f"Task {task!r} has errored out but its parent "
f"task {self._parent_task} is already completed",
"exception": exc,
"task": task,
}
)
return

if not self._aborting and not self._parent_cancel_requested:
Expand Down
25 changes: 15 additions & 10 deletions taskgroup/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import collections.abc
import contextvars
from typing import Any, cast, TypeAlias
from typing import Any, TypeAlias
from typing_extensions import TypeVar
import sys

Expand All @@ -19,11 +19,17 @@

if sys.version_info >= (3, 12):
_TaskCompatibleCoro: TypeAlias = collections.abc.Coroutine[Any, Any, _T_co]
elif sys.version_info >= (3, 9):
_TaskCompatibleCoro: TypeAlias = collectiona.abc.Generator[_TaskYieldType, None, _T_co] | Coroutine[Any, Any, _T_co]
else:
_TaskCompatibleCoro: TypeAlias = (
collections.abc.Generator[_TaskYieldType, None, _T_co]
| collections.abc.Coroutine[Any, Any, _T_co]
)

class _Interceptor(collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd], collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):

class _Interceptor(
collections.abc.Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
collections.abc.Coroutine[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
):
def __init__(
self,
coro: (
Expand All @@ -50,11 +56,7 @@ def close(self) -> None:

class Task(asyncio.Task[_T_co]):
def __init__(
self,
coro: _TaskCompatibleCoro[_T_co],
*args,
context=None,
**kwargs
self, coro: _TaskCompatibleCoro[_T_co], *args, context=None, **kwargs
) -> None:
self._num_cancels_requested = 0
if context is not None:
Expand All @@ -80,5 +82,8 @@ def get_coro(self) -> _TaskCompatibleCoro[_T_co] | None:
return coro._Interceptor__coro # type: ignore
return coro

def task_factory(loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any) -> Task[_T_co]:

def task_factory(
loop: asyncio.AbstractEventLoop, coro: _TaskCompatibleCoro[_T_co], **kwargs: Any
) -> Task[_T_co]:
return Task(coro, loop=loop, **kwargs)
10 changes: 6 additions & 4 deletions taskgroup/timeouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class _State(enum.Enum):

@final
class Timeout:

def __init__(self, when: Optional[float]) -> None:
self._state = _State.CREATED

Expand Down Expand Up @@ -71,11 +70,11 @@ def expired(self) -> bool:
return self._state in (_State.EXPIRING, _State.EXPIRED)

def __repr__(self) -> str:
info = ['']
info = [""]
if self._state is _State.ENTERED:
when = round(self._when, 3) if self._when is not None else None
info.append(f"when={when}")
info_str = ' '.join(info)
info_str = " ".join(info)
return f"<Timeout [{self._state.value}]{info_str}>"

@contextlib.asynccontextmanager
Expand All @@ -100,7 +99,10 @@ async def _cmgr_factory(self) -> collections.abc.AsyncGenerator[Self, None]:
if self._state is _State.EXPIRING:
self._state = _State.EXPIRED

if self._task.uncancel() <= self._cancelling and exc_type is exceptions.CancelledError:
if (
self._task.uncancel() <= self._cancelling
and exc_type is exceptions.CancelledError
):
# Since there are no outstanding cancel requests, we're
# handling this.
raise TimeoutError from exc_value
Expand Down
Loading