From 428b9330ce694a528c1d0188ba01d10643e740b3 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sun, 11 Feb 2024 15:40:40 +0000 Subject: [PATCH 1/3] Improve annotations for various set methods --- stdlib/_weakrefset.pyi | 13 ++++++++----- stdlib/builtins.pyi | 16 +++++++++++----- stdlib/typing.pyi | 9 ++++++--- test_cases/stdlib/builtins/check_set.py | 23 +++++++++++++++++++++++ 4 files changed, 48 insertions(+), 13 deletions(-) create mode 100644 test_cases/stdlib/builtins/check_set.py diff --git a/stdlib/_weakrefset.pyi b/stdlib/_weakrefset.pyi index 6482ade1271e..02c2eb213d5e 100644 --- a/stdlib/_weakrefset.pyi +++ b/stdlib/_weakrefset.pyi @@ -1,5 +1,5 @@ import sys -from collections.abc import Iterable, Iterator, MutableSet +from collections.abc import Iterable, Iterator, MutableSet, Set as AbstractSet from typing import Any, TypeVar, overload from typing_extensions import Self @@ -17,7 +17,7 @@ class WeakSet(MutableSet[_T]): @overload def __init__(self, data: Iterable[_T]) -> None: ... def add(self, item: _T) -> None: ... - def discard(self, item: _T) -> None: ... + def discard(self, item: _T | None) -> None: ... def copy(self) -> Self: ... def remove(self, item: _T) -> None: ... def update(self, other: Iterable[_T]) -> None: ... @@ -26,9 +26,12 @@ class WeakSet(MutableSet[_T]): def __iter__(self) -> Iterator[_T]: ... def __ior__(self, other: Iterable[_T]) -> Self: ... # type: ignore[override,misc] def difference(self, other: Iterable[_T]) -> Self: ... - def __sub__(self, other: Iterable[Any]) -> Self: ... - def difference_update(self, other: Iterable[Any]) -> None: ... - def __isub__(self, other: Iterable[Any]) -> Self: ... + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], other: Iterable[None]) -> WeakSet[_S]: ... # type: ignore[overload-overlap] + @overload + def __sub__(self, other: Iterable[_T | None]) -> Self: ... + def difference_update(self, other: Iterable[_T | None]) -> None: ... + def __isub__(self, other: Iterable[_T | None]) -> Self: ... # type: ignore[misc] def intersection(self, other: Iterable[_T]) -> Self: ... def __and__(self, other: Iterable[Any]) -> Self: ... def intersection_update(self, other: Iterable[Any]) -> None: ... diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index b8c807dd388d..6d06165953b3 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1103,9 +1103,9 @@ class set(MutableSet[_T]): def __init__(self, __iterable: Iterable[_T]) -> None: ... def add(self, __element: _T) -> None: ... def copy(self) -> set[_T]: ... - def difference(self, *s: Iterable[Any]) -> set[_T]: ... - def difference_update(self, *s: Iterable[Any]) -> None: ... - def discard(self, __element: _T) -> None: ... + def difference(self, *s: Iterable[_T | None]) -> set[_T]: ... + def difference_update(self, *s: Iterable[_T | None]) -> None: ... + def discard(self, __element: _T | None) -> None: ... def intersection(self, *s: Iterable[Any]) -> set[_T]: ... def intersection_update(self, *s: Iterable[Any]) -> None: ... def isdisjoint(self, __s: Iterable[Any]) -> bool: ... @@ -1123,8 +1123,11 @@ class set(MutableSet[_T]): def __iand__(self, __value: AbstractSet[object]) -> Self: ... def __or__(self, __value: AbstractSet[_S]) -> set[_T | _S]: ... def __ior__(self, __value: AbstractSet[_T]) -> Self: ... # type: ignore[override,misc] + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], __value: AbstractSet[None]) -> set[_S]: ... + @overload def __sub__(self, __value: AbstractSet[_T | None]) -> set[_T]: ... - def __isub__(self, __value: AbstractSet[object]) -> Self: ... + def __isub__(self, __value: AbstractSet[_T | None]) -> Self: ... # type: ignore[misc] def __xor__(self, __value: AbstractSet[_S]) -> set[_T | _S]: ... def __ixor__(self, __value: AbstractSet[_T]) -> Self: ... # type: ignore[override,misc] def __le__(self, __value: AbstractSet[object]) -> bool: ... @@ -1154,7 +1157,10 @@ class frozenset(AbstractSet[_T_co]): def __iter__(self) -> Iterator[_T_co]: ... def __and__(self, __value: AbstractSet[_T_co]) -> frozenset[_T_co]: ... def __or__(self, __value: AbstractSet[_S]) -> frozenset[_T_co | _S]: ... - def __sub__(self, __value: AbstractSet[_T_co]) -> frozenset[_T_co]: ... + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], __value: AbstractSet[None]) -> frozenset[_S]: ... + @overload + def __sub__(self, __value: AbstractSet[_T_co | None]) -> frozenset[_T_co]: ... def __xor__(self, __value: AbstractSet[_S]) -> frozenset[_T_co | _S]: ... def __le__(self, __value: AbstractSet[object]) -> bool: ... def __lt__(self, __value: AbstractSet[object]) -> bool: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index 5d01be539016..4ec22d697418 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -569,7 +569,10 @@ class AbstractSet(Collection[_T_co]): def __ge__(self, other: AbstractSet[Any]) -> bool: ... def __and__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ... def __or__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ... - def __sub__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ... + @overload + def __sub__(self: AbstractSet[_S | None], other: AbstractSet[None]) -> AbstractSet[_S]: ... + @overload + def __sub__(self, other: AbstractSet[_T_co | None]) -> AbstractSet[_T_co]: ... def __xor__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ... def __eq__(self, other: object) -> bool: ... def isdisjoint(self, other: Iterable[Any]) -> bool: ... @@ -578,7 +581,7 @@ class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T) -> None: ... @abstractmethod - def discard(self, value: _T) -> None: ... + def discard(self, value: _T | None) -> None: ... # Mixin methods def clear(self) -> None: ... def pop(self) -> _T: ... @@ -586,7 +589,7 @@ class MutableSet(AbstractSet[_T]): def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] - def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... + def __isub__(self, it: AbstractSet[_T | None]) -> typing_extensions.Self: ... # type: ignore[misc] class MappingView(Sized): def __init__(self, mapping: Mapping[Any, Any]) -> None: ... # undocumented diff --git a/test_cases/stdlib/builtins/check_set.py b/test_cases/stdlib/builtins/check_set.py new file mode 100644 index 000000000000..3afb234679c2 --- /dev/null +++ b/test_cases/stdlib/builtins/check_set.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing_extensions import assert_type + +# We special case `AbstractSet[None] in set.__sub__ and frozenset.__sub__ +# so that it can be used for narrowing `set[T|None]` to `set[T]` +x = {"foo", "bar", None} +y = frozenset(x) +assert_type(x - {None}, set[str]) +assert_type(y - {None}, frozenset[str]) + +# For most other cases of set subtraction, we're pretty restrictive about what's allowed. +# `set[T] - set[S]` is an error, even though it won't cause an exception at runtime, +# as it will always be a useless no-op +{"foo", "bar"} - {1, 2} # type: ignore + +# But subtracting set[T|None] from set[T] is allowed, as a convenience; +# this comes up a lot in real-life code: +assert_type({"foo", "bar"} - {"foo", None}, set[str]) +x = {"foo", "bar"} +x.difference_update({"foo", "bar", None}) +name: str | None = "foo" +x.discard(name) From 669a209043f93880b6fac4d7cc6194dcbdf34100 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sun, 11 Feb 2024 15:50:42 +0000 Subject: [PATCH 2/3] py38 compat --- test_cases/stdlib/builtins/check_set.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test_cases/stdlib/builtins/check_set.py b/test_cases/stdlib/builtins/check_set.py index 3afb234679c2..4dd003bf30f5 100644 --- a/test_cases/stdlib/builtins/check_set.py +++ b/test_cases/stdlib/builtins/check_set.py @@ -1,13 +1,14 @@ from __future__ import annotations +from typing import FrozenSet, Set from typing_extensions import assert_type # We special case `AbstractSet[None] in set.__sub__ and frozenset.__sub__ # so that it can be used for narrowing `set[T|None]` to `set[T]` x = {"foo", "bar", None} y = frozenset(x) -assert_type(x - {None}, set[str]) -assert_type(y - {None}, frozenset[str]) +assert_type(x - {None}, Set[str]) +assert_type(y - {None}, FrozenSet[str]) # For most other cases of set subtraction, we're pretty restrictive about what's allowed. # `set[T] - set[S]` is an error, even though it won't cause an exception at runtime, @@ -16,7 +17,7 @@ # But subtracting set[T|None] from set[T] is allowed, as a convenience; # this comes up a lot in real-life code: -assert_type({"foo", "bar"} - {"foo", None}, set[str]) +assert_type({"foo", "bar"} - {"foo", None}, Set[str]) x = {"foo", "bar"} x.difference_update({"foo", "bar", None}) name: str | None = "foo" From 481407f1e9840117a5c4bddc5ec713a4adc4512e Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sun, 11 Feb 2024 16:58:44 +0000 Subject: [PATCH 3/3] Update _weakrefset.pyi --- stdlib/_weakrefset.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/_weakrefset.pyi b/stdlib/_weakrefset.pyi index 02c2eb213d5e..370e611825a0 100644 --- a/stdlib/_weakrefset.pyi +++ b/stdlib/_weakrefset.pyi @@ -25,7 +25,7 @@ class WeakSet(MutableSet[_T]): def __len__(self) -> int: ... def __iter__(self) -> Iterator[_T]: ... def __ior__(self, other: Iterable[_T]) -> Self: ... # type: ignore[override,misc] - def difference(self, other: Iterable[_T]) -> Self: ... + def difference(self, other: Iterable[_T | None]) -> Self: ... @overload # type: ignore[override] def __sub__(self: AbstractSet[_S | None], other: Iterable[None]) -> WeakSet[_S]: ... # type: ignore[overload-overlap] @overload