diff --git a/bson/binary.py b/bson/binary.py index 48eb12b0ac..ad8337032c 100644 --- a/bson/binary.py +++ b/bson/binary.py @@ -65,6 +65,9 @@ from array import array as _array from mmap import mmap as _mmap + import numpy as np + import numpy.typing as npt + class UuidRepresentation: UNSPECIFIED = 0 @@ -234,13 +237,20 @@ class BinaryVector: __slots__ = ("data", "dtype", "padding") - def __init__(self, data: Sequence[float | int], dtype: BinaryVectorDtype, padding: int = 0): + def __init__( + self, + data: Union[Sequence[float | int], npt.NDArray[np.number]], + dtype: BinaryVectorDtype, + padding: int = 0, + ): """ :param data: Sequence of numbers representing the mathematical vector. :param dtype: The data type stored in binary :param padding: The number of bits in the final byte that are to be ignored when a vector element's size is less than a byte and the length of the vector is not a multiple of 8. + (Padding is equivalent to a negative value of `count` in + `numpy.unpackbits `_) """ self.data = data self.dtype = dtype @@ -424,10 +434,20 @@ def from_vector( ) -> Binary: ... + @classmethod + @overload + def from_vector( + cls: Type[Binary], + vector: npt.NDArray[np.number], + dtype: BinaryVectorDtype, + padding: int = 0, + ) -> Binary: + ... + @classmethod def from_vector( cls: Type[Binary], - vector: Union[BinaryVector, list[int], list[float]], + vector: Union[BinaryVector, list[int], list[float], npt.NDArray[np.number]], dtype: Optional[BinaryVectorDtype] = None, padding: Optional[int] = None, ) -> Binary: @@ -459,25 +479,60 @@ def from_vector( vector = vector.data # type: ignore padding = 0 if padding is None else padding - if dtype == BinaryVectorDtype.INT8: # pack ints in [-128, 127] as signed int8 - format_str = "b" - if padding: - raise ValueError(f"padding does not apply to {dtype=}") - elif dtype == BinaryVectorDtype.PACKED_BIT: # pack ints in [0, 255] as unsigned uint8 - format_str = "B" - if 0 <= padding > 7: - raise ValueError(f"{padding=}. It must be in [0,1, ..7].") - if padding and not vector: - raise ValueError("Empty vector with non-zero padding.") - elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 - format_str = "f" - if padding: - raise ValueError(f"padding does not apply to {dtype=}") - else: - raise NotImplementedError("%s not yet supported" % dtype) - + if not isinstance(dtype, BinaryVectorDtype): + raise TypeError( + "dtype must be a bson.BinaryVectorDtype, such as BinaryVectorDtype.FLOAT32" + ) metadata = struct.pack(" 7: + raise ValueError(f"{padding=}. It must be in [0,1, ..7].") + if padding and not vector: + raise ValueError("Empty vector with non-zero padding.") + elif dtype == BinaryVectorDtype.FLOAT32: # pack floats as float32 + format_str = "f" + if padding: + raise ValueError(f"padding does not apply to {dtype=}") + else: + raise NotImplementedError("%s not yet supported" % dtype) + data = struct.pack(f"<{len(vector)}{format_str}", *vector) + else: # vector is numpy array or incorrect type. + try: + import numpy as np + except ImportError as exc: + raise ImportError( + "Failed to create binary from vector. Check type. If numpy array, numpy must be installed." + ) from exc + if not isinstance(vector, np.ndarray): + raise TypeError("Vector must be a numpy array.") + if vector.ndim != 1: + raise ValueError( + "from_numpy_vector only supports 1D arrays as it creates a single vector." + ) + + if dtype == BinaryVectorDtype.FLOAT32: + vector = vector.astype(np.dtype("float32"), copy=False) + elif dtype == BinaryVectorDtype.INT8: + if vector.min() >= -128 and vector.max() <= 127: + vector = vector.astype(np.dtype("int8"), copy=False) + else: + raise ValueError("Values found outside INT8 range.") + elif dtype == BinaryVectorDtype.PACKED_BIT: + if vector.min() >= 0 and vector.max() <= 127: + vector = vector.astype(np.dtype("uint8"), copy=False) + else: + raise ValueError("Values found outside UINT8 range.") + else: + raise NotImplementedError("%s not yet supported" % dtype) + data = vector.tobytes() + if padding and len(vector) and not (data[-1] & ((1 << padding) - 1)) == 0: raise ValueError( "Vector has a padding P, but bits in the final byte lower than P are non-zero. They must be zero." @@ -549,6 +604,54 @@ def subtype(self) -> int: """Subtype of this binary data.""" return self.__subtype + def as_numpy_vector(self) -> BinaryVector: + """From the Binary, create a BinaryVector where data is a 1-dim numpy array. + dtype still follows our typing (BinaryVectorDtype), + and padding is as we define it, notably equivalent to a negative value of count + in `numpy.unpackbits `_. + + :return: BinaryVector + + .. versionadded:: 4.16 + """ + if self.subtype != VECTOR_SUBTYPE: + raise ValueError(f"Cannot decode subtype {self.subtype} as a vector") + try: + import numpy as np + except ImportError as exc: + raise ImportError( + "Converting binary to numpy.ndarray requires numpy to be installed." + ) from exc + + dtype, padding = struct.unpack_from(" 7 or padding < 0: + raise ValueError(f"Corrupt data. Padding ({padding}) must be between 0 and 7.") + data = np.frombuffer(self[2:], dtype="uint8") + if padding and np.unpackbits(data[-1])[-padding:].sum() > 0: + warnings.warn( + "Vector has a padding P, but bits in the final byte lower than P are non-zero. For pymongo>=5.0, they must be zero.", + DeprecationWarning, + stacklevel=2, + ) + else: + raise ValueError(f"Unsupported dtype code: {dtype!r}") + return BinaryVector(data, dtype, padding) + def __getnewargs__(self) -> Tuple[bytes, int]: # type: ignore[override] # Work around http://bugs.python.org/issue7382 data = super().__getnewargs__()[0] diff --git a/doc/changelog.rst b/doc/changelog.rst index f3eb4f6f23..e7a642aeed 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -16,6 +16,7 @@ PyMongo 4.16 brings a number of changes including: Python 3.10+. The minimum version is ``2.6.1`` to account for `CVE-2023-29483 `_. - Removed support for Eventlet. Eventlet is actively being sunset by its maintainers and has compatibility issues with PyMongo's dnspython dependency. +- Added support for NumPy 1D-arrays in BSON Binary Vectors. Changes in Version 4.15.3 (2025/10/07) -------------------------------------- diff --git a/justfile b/justfile index 17b95e87b7..50957149e8 100644 --- a/justfile +++ b/justfile @@ -2,7 +2,7 @@ set shell := ["bash", "-c"] # Commonly used command segments. -typing_run := "uv run --group typing --extra aws --extra encryption --extra ocsp --extra snappy --extra test --extra zstd" +typing_run := "uv run --group typing --extra aws --extra encryption --with numpy --extra ocsp --extra snappy --extra test --extra zstd" docs_run := "uv run --extra docs" doc_build := "./doc/_build" mypy_args := "--install-types --non-interactive" @@ -39,14 +39,14 @@ typing: && resync [group('typing')] typing-mypy: && resync - {{typing_run}} mypy {{mypy_args}} bson gridfs tools pymongo - {{typing_run}} mypy {{mypy_args}} --config-file mypy_test.ini test - {{typing_run}} mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py + {{typing_run}} python -m mypy {{mypy_args}} bson gridfs tools pymongo + {{typing_run}} python -m mypy {{mypy_args}} --config-file mypy_test.ini test + {{typing_run}} python -m mypy {{mypy_args}} test/test_typing.py test/test_typing_strict.py [group('typing')] typing-pyright: && resync - {{typing_run}} pyright test/test_typing.py test/test_typing_strict.py - {{typing_run}} pyright -p strict_pyrightconfig.json test/test_typing_strict.py + {{typing_run}} python -m pyright test/test_typing.py test/test_typing_strict.py + {{typing_run}} python -m pyright -p strict_pyrightconfig.json test/test_typing_strict.py [group('lint')] lint *args="": && resync @@ -58,7 +58,13 @@ lint-manual *args="": && resync [group('test')] test *args="-v --durations=5 --maxfail=10": && resync - uv run --extra test pytest {{args}} + uv run --extra test python -m pytest {{args}} + +[group('test')] +test-bson *args="-v --durations=5 --maxfail=10": && resync + uv run --extra test --with numpy python -m pytest test/test_bson.py + + [group('test')] run-tests *args: && resync diff --git a/test/test_bson.py b/test/test_bson.py index f792db1e89..cc670bc228 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -19,6 +19,7 @@ import array import collections import datetime +import importlib.util import mmap import os import pickle @@ -71,6 +72,8 @@ from bson.timestamp import Timestamp from bson.tz_util import FixedOffset, utc +_NUMPY_AVAILABLE = importlib.util.find_spec("numpy") is not None + class NotADict(abc.MutableMapping): """Non-dict type that implements the mapping protocol.""" @@ -871,6 +874,62 @@ def test_binaryvector_equality(self): BinaryVector([1], BinaryVectorDtype.INT8), BinaryVector([2], BinaryVectorDtype.INT8) ) + @unittest.skipIf(not _NUMPY_AVAILABLE, "numpy optional-dependency not installed.") + def test_vector_from_numpy(self): + """Follows test_vector except for input type numpy.ndarray""" + # Simple data values could be treated as any of our BinaryVectorDtypes + import numpy as np + + arr = np.array([2, 3]) + # INT8 + binary_vector_int8 = Binary.from_vector(arr, BinaryVectorDtype.INT8) + # as_vector + vector = binary_vector_int8.as_vector() + assert isinstance(vector, BinaryVector) + assert vector.data == arr.tolist() + # as_numpy_vector + vector_np = binary_vector_int8.as_numpy_vector() + assert isinstance(vector_np, BinaryVector) + assert np.all(vector.data == arr) + # PACKED_BIT + binary_vector_uint8 = Binary.from_vector(arr, BinaryVectorDtype.PACKED_BIT) + # as_vector + vector = binary_vector_uint8.as_vector() + assert isinstance(vector, BinaryVector) + assert vector.data == arr.tolist() + # as_numpy_vector + vector_np = binary_vector_uint8.as_numpy_vector() + assert isinstance(vector_np, BinaryVector) + assert np.all(vector_np.data == arr) + # FLOAT32 + binary_vector_float32 = Binary.from_vector(arr, BinaryVectorDtype.FLOAT32) + # as_vector + vector = binary_vector_float32.as_vector() + assert isinstance(vector, BinaryVector) + assert vector.data == arr.tolist() + # as_numpy_vector + vector_np = binary_vector_float32.as_numpy_vector() + assert isinstance(vector_np, BinaryVector) + assert np.all(vector_np.data == arr) + + # Invalid cases + with self.assertRaises(ValueError): + Binary.from_vector(np.array([-1]), BinaryVectorDtype.PACKED_BIT) + with self.assertRaises(ValueError): + Binary.from_vector(np.array([128]), BinaryVectorDtype.PACKED_BIT) + with self.assertRaises(ValueError): + Binary.from_vector(np.array([-198]), BinaryVectorDtype.INT8) + + # Unexpected cases + # Creating a vector of INT8 from a list of doubles will be caught by struct.pack + # Numpy's default behavior is to cast to the type requested. + list_floats = [-1.1, 1.1] + cast_bin = Binary.from_vector(np.array(list_floats), BinaryVectorDtype.INT8) + vector = cast_bin.as_vector() + vector_np = cast_bin.as_numpy_vector() + assert vector.data != list_floats + assert vector.data == vector_np.data.tolist() == [-1, 1] + def test_unicode_regex(self): """Tests we do not get a segfault for C extension on unicode RegExs. This had been happening.