Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ test = [
"mypy",
"coverage[toml] >=7",
"exceptiongroup; python_version<'3.11'",
"ruff>=0.13.3",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're introducing ruff, maybe we should use pre-commit to check and lint?

]
docs = [
"mkdocs",
Expand Down
1 change: 0 additions & 1 deletion python/pycrdt/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class BaseDoc:
_txn_lock: threading.Lock
_txn_async_lock: anyio.Lock
_allow_multithreading: bool
_Model: Any
_subscriptions: list[Subscription]
_origins: dict[int, Any]
_task_group: TaskGroup | None
Expand Down
42 changes: 35 additions & 7 deletions python/pycrdt/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@

from functools import partial
from inspect import iscoroutinefunction
from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, Type, TypeVar, Union, cast, overload
from typing import (
Any,
Awaitable,
Callable,
Generic,
Iterable,
Literal,
Type,
TypeVar,
Union,
cast,
overload,
)

from anyio import BrokenResourceError, create_memory_object_stream
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
Expand All @@ -15,7 +27,9 @@
from ._transaction import NewTransaction, ReadTransaction, Transaction

T = TypeVar("T", bound=BaseType)
TransactionOrSubdocsEvent = TypeVar("TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent)
TransactionOrSubdocsEvent = TypeVar(
"TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent
)


class Doc(BaseDoc, Generic[T]):
Expand All @@ -35,7 +49,7 @@ def __init__(
client_id: int | None = None,
skip_gc: bool | None = None,
doc: _Doc | None = None,
Model=None,
Model: Any | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Model: Any | None = None,
Model: Any = None,

allow_multithreading: bool = False,
) -> None:
"""
Expand All @@ -47,8 +61,9 @@ def __init__(
allow_multithreading: Whether to allow the document to be used in different threads.
"""
super().__init__(
client_id=client_id, skip_gc=skip_gc, doc=doc, Model=Model, allow_multithreading=allow_multithreading
client_id=client_id, skip_gc=skip_gc, doc=doc, allow_multithreading=allow_multithreading
)
self._Model = Model
for k, v in init.items():
self[k] = v
if Model is not None:
Expand Down Expand Up @@ -150,6 +165,16 @@ def get_state(self) -> bytes:
assert txn._txn is not None
return self._doc.get_state(txn._txn)

def get_model_state(self) -> Any:
if self._Model is None:
raise RuntimeError(
"no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)"
"Document has no model"

)
with self.transaction() as txn:
assert txn._txn is not None
all_roots = self._doc.to_py(txn._txn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why you don't do the same as in Doc.apply_update?:

d = {k: self._doc[k].to_py() for k in self._Model.model_fields}
self._Model.model_validate(d)

return self._Model.model_validate(all_roots)

def get_update(self, state: bytes | None = None) -> bytes:
"""
Args:
Expand All @@ -174,7 +199,7 @@ def apply_update(self, update: bytes) -> None:
twin_doc.apply_update(update)
d = {k: twin_doc[k].to_py() for k in self._Model.model_fields}
try:
self._Model(**d)
self._Model.model_validate(d)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the style change i refer to in the PR description.

except Exception as e:
self._twin_doc = Doc(dict(self))
raise e
Expand Down Expand Up @@ -292,7 +317,8 @@ def _roots(self) -> dict[str, T]:

def observe(
self,
callback: Callable[[TransactionEvent], None] | Callable[[TransactionEvent], Awaitable[None]],
callback: Callable[[TransactionEvent], None]
| Callable[[TransactionEvent], Awaitable[None]],
) -> Subscription:
"""
Subscribes a callback to be called with the document change event.
Expand Down Expand Up @@ -405,7 +431,9 @@ async def main():
observe = self.observe_subdocs if subdocs else self.observe
if not self._send_streams[subdocs]:
if async_transactions:
self._event_subscription[subdocs] = observe(partial(self._async_send_event, subdocs))
self._event_subscription[subdocs] = observe(
partial(self._async_send_event, subdocs)
)
else:
self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs))
send_stream, receive_stream = create_memory_object_stream[
Expand Down
3 changes: 3 additions & 0 deletions python/pycrdt/_pycrdt.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class Doc:
def roots(self, txn: Transaction) -> dict[str, Text | Array | Map]:
"""Get top-level (root) shared types available in current document."""

def to_py(self, txn: Transaction) -> dict[str, Any]:
"""Get top-level (root) shared types as native Python objects."""

def observe(self, callback: Callable[[TransactionEvent], None]) -> Subscription:
"""Subscribes a callback to be called with the shared document change event.
Returns a subscription that can be used to unsubscribe."""
Expand Down
4 changes: 1 addition & 3 deletions python/pycrdt/_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,7 @@ def insert_embed(self, index: int, value: Any, attrs: dict[str, Any] | None = No
self._do_and_integrate("insert", value, txn._txn, index, _attrs)
else:
# primitive type
self.integrated.insert_embed(
txn._txn, index, value, _attrs
)
self.integrated.insert_embed(txn._txn, index, value, _attrs)

def format(self, start: int, stop: int, attrs: dict[str, Any]) -> None:
"""
Expand Down
156 changes: 124 additions & 32 deletions src/doc.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use pyo3::prelude::*;
use pyo3::IntoPyObjectExt;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList};
use yrs::{
Doc as _Doc, Options, ReadTxn, StateVector, SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, WriteTxn
};
use yrs::updates::encoder::{Encode, Encoder};
use yrs::updates::decoder::Decode;
use crate::text::Text;
use crate::array::Array;
use crate::map::Map;
use crate::transaction::Transaction;
use crate::subscription::Subscription;
use crate::text::Text;
use crate::transaction::Transaction;
use crate::type_conversions::ToPython;
use crate::xml::XmlFragment;

use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList};
use pyo3::IntoPyObjectExt;
use yrs::updates::decoder::Decode;
use yrs::updates::encoder::{Encode, Encoder};
use yrs::{
Array as YArray, Doc as _Doc, GetString, Map as YMap, Options, ReadTxn, StateVector,
SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update,
WriteTxn,
};

#[pyclass]
#[derive(Clone)]
Expand All @@ -41,7 +42,8 @@ impl Doc {
let mut encoder = yrs::updates::encoder::EncoderV1::new();
{
let txn = original.doc.transact();
txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder).unwrap();
txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder)
.unwrap();
}
let update = yrs::Update::decode_v1(&encoder.to_vec()).unwrap();
{
Expand All @@ -53,11 +55,19 @@ impl Doc {
let txn_orig = original.doc.transact();
for (name, root) in txn_orig.root_refs() {
match root {
yrs::Out::YText(_) => { let _ = new_doc.get_or_insert_text(name); },
yrs::Out::YArray(_) => { let _ = new_doc.get_or_insert_array(name); },
yrs::Out::YMap(_) => { let _ = new_doc.get_or_insert_map(name); },
yrs::Out::YXmlFragment(_) => { let _ = new_doc.get_or_insert_xml_fragment(name); },
_ => {}, // ignore unknown types
yrs::Out::YText(_) => {
let _ = new_doc.get_or_insert_text(name);
}
yrs::Out::YArray(_) => {
let _ = new_doc.get_or_insert_array(name);
}
yrs::Out::YMap(_) => {
let _ = new_doc.get_or_insert_map(name);
}
yrs::Out::YXmlFragment(_) => {
let _ = new_doc.get_or_insert_xml_fragment(name);
}
_ => {} // ignore unknown types
}
}
drop(txn_orig);
Expand All @@ -71,14 +81,16 @@ impl Doc {
fn new(client_id: &Bound<'_, PyAny>, skip_gc: &Bound<'_, PyAny>) -> PyResult<Self> {
let mut options = Options::default();
if !client_id.is_none() {
let _client_id: u64 = client_id.downcast::<PyInt>()
let _client_id: u64 = client_id
.downcast::<PyInt>()
.map_err(|_| PyValueError::new_err("client_id must be an integer"))?
.extract()
.map_err(|_| PyValueError::new_err("client_id must be a valid u64"))?;
options.client_id = _client_id;
}
if !skip_gc.is_none() {
let _skip_gc: bool = skip_gc.downcast::<PyBool>()
let _skip_gc: bool = skip_gc
.downcast::<PyBool>()
.map_err(|_| PyValueError::new_err("skip_gc must be a boolean"))?
.extract()
.map_err(|_| PyValueError::new_err("skip_gc must be a valid bool"))?;
Expand All @@ -90,7 +102,11 @@ impl Doc {

#[staticmethod]
#[pyo3(name = "from_snapshot")]
pub fn from_snapshot(py: Python<'_>, snapshot: PyRef<'_, crate::snapshot::Snapshot>, doc: PyRef<'_, Doc>) -> PyResult<Py<Doc>> {
pub fn from_snapshot(
py: Python<'_>,
snapshot: PyRef<'_, crate::snapshot::Snapshot>,
doc: PyRef<'_, Doc>,
) -> PyResult<Py<Doc>> {
let restored = Doc::_from_snapshot_impl(&doc, &snapshot);
Py::new(py, restored)
}
Expand All @@ -103,23 +119,38 @@ impl Doc {
self.doc.client_id()
}

fn get_or_insert_text(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Text>> {
fn get_or_insert_text(
&mut self,
py: Python<'_>,
txn: &mut Transaction,
name: &str,
) -> PyResult<Py<Text>> {
let mut _t = txn.transaction();
let t = _t.as_mut().unwrap().as_mut();
let text = t.get_or_insert_text(name);
let pytext: Py<Text> = Py::new(py, Text::from(text))?;
Ok(pytext)
}

fn get_or_insert_array(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Array>> {
fn get_or_insert_array(
&mut self,
py: Python<'_>,
txn: &mut Transaction,
name: &str,
) -> PyResult<Py<Array>> {
let mut _t = txn.transaction();
let t = _t.as_mut().unwrap().as_mut();
let shared = t.get_or_insert_array(name);
let pyshared: Py<Array > = Py::new(py, Array::from(shared))?;
let pyshared: Py<Array> = Py::new(py, Array::from(shared))?;
Ok(pyshared)
}

fn get_or_insert_map(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Map>> {
fn get_or_insert_map(
&mut self,
py: Python<'_>,
txn: &mut Transaction,
name: &str,
) -> PyResult<Py<Map>> {
let mut _t = txn.transaction();
let t = _t.as_mut().unwrap().as_mut();
let shared = t.get_or_insert_map(name);
Expand All @@ -141,7 +172,11 @@ impl Doc {
Err(PyRuntimeError::new_err("Already in a transaction"))
}

fn create_transaction_with_origin(&self, py: Python<'_>, origin: i128) -> PyResult<Py<Transaction>> {
fn create_transaction_with_origin(
&self,
py: Python<'_>,
origin: i128,
) -> PyResult<Py<Transaction>> {
if let Ok(txn) = self.doc.try_transact_mut_with(origin) {
let t: Py<Transaction> = Py::new(py, Transaction::from(txn))?;
return Ok(t);
Expand All @@ -160,7 +195,9 @@ impl Doc {
let mut _t = txn.transaction();
let t = _t.as_mut().unwrap().as_mut();
let state: &[u8] = state.extract()?;
let Ok(state_vector) = StateVector::decode_v1(&state) else { return Err(PyValueError::new_err("Cannot decode state")) };
let Ok(state_vector) = StateVector::decode_v1(&state) else {
return Err(PyValueError::new_err("Cannot decode state"));
};
let update = t.encode_diff_v1(&state_vector);
let bytes: Py<PyAny> = Python::attach(|py| PyBytes::new(py, &update).into());
Ok(bytes)
Expand All @@ -186,8 +223,53 @@ impl Doc {
result.into()
}

fn to_py(&self, py: Python<'_>, txn: &mut Transaction) -> PyResult<Py<PyAny>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this convert nested shared data to Python too?

let mut _t = txn.transaction();
let t = _t.as_mut().unwrap().as_mut();
let result = PyDict::new(py);

let roots_info: Vec<_> = t
.root_refs()
.map(|(name, root)| (name.to_string(), root))
.collect();

for (name, root) in roots_info {
match root {
yrs::Out::YText(_) => {
let text = t.get_or_insert_text(name.as_str());
let value = text.get_string(t);
result.set_item(name, value)?;
}
yrs::Out::YArray(_) => {
let array = t.get_or_insert_array(name.as_str());
let list = PyList::empty(py);
for item in array.iter(t) {
list.append(item.into_py(py))?;
}
result.set_item(name, list)?;
}
yrs::Out::YMap(_) => {
let map = t.get_or_insert_map(name.as_str());
let dict = PyDict::new(py);
for (key, value) in map.iter(t) {
dict.set_item(key, value.into_py(py))?;
}
result.set_item(name, dict)?;
}
yrs::Out::YXmlFragment(_) => {
let xml = t.get_or_insert_xml_fragment(name.as_str());
let xml_py = Py::new(py, XmlFragment::from(xml))?;
result.set_item(name, xml_py)?;
}
_ => {} // ignore other types
}
}
Ok(result.into())
}

pub fn observe(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> {
let sub = self.doc
let sub = self
.doc
.observe_transaction_cleanup(move |txn, event| {
if !event.delete_set.is_empty() || event.before_state != event.after_state {
Python::attach(|py| {
Expand All @@ -204,7 +286,8 @@ impl Doc {
}

pub fn observe_subdocs(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> {
let sub = self.doc
let sub = self
.doc
.observe_subdocs(move |_, event| {
Python::attach(|py| {
let event = SubdocsEvent::new(py, event);
Expand Down Expand Up @@ -326,11 +409,20 @@ pub struct SubdocsEvent {

impl SubdocsEvent {
fn new<'py>(py: Python<'py>, event: &_SubdocsEvent) -> Self {
let added: Vec<String> = event.added().map(|d| d.guid().clone().to_string()).collect();
let added: Vec<String> = event
.added()
.map(|d| d.guid().clone().to_string())
.collect();
let added = PyList::new(py, added).unwrap().into_py_any(py).unwrap();
let removed: Vec<String> = event.removed().map(|d| d.guid().clone().to_string()).collect();
let removed: Vec<String> = event
.removed()
.map(|d| d.guid().clone().to_string())
.collect();
let removed = PyList::new(py, removed).unwrap().into_py_any(py).unwrap();
let loaded: Vec<String> = event.loaded().map(|d| d.guid().clone().to_string()).collect();
let loaded: Vec<String> = event
.loaded()
.map(|d| d.guid().clone().to_string())
.collect();
let loaded = PyList::new(py, loaded).unwrap().into_py_any(py).unwrap();
SubdocsEvent {
added,
Expand Down
Loading
Loading