diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 3a0d2e0df..7b09f5a9f 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -10,6 +10,9 @@ This library adheres to `Semantic Versioning 2.0 `_.
- Fixed an async fixture's ``self`` being different than the test's ``self`` in
class-based tests (`#633 `_)
(PR by @agronholm and @graingert)
+- Fixed TaskGroup and CancelScope producing cyclic references in tracebacks
+ when raising exceptions (`#806 `_)
+ (PR by @graingert)
**4.6.0**
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index ed00dee04..0a69e7ac6 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -425,6 +425,8 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
+ del exc_tb
+
if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
@@ -441,42 +443,46 @@ def __exit__(
"current cancel scope"
)
- self._active = False
- if self._timeout_handle:
- self._timeout_handle.cancel()
- self._timeout_handle = None
-
- self._tasks.remove(self._host_task)
- if self._parent_scope is not None:
- self._parent_scope._child_scopes.remove(self)
- self._parent_scope._tasks.add(self._host_task)
-
- host_task_state.cancel_scope = self._parent_scope
-
- # Undo all cancellations done by this scope
- if self._cancelling is not None:
- while self._cancel_calls:
- self._cancel_calls -= 1
- if self._host_task.uncancel() <= self._cancelling:
- break
+ try:
+ self._active = False
+ if self._timeout_handle:
+ self._timeout_handle.cancel()
+ self._timeout_handle = None
- # We only swallow the exception iff it was an AnyIO CancelledError, either
- # directly as exc_val or inside an exception group and there are no cancelled
- # parent cancel scopes visible to us here
- not_swallowed_exceptions = 0
- swallow_exception = False
- if exc_val is not None:
- for exc in iterate_exceptions(exc_val):
- if self._cancel_called and isinstance(exc, CancelledError):
- if not (swallow_exception := self._uncancel(exc)):
+ self._tasks.remove(self._host_task)
+ if self._parent_scope is not None:
+ self._parent_scope._child_scopes.remove(self)
+ self._parent_scope._tasks.add(self._host_task)
+
+ host_task_state.cancel_scope = self._parent_scope
+
+ # Undo all cancellations done by this scope
+ if self._cancelling is not None:
+ while self._cancel_calls:
+ self._cancel_calls -= 1
+ if self._host_task.uncancel() <= self._cancelling:
+ break
+
+ # We only swallow the exception iff it was an AnyIO CancelledError, either
+ # directly as exc_val or inside an exception group and there are no cancelled
+ # parent cancel scopes visible to us here
+ not_swallowed_exceptions = 0
+ swallow_exception = False
+ if exc_val is not None:
+ for exc in iterate_exceptions(exc_val):
+ if self._cancel_called and isinstance(exc, CancelledError):
+ if not (swallow_exception := self._uncancel(exc)):
+ not_swallowed_exceptions += 1
+ else:
not_swallowed_exceptions += 1
- else:
- not_swallowed_exceptions += 1
- # Restart the cancellation effort in the closest visible, cancelled parent
- # scope if necessary
- self._restart_cancellation_in_parent()
- return swallow_exception and not not_swallowed_exceptions
+ # Restart the cancellation effort in the closest visible, cancelled parent
+ # scope if necessary
+ self._restart_cancellation_in_parent()
+ return swallow_exception and not not_swallowed_exceptions
+ finally:
+ self._host_task = None
+ del exc_val
@property
def _effectively_cancelled(self) -> bool:
@@ -683,6 +689,26 @@ def started(self, value: T_contra | None = None) -> None:
_task_states[task].parent_id = self._parent_id
+async def _wait(tasks: Iterable[asyncio.Task[object]]) -> None:
+ tasks = set(tasks)
+ waiter = get_running_loop().create_future()
+
+ def on_completion(task: asyncio.Task[object]) -> None:
+ tasks.discard(task)
+ if not tasks and not waiter.done():
+ waiter.set_result(None)
+
+ for task in tasks:
+ task.add_done_callback(on_completion)
+ del task
+
+ try:
+ await waiter
+ finally:
+ while tasks:
+ tasks.pop().remove_done_callback(on_completion)
+
+
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
@@ -701,50 +727,53 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
- if exc_val is not None:
- self.cancel_scope.cancel()
- if not isinstance(exc_val, CancelledError):
- self._exceptions.append(exc_val)
-
try:
- if self._tasks:
- with CancelScope() as wait_scope:
- while self._tasks:
- try:
- await asyncio.wait(self._tasks)
- except CancelledError as exc:
- # Shield the scope against further cancellation attempts,
- # as they're not productive (#695)
- wait_scope.shield = True
- self.cancel_scope.cancel()
-
- # Set exc_val from the cancellation exception if it was
- # previously unset. However, we should not replace a native
- # cancellation exception with one raise by a cancel scope.
- if exc_val is None or (
- isinstance(exc_val, CancelledError)
- and not is_anyio_cancellation(exc)
- ):
- exc_val = exc
- else:
- # If there are no child tasks to wait on, run at least one checkpoint
- # anyway
- await AsyncIOBackend.cancel_shielded_checkpoint()
+ if exc_val is not None:
+ self.cancel_scope.cancel()
+ if not isinstance(exc_val, CancelledError):
+ self._exceptions.append(exc_val)
- self._active = False
- if self._exceptions:
- raise BaseExceptionGroup(
- "unhandled errors in a TaskGroup", self._exceptions
- )
- elif exc_val:
- raise exc_val
- except BaseException as exc:
- if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
- return True
+ try:
+ if self._tasks:
+ with CancelScope() as wait_scope:
+ while self._tasks:
+ try:
+ await _wait(self._tasks)
+ except CancelledError as exc:
+ # Shield the scope against further cancellation attempts,
+ # as they're not productive (#695)
+ wait_scope.shield = True
+ self.cancel_scope.cancel()
+
+ # Set exc_val from the cancellation exception if it was
+ # previously unset. However, we should not replace a native
+ # cancellation exception with one raise by a cancel scope.
+ if exc_val is None or (
+ isinstance(exc_val, CancelledError)
+ and not is_anyio_cancellation(exc)
+ ):
+ exc_val = exc
+ else:
+ # If there are no child tasks to wait on, run at least one checkpoint
+ # anyway
+ await AsyncIOBackend.cancel_shielded_checkpoint()
- raise
+ self._active = False
+ if self._exceptions:
+ raise BaseExceptionGroup(
+ "unhandled errors in a TaskGroup", self._exceptions
+ )
+ elif exc_val:
+ raise exc_val
+ except BaseException as exc:
+ if self.cancel_scope.__exit__(type(exc), exc, exc.__traceback__):
+ return True
+
+ raise
- return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
+ return self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
+ finally:
+ del exc_val, exc_tb, self._exceptions
def _spawn(
self,
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index c10a72a19..24dcd7444 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -186,13 +186,12 @@ async def __aexit__(
try:
return await self._nursery_manager.__aexit__(exc_type, exc_val, exc_tb)
except BaseExceptionGroup as exc:
- _, rest = exc.split(trio.Cancelled)
- if not rest:
- cancelled_exc = trio.Cancelled._create()
- raise cancelled_exc from exc
+ if not exc.split(trio.Cancelled)[1]:
+ raise trio.Cancelled._create() from exc
raise
finally:
+ del exc_val, exc_tb
self._active = False
def start_soon(
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index 31490572b..78ef9983c 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import gc
import math
import sys
import time
@@ -9,7 +10,7 @@
from typing import Any, NoReturn, cast
import pytest
-from exceptiongroup import catch
+from exceptiongroup import ExceptionGroup, catch
from pytest_mock import MockerFixture
import anyio
@@ -1548,6 +1549,125 @@ async def in_task_group(task_status: TaskStatus[None]) -> None:
assert not tg.cancel_scope.cancel_called
+if sys.version_info <= (3, 11):
+
+ def no_other_refs() -> list[object]:
+ return [sys._getframe(1)]
+else:
+
+ def no_other_refs() -> list[object]:
+ return []
+
+
+@pytest.mark.skipif(
+ sys.implementation.name == "pypy",
+ reason=(
+ "gc.get_referrers is broken on PyPy see "
+ "https://github.com/pypy/pypy/issues/5075"
+ ),
+)
+class TestRefcycles:
+ async def test_exception_refcycles_direct(self) -> None:
+ """
+ Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup
+
+ Note: This test never failed on anyio, but keeping this test to align
+ with the tests from cpython.
+ """
+ tg = create_task_group()
+ exc = None
+
+ class _Done(Exception):
+ pass
+
+ try:
+ async with tg:
+ raise _Done
+ except ExceptionGroup as e:
+ exc = e
+
+ assert exc is not None
+ assert gc.get_referrers(exc) == no_other_refs()
+
+ async def test_exception_refcycles_errors(self) -> None:
+ """Test that TaskGroup deletes self._exceptions, and __aexit__ args"""
+ tg = create_task_group()
+ exc = None
+
+ class _Done(Exception):
+ pass
+
+ try:
+ async with tg:
+ raise _Done
+ except ExceptionGroup as excs:
+ exc = excs.exceptions[0]
+
+ assert isinstance(exc, _Done)
+ assert gc.get_referrers(exc) == no_other_refs()
+
+ async def test_exception_refcycles_parent_task(self) -> None:
+ """Test that TaskGroup's cancel_scope deletes self._host_task"""
+ tg = create_task_group()
+ exc = None
+
+ class _Done(Exception):
+ pass
+
+ async def coro_fn() -> None:
+ async with tg:
+ raise _Done
+
+ try:
+ async with anyio.create_task_group() as tg2:
+ tg2.start_soon(coro_fn)
+ except ExceptionGroup as excs:
+ exc = excs.exceptions[0].exceptions[0]
+
+ assert isinstance(exc, _Done)
+ assert gc.get_referrers(exc) == no_other_refs()
+
+ async def test_exception_refcycles_propagate_cancellation_error(self) -> None:
+ """Test that TaskGroup deletes cancelled_exc"""
+ tg = anyio.create_task_group()
+ exc = None
+
+ with CancelScope() as cs:
+ cs.cancel()
+ try:
+ async with tg:
+ await checkpoint()
+ except get_cancelled_exc_class() as e:
+ exc = e
+ raise
+
+ assert isinstance(exc, get_cancelled_exc_class())
+ assert gc.get_referrers(exc) == no_other_refs()
+
+ async def test_exception_refcycles_base_error(self) -> None:
+ """
+ Test for BaseExceptions.
+
+ anyio doesn't treat these differently so this test is redundant
+ but copied from CPython's asyncio.TaskGroup tests for completion.
+ """
+
+ class MyKeyboardInterrupt(KeyboardInterrupt):
+ pass
+
+ tg = create_task_group()
+ exc = None
+
+ try:
+ async with tg:
+ raise MyKeyboardInterrupt
+ except BaseExceptionGroup as excs:
+ exc = excs.exceptions[0]
+
+ assert isinstance(exc, MyKeyboardInterrupt)
+ assert gc.get_referrers(exc) == no_other_refs()
+
+
class TestTaskStatusTyping:
"""
These tests do not do anything at run time, but since the test suite is also checked