diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 3a0d2e0df..7b09f5a9f 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,9 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed an async fixture's ``self`` being different than the test's ``self`` in class-based tests (`#633 `_) (PR by @agronholm and @graingert) +- Fixed TaskGroup and CancelScope producing cyclic references in tracebacks + when raising exceptions (`#806 `_) + (PR by @graingert) **4.6.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index ed00dee04..0a69e7ac6 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -425,6 +425,8 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: + del exc_tb + if not self._active: raise RuntimeError("This cancel scope is not active") if current_task() is not self._host_task: @@ -441,42 +443,46 @@ def __exit__( "current cancel scope" ) - self._active = False - if self._timeout_handle: - self._timeout_handle.cancel() - self._timeout_handle = None - - self._tasks.remove(self._host_task) - if self._parent_scope is not None: - self._parent_scope._child_scopes.remove(self) - self._parent_scope._tasks.add(self._host_task) - - host_task_state.cancel_scope = self._parent_scope - - # Undo all cancellations done by this scope - if self._cancelling is not None: - while self._cancel_calls: - self._cancel_calls -= 1 - if self._host_task.uncancel() <= self._cancelling: - break + try: + self._active = False + if self._timeout_handle: + self._timeout_handle.cancel() + self._timeout_handle = None - # We only swallow the exception iff it was an AnyIO CancelledError, either - # directly as exc_val or inside an exception group and there are no cancelled - # parent cancel scopes visible to us here - not_swallowed_exceptions = 0 - swallow_exception = False - if exc_val is not None: - for exc in iterate_exceptions(exc_val): - if self._cancel_called and isinstance(exc, CancelledError): - if not (swallow_exception := self._uncancel(exc)): + self._tasks.remove(self._host_task) + if self._parent_scope is not None: + self._parent_scope._child_scopes.remove(self) + self._parent_scope._tasks.add(self._host_task) + + host_task_state.cancel_scope = self._parent_scope + + # Undo all cancellations done by this scope + if self._cancelling is not None: + while self._cancel_calls: + self._cancel_calls -= 1 + if self._host_task.uncancel() <= self._cancelling: + break + + # We only swallow the exception iff it was an AnyIO CancelledError, either + # directly as exc_val or inside an exception group and there are no cancelled + # parent cancel scopes visible to us here + not_swallowed_exceptions = 0 + swallow_exception = False + if exc_val is not None: + for exc in iterate_exceptions(exc_val): + if self._cancel_called and isinstance(exc, CancelledError): + if not (swallow_exception := self._uncancel(exc)): + not_swallowed_exceptions += 1 + else: not_swallowed_exceptions += 1 - else: - not_swallowed_exceptions += 1 - # Restart the cancellation effort in the closest visible, cancelled parent - # scope if necessary - self._restart_cancellation_in_parent() - return swallow_exception and not not_swallowed_exceptions + # Restart the cancellation effort in the closest visible, cancelled parent + # scope if necessary + self._restart_cancellation_in_parent() + return swallow_exception and not not_swallowed_exceptions + finally: + self._host_task = None + del exc_val @property def _effectively_cancelled(self) -> bool: @@ -683,6 +689,26 @@ def started(self, value: T_contra | None = None) -> None: _task_states[task].parent_id = self._parent_id +async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None: + tasks = set(tasks) + waiter = get_running_loop().create_future() + + def on_completion(task: asyncio.Task[object]) -> None: + tasks.discard(task) + if not tasks and not waiter.done(): + waiter.set_result(None) + + for task in tasks: + task.add_done_callback(on_completion) + del task + + try: + await waiter + finally: + while tasks: + tasks.pop().remove_done_callback(on_completion) + + class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() @@ -701,50 +727,53 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - if exc_val is not None: - self.cancel_scope.cancel() - if not isinstance(exc_val, CancelledError): - self._exceptions.append(exc_val) - try: - if self._tasks: - with CancelScope() as wait_scope: - while self._tasks: - try: - await asyncio.wait(self._tasks) - except CancelledError as exc: - # Shield the scope against further cancellation attempts, - # as they're not productive (#695) - wait_scope.shield = True - self.cancel_scope.cancel() - - # Set exc_val from the cancellation exception if it was - # previously unset. However, we should not replace a native - # cancellation exception with one raise by a cancel scope. - if exc_val is None or ( - isinstance(exc_val, CancelledError) - and not is_anyio_cancellation(exc) - ): - exc_val = exc - else: - # If there are no child tasks to wait on, run at least one checkpoint - # anyway - await AsyncIOBackend.cancel_shielded_checkpoint() + if exc_val is not None: + self.cancel_scope.cancel() + if not isinstance(exc_val, CancelledError): + self._exceptions.append(exc_val) - self._active = False - if self._exceptions: - raise BaseExceptionGroup( - "unhandled errors in a TaskGroup", self._exceptions - ) - elif exc_val: - raise exc_val - except BaseException as exc: - if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__): - return True + try: + if self._tasks: + with CancelScope() as wait_scope: + while self._tasks: + try: + await _wait(self._tasks) + except CancelledError as exc: + # Shield the scope against further cancellation attempts, + # as they're not productive (#695) + wait_scope.shield = True + self.cancel_scope.cancel() + + # Set exc_val from the cancellation exception if it was + # previously unset. However, we should not replace a native + # cancellation exception with one raise by a cancel scope. + if exc_val is None or ( + isinstance(exc_val, CancelledError) + and not is_anyio_cancellation(exc) + ): + exc_val = exc + else: + # If there are no child tasks to wait on, run at least one checkpoint + # anyway + await AsyncIOBackend.cancel_shielded_checkpoint() - raise + self._active = False + if self._exceptions: + raise BaseExceptionGroup( + "unhandled errors in a TaskGroup", self._exceptions + ) + elif exc_val: + raise exc_val + except BaseException as exc: + if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__): + return True + + raise - return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) + return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) + finally: + del exc_val, exc_tb, self._exceptions def _spawn( self, diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index c10a72a19..24dcd7444 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -186,13 +186,12 @@ async def __aexit__( try: return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb) except BaseExceptionGroup as exc: - _, rest = exc.split(trio.Cancelled) - if not rest: - cancelled_exc = trio.Cancelled._create() - raise cancelled_exc from exc + if not exc.split(trio.Cancelled)[1]: + raise trio.Cancelled._create() from exc raise finally: + del exc_val, exc_tb self._active = False def start_soon( diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 31490572b..78ef9983c 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import gc import math import sys import time @@ -9,7 +10,7 @@ from typing import Any, NoReturn, cast import pytest -from exceptiongroup import catch +from exceptiongroup import ExceptionGroup, catch from pytest_mock import MockerFixture import anyio @@ -1548,6 +1549,125 @@ async def in_task_group(task_status: TaskStatus[None]) -> None: assert not tg.cancel_scope.cancel_called +if sys.version_info <= (3, 11): + + def no_other_refs() -> list[object]: + return [sys._getframe(1)] +else: + + def no_other_refs() -> list[object]: + return [] + + +@pytest.mark.skipif( + sys.implementation.name == "pypy", + reason=( + "gc.get_referrers is broken on PyPy see " + "https://github.com/pypy/pypy/issues/5075" + ), +) +class TestRefcycles: + async def test_exception_refcycles_direct(self) -> None: + """ + Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup + + Note: This test never failed on anyio, but keeping this test to align + with the tests from cpython. + """ + tg = create_task_group() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except ExceptionGroup as e: + exc = e + + assert exc is not None + assert gc.get_referrers(exc) == no_other_refs() + + async def test_exception_refcycles_errors(self) -> None: + """Test that TaskGroup deletes self._exceptions, and __aexit__ args""" + tg = create_task_group() + exc = None + + class _Done(Exception): + pass + + try: + async with tg: + raise _Done + except ExceptionGroup as excs: + exc = excs.exceptions[0] + + assert isinstance(exc, _Done) + assert gc.get_referrers(exc) == no_other_refs() + + async def test_exception_refcycles_parent_task(self) -> None: + """Test that TaskGroup's cancel_scope deletes self._host_task""" + tg = create_task_group() + exc = None + + class _Done(Exception): + pass + + async def coro_fn() -> None: + async with tg: + raise _Done + + try: + async with anyio.create_task_group() as tg2: + tg2.start_soon(coro_fn) + except ExceptionGroup as excs: + exc = excs.exceptions[0].exceptions[0] + + assert isinstance(exc, _Done) + assert gc.get_referrers(exc) == no_other_refs() + + async def test_exception_refcycles_propagate_cancellation_error(self) -> None: + """Test that TaskGroup deletes cancelled_exc""" + tg = anyio.create_task_group() + exc = None + + with CancelScope() as cs: + cs.cancel() + try: + async with tg: + await checkpoint() + except get_cancelled_exc_class() as e: + exc = e + raise + + assert isinstance(exc, get_cancelled_exc_class()) + assert gc.get_referrers(exc) == no_other_refs() + + async def test_exception_refcycles_base_error(self) -> None: + """ + Test for BaseExceptions. + + anyio doesn't treat these differently so this test is redundant + but copied from CPython's asyncio.TaskGroup tests for completion. + """ + + class MyKeyboardInterrupt(KeyboardInterrupt): + pass + + tg = create_task_group() + exc = None + + try: + async with tg: + raise MyKeyboardInterrupt + except BaseExceptionGroup as excs: + exc = excs.exceptions[0] + + assert isinstance(exc, MyKeyboardInterrupt) + assert gc.get_referrers(exc) == no_other_refs() + + class TestTaskStatusTyping: """ These tests do not do anything at run time, but since the test suite is also checked