Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ test = [
"mypy",
"coverage[toml] >=7",
"exceptiongroup; python_version<'3.11'",
"ruff>=0.13.3",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're introducing ruff, maybe we should use pre-commit to check and lint?

]
docs = [
"mkdocs",
Expand Down
1 change: 0 additions & 1 deletion python/pycrdt/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class BaseDoc:
_txn_lock: threading.Lock
_txn_async_lock: anyio.Lock
_allow_multithreading: bool
_Model: Any
_subscriptions: list[Subscription]
_origins: dict[int, Any]
_task_group: TaskGroup | None
Expand Down
40 changes: 33 additions & 7 deletions python/pycrdt/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@

from functools import partial
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, Type, TypeVar, Union, cast, overload
from typing import (
Any,
Awaitable,
Callable,
Generic,
Iterable,
Literal,
Type,
TypeVar,
Union,
cast,
overload,
)

from anyio import BrokenResourceError, create_memory_object_stream
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand All @@ -15,7 +27,9 @@
from ._transaction import NewTransaction, ReadTransaction, Transaction

T = TypeVar("T", bound=BaseType)
TransactionOrSubdocsEvent = TypeVar("TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent)
TransactionOrSubdocsEvent = TypeVar(
"TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent
)


class Doc(BaseDoc, Generic[T]):
Expand All @@ -35,7 +49,7 @@ def __init__(
client_id: int | None = None,
skip_gc: bool | None = None,
doc: _Doc | None = None,
Model=None,
Model: Any | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Model: Any | None = None,
Model: Any = None,

allow_multithreading: bool = False,
) -> None:
"""
Expand All @@ -47,8 +61,9 @@ def __init__(
allow_multithreading: Whether to allow the document to be used in different threads.
"""
super().__init__(
client_id=client_id, skip_gc=skip_gc, doc=doc, Model=Model, allow_multithreading=allow_multithreading
client_id=client_id, skip_gc=skip_gc, doc=doc, allow_multithreading=allow_multithreading
)
self._Model = Model
for k, v in init.items():
self[k] = v
if Model is not None:
Expand Down Expand Up @@ -150,6 +165,14 @@ def get_state(self) -> bytes:
assert txn._txn is not None
return self._doc.get_state(txn._txn)

def get_model_state(self) -> Any:
if self._Model is None:
raise RuntimeError(
"no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)"
"Document has no model"

)
d = {k: self[k].to_py() for k in self._Model.model_fields}
return self._Model.model_validate(d)

def get_update(self, state: bytes | None = None) -> bytes:
"""
Args:
Expand All @@ -174,7 +197,7 @@ def apply_update(self, update: bytes) -> None:
twin_doc.apply_update(update)
d = {k: twin_doc[k].to_py() for k in self._Model.model_fields}
try:
self._Model(**d)
self._Model.model_validate(d)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the style change i refer to in the PR description.

except Exception as e:
self._twin_doc = Doc(dict(self))
raise e
Expand Down Expand Up @@ -292,7 +315,8 @@ def _roots(self) -> dict[str, T]:

def observe(
self,
callback: Callable[[TransactionEvent], None] | Callable[[TransactionEvent], Awaitable[None]],
callback: Callable[[TransactionEvent], None]
| Callable[[TransactionEvent], Awaitable[None]],
) -> Subscription:
"""
Subscribes a callback to be called with the document change event.
Expand Down Expand Up @@ -405,7 +429,9 @@ async def main():
observe = self.observe_subdocs if subdocs else self.observe
if not self._send_streams[subdocs]:
if async_transactions:
self._event_subscription[subdocs] = observe(partial(self._async_send_event, subdocs))
self._event_subscription[subdocs] = observe(
partial(self._async_send_event, subdocs)
)
else:
self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs))
send_stream, receive_stream = create_memory_object_stream[
Expand Down
4 changes: 1 addition & 3 deletions python/pycrdt/_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,7 @@ def insert_embed(self, index: int, value: Any, attrs: dict[str, Any] | None = No
self._do_and_integrate("insert", value, txn._txn, index, _attrs)
else:
# primitive type
self.integrated.insert_embed(
txn._txn, index, value, _attrs
)
self.integrated.insert_embed(txn._txn, index, value, _attrs)

def format(self, start: int, stop: int, attrs: dict[str, Any]) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_get_update_exception():
def test_apply_update_exception():
doc = Doc()
with pytest.raises(ValueError) as excinfo:
doc.apply_update(b"\xFF\xFF\xFF\xFF")
doc.apply_update(b"\xff\xff\xff\xff")
assert "Cannot decode update" in str(excinfo.value)


Expand Down
9 changes: 8 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Tuple

import pytest
Expand Down Expand Up @@ -52,3 +52,10 @@ class Delivery(BaseModel):

assert str(local_doc["timestamp"]) == "2020-02-02T03:04:05Z"
assert list(local_doc["dimensions"]) == ["10", "30"]

decoded = local_doc.get_model_state()
assert decoded.timestamp == datetime(2020, 2, 2, 3, 4, 5, tzinfo=timezone.utc)
assert decoded.dimensions == (
10,
30,
)
1 change: 0 additions & 1 deletion tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import platform
import sys
import time
from functools import partial
Expand Down
Loading