-
Notifications
You must be signed in to change notification settings - Fork 28
Add get_model_state to get validated doc.
#309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,19 @@ | |||||
|
|
||||||
| from functools import partial | ||||||
| from inspect import iscoroutinefunction | ||||||
| from typing import Any, Awaitable, Callable, Generic, Iterable, Literal, Type, TypeVar, Union, cast, overload | ||||||
| from typing import ( | ||||||
| Any, | ||||||
| Awaitable, | ||||||
| Callable, | ||||||
| Generic, | ||||||
| Iterable, | ||||||
| Literal, | ||||||
| Type, | ||||||
| TypeVar, | ||||||
| Union, | ||||||
| cast, | ||||||
| overload, | ||||||
| ) | ||||||
|
|
||||||
| from anyio import BrokenResourceError, create_memory_object_stream | ||||||
| from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | ||||||
|
|
@@ -15,7 +27,9 @@ | |||||
| from ._transaction import NewTransaction, ReadTransaction, Transaction | ||||||
|
|
||||||
| T = TypeVar("T", bound=BaseType) | ||||||
| TransactionOrSubdocsEvent = TypeVar("TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent) | ||||||
| TransactionOrSubdocsEvent = TypeVar( | ||||||
| "TransactionOrSubdocsEvent", bound=TransactionEvent | SubdocsEvent | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class Doc(BaseDoc, Generic[T]): | ||||||
|
|
@@ -35,7 +49,7 @@ def __init__( | |||||
| client_id: int | None = None, | ||||||
| skip_gc: bool | None = None, | ||||||
| doc: _Doc | None = None, | ||||||
| Model=None, | ||||||
| Model: Any | None = None, | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| allow_multithreading: bool = False, | ||||||
| ) -> None: | ||||||
| """ | ||||||
|
|
@@ -47,8 +61,9 @@ def __init__( | |||||
| allow_multithreading: Whether to allow the document to be used in different threads. | ||||||
| """ | ||||||
| super().__init__( | ||||||
| client_id=client_id, skip_gc=skip_gc, doc=doc, Model=Model, allow_multithreading=allow_multithreading | ||||||
| client_id=client_id, skip_gc=skip_gc, doc=doc, allow_multithreading=allow_multithreading | ||||||
| ) | ||||||
| self._Model = Model | ||||||
| for k, v in init.items(): | ||||||
| self[k] = v | ||||||
| if Model is not None: | ||||||
|
|
@@ -150,6 +165,16 @@ def get_state(self) -> bytes: | |||||
| assert txn._txn is not None | ||||||
| return self._doc.get_state(txn._txn) | ||||||
|
|
||||||
| def get_model_state(self) -> Any: | ||||||
| if self._Model is None: | ||||||
| raise RuntimeError( | ||||||
| "no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)" | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| ) | ||||||
| with self.transaction() as txn: | ||||||
| assert txn._txn is not None | ||||||
| all_roots = self._doc.to_py(txn._txn) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering why you don't do the same as in d = {k: self._doc[k].to_py() for k in self._Model.model_fields}
self._Model.model_validate(d) |
||||||
| return self._Model.model_validate(all_roots) | ||||||
|
|
||||||
| def get_update(self, state: bytes | None = None) -> bytes: | ||||||
| """ | ||||||
| Args: | ||||||
|
|
@@ -174,7 +199,7 @@ def apply_update(self, update: bytes) -> None: | |||||
| twin_doc.apply_update(update) | ||||||
| d = {k: twin_doc[k].to_py() for k in self._Model.model_fields} | ||||||
| try: | ||||||
| self._Model(**d) | ||||||
| self._Model.model_validate(d) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the style change i refer to in the PR description. |
||||||
| except Exception as e: | ||||||
| self._twin_doc = Doc(dict(self)) | ||||||
| raise e | ||||||
|
|
@@ -292,7 +317,8 @@ def _roots(self) -> dict[str, T]: | |||||
|
|
||||||
| def observe( | ||||||
| self, | ||||||
| callback: Callable[[TransactionEvent], None] | Callable[[TransactionEvent], Awaitable[None]], | ||||||
| callback: Callable[[TransactionEvent], None] | ||||||
| | Callable[[TransactionEvent], Awaitable[None]], | ||||||
| ) -> Subscription: | ||||||
| """ | ||||||
| Subscribes a callback to be called with the document change event. | ||||||
|
|
@@ -405,7 +431,9 @@ async def main(): | |||||
| observe = self.observe_subdocs if subdocs else self.observe | ||||||
| if not self._send_streams[subdocs]: | ||||||
| if async_transactions: | ||||||
| self._event_subscription[subdocs] = observe(partial(self._async_send_event, subdocs)) | ||||||
| self._event_subscription[subdocs] = observe( | ||||||
| partial(self._async_send_event, subdocs) | ||||||
| ) | ||||||
| else: | ||||||
| self._event_subscription[subdocs] = observe(partial(self._send_event, subdocs)) | ||||||
| send_stream, receive_stream = create_memory_object_stream[ | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,21 @@ | ||
| use pyo3::prelude::*; | ||
| use pyo3::IntoPyObjectExt; | ||
| use pyo3::exceptions::{PyRuntimeError, PyValueError}; | ||
| use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList}; | ||
| use yrs::{ | ||
| Doc as _Doc, Options, ReadTxn, StateVector, SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, WriteTxn | ||
| }; | ||
| use yrs::updates::encoder::{Encode, Encoder}; | ||
| use yrs::updates::decoder::Decode; | ||
| use crate::text::Text; | ||
| use crate::array::Array; | ||
| use crate::map::Map; | ||
| use crate::transaction::Transaction; | ||
| use crate::subscription::Subscription; | ||
| use crate::text::Text; | ||
| use crate::transaction::Transaction; | ||
| use crate::type_conversions::ToPython; | ||
| use crate::xml::XmlFragment; | ||
|
|
||
| use pyo3::exceptions::{PyRuntimeError, PyValueError}; | ||
| use pyo3::prelude::*; | ||
| use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList}; | ||
| use pyo3::IntoPyObjectExt; | ||
| use yrs::updates::decoder::Decode; | ||
| use yrs::updates::encoder::{Encode, Encoder}; | ||
| use yrs::{ | ||
| Array as YArray, Doc as _Doc, GetString, Map as YMap, Options, ReadTxn, StateVector, | ||
| SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, | ||
| WriteTxn, | ||
| }; | ||
|
|
||
| #[pyclass] | ||
| #[derive(Clone)] | ||
|
|
@@ -41,7 +42,8 @@ impl Doc { | |
| let mut encoder = yrs::updates::encoder::EncoderV1::new(); | ||
| { | ||
| let txn = original.doc.transact(); | ||
| txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder).unwrap(); | ||
| txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder) | ||
| .unwrap(); | ||
| } | ||
| let update = yrs::Update::decode_v1(&encoder.to_vec()).unwrap(); | ||
| { | ||
|
|
@@ -53,11 +55,19 @@ impl Doc { | |
| let txn_orig = original.doc.transact(); | ||
| for (name, root) in txn_orig.root_refs() { | ||
| match root { | ||
| yrs::Out::YText(_) => { let _ = new_doc.get_or_insert_text(name); }, | ||
| yrs::Out::YArray(_) => { let _ = new_doc.get_or_insert_array(name); }, | ||
| yrs::Out::YMap(_) => { let _ = new_doc.get_or_insert_map(name); }, | ||
| yrs::Out::YXmlFragment(_) => { let _ = new_doc.get_or_insert_xml_fragment(name); }, | ||
| _ => {}, // ignore unknown types | ||
| yrs::Out::YText(_) => { | ||
| let _ = new_doc.get_or_insert_text(name); | ||
| } | ||
| yrs::Out::YArray(_) => { | ||
| let _ = new_doc.get_or_insert_array(name); | ||
| } | ||
| yrs::Out::YMap(_) => { | ||
| let _ = new_doc.get_or_insert_map(name); | ||
| } | ||
| yrs::Out::YXmlFragment(_) => { | ||
| let _ = new_doc.get_or_insert_xml_fragment(name); | ||
| } | ||
| _ => {} // ignore unknown types | ||
| } | ||
| } | ||
| drop(txn_orig); | ||
|
|
@@ -71,14 +81,16 @@ impl Doc { | |
| fn new(client_id: &Bound<'_, PyAny>, skip_gc: &Bound<'_, PyAny>) -> PyResult<Self> { | ||
| let mut options = Options::default(); | ||
| if !client_id.is_none() { | ||
| let _client_id: u64 = client_id.downcast::<PyInt>() | ||
| let _client_id: u64 = client_id | ||
| .downcast::<PyInt>() | ||
| .map_err(|_| PyValueError::new_err("client_id must be an integer"))? | ||
| .extract() | ||
| .map_err(|_| PyValueError::new_err("client_id must be a valid u64"))?; | ||
| options.client_id = _client_id; | ||
| } | ||
| if !skip_gc.is_none() { | ||
| let _skip_gc: bool = skip_gc.downcast::<PyBool>() | ||
| let _skip_gc: bool = skip_gc | ||
| .downcast::<PyBool>() | ||
| .map_err(|_| PyValueError::new_err("skip_gc must be a boolean"))? | ||
| .extract() | ||
| .map_err(|_| PyValueError::new_err("skip_gc must be a valid bool"))?; | ||
|
|
@@ -90,7 +102,11 @@ impl Doc { | |
|
|
||
| #[staticmethod] | ||
| #[pyo3(name = "from_snapshot")] | ||
| pub fn from_snapshot(py: Python<'_>, snapshot: PyRef<'_, crate::snapshot::Snapshot>, doc: PyRef<'_, Doc>) -> PyResult<Py<Doc>> { | ||
| pub fn from_snapshot( | ||
| py: Python<'_>, | ||
| snapshot: PyRef<'_, crate::snapshot::Snapshot>, | ||
| doc: PyRef<'_, Doc>, | ||
| ) -> PyResult<Py<Doc>> { | ||
| let restored = Doc::_from_snapshot_impl(&doc, &snapshot); | ||
| Py::new(py, restored) | ||
| } | ||
|
|
@@ -103,23 +119,38 @@ impl Doc { | |
| self.doc.client_id() | ||
| } | ||
|
|
||
| fn get_or_insert_text(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Text>> { | ||
| fn get_or_insert_text( | ||
| &mut self, | ||
| py: Python<'_>, | ||
| txn: &mut Transaction, | ||
| name: &str, | ||
| ) -> PyResult<Py<Text>> { | ||
| let mut _t = txn.transaction(); | ||
| let t = _t.as_mut().unwrap().as_mut(); | ||
| let text = t.get_or_insert_text(name); | ||
| let pytext: Py<Text> = Py::new(py, Text::from(text))?; | ||
| Ok(pytext) | ||
| } | ||
|
|
||
| fn get_or_insert_array(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Array>> { | ||
| fn get_or_insert_array( | ||
| &mut self, | ||
| py: Python<'_>, | ||
| txn: &mut Transaction, | ||
| name: &str, | ||
| ) -> PyResult<Py<Array>> { | ||
| let mut _t = txn.transaction(); | ||
| let t = _t.as_mut().unwrap().as_mut(); | ||
| let shared = t.get_or_insert_array(name); | ||
| let pyshared: Py<Array > = Py::new(py, Array::from(shared))?; | ||
| let pyshared: Py<Array> = Py::new(py, Array::from(shared))?; | ||
| Ok(pyshared) | ||
| } | ||
|
|
||
| fn get_or_insert_map(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Map>> { | ||
| fn get_or_insert_map( | ||
| &mut self, | ||
| py: Python<'_>, | ||
| txn: &mut Transaction, | ||
| name: &str, | ||
| ) -> PyResult<Py<Map>> { | ||
| let mut _t = txn.transaction(); | ||
| let t = _t.as_mut().unwrap().as_mut(); | ||
| let shared = t.get_or_insert_map(name); | ||
|
|
@@ -141,7 +172,11 @@ impl Doc { | |
| Err(PyRuntimeError::new_err("Already in a transaction")) | ||
| } | ||
|
|
||
| fn create_transaction_with_origin(&self, py: Python<'_>, origin: i128) -> PyResult<Py<Transaction>> { | ||
| fn create_transaction_with_origin( | ||
| &self, | ||
| py: Python<'_>, | ||
| origin: i128, | ||
| ) -> PyResult<Py<Transaction>> { | ||
| if let Ok(txn) = self.doc.try_transact_mut_with(origin) { | ||
| let t: Py<Transaction> = Py::new(py, Transaction::from(txn))?; | ||
| return Ok(t); | ||
|
|
@@ -160,7 +195,9 @@ impl Doc { | |
| let mut _t = txn.transaction(); | ||
| let t = _t.as_mut().unwrap().as_mut(); | ||
| let state: &[u8] = state.extract()?; | ||
| let Ok(state_vector) = StateVector::decode_v1(&state) else { return Err(PyValueError::new_err("Cannot decode state")) }; | ||
| let Ok(state_vector) = StateVector::decode_v1(&state) else { | ||
| return Err(PyValueError::new_err("Cannot decode state")); | ||
| }; | ||
| let update = t.encode_diff_v1(&state_vector); | ||
| let bytes: Py<PyAny> = Python::attach(|py| PyBytes::new(py, &update).into()); | ||
| Ok(bytes) | ||
|
|
@@ -186,8 +223,53 @@ impl Doc { | |
| result.into() | ||
| } | ||
|
|
||
| fn to_py(&self, py: Python<'_>, txn: &mut Transaction) -> PyResult<Py<PyAny>> { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this convert nested shared data to Python too? |
||
| let mut _t = txn.transaction(); | ||
| let t = _t.as_mut().unwrap().as_mut(); | ||
| let result = PyDict::new(py); | ||
|
|
||
| let roots_info: Vec<_> = t | ||
| .root_refs() | ||
| .map(|(name, root)| (name.to_string(), root)) | ||
| .collect(); | ||
|
|
||
| for (name, root) in roots_info { | ||
| match root { | ||
| yrs::Out::YText(_) => { | ||
| let text = t.get_or_insert_text(name.as_str()); | ||
| let value = text.get_string(t); | ||
| result.set_item(name, value)?; | ||
| } | ||
| yrs::Out::YArray(_) => { | ||
| let array = t.get_or_insert_array(name.as_str()); | ||
| let list = PyList::empty(py); | ||
| for item in array.iter(t) { | ||
| list.append(item.into_py(py))?; | ||
| } | ||
| result.set_item(name, list)?; | ||
| } | ||
| yrs::Out::YMap(_) => { | ||
| let map = t.get_or_insert_map(name.as_str()); | ||
| let dict = PyDict::new(py); | ||
| for (key, value) in map.iter(t) { | ||
| dict.set_item(key, value.into_py(py))?; | ||
| } | ||
| result.set_item(name, dict)?; | ||
| } | ||
| yrs::Out::YXmlFragment(_) => { | ||
| let xml = t.get_or_insert_xml_fragment(name.as_str()); | ||
| let xml_py = Py::new(py, XmlFragment::from(xml))?; | ||
| result.set_item(name, xml_py)?; | ||
| } | ||
| _ => {} // ignore other types | ||
| } | ||
| } | ||
| Ok(result.into()) | ||
| } | ||
|
|
||
| pub fn observe(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> { | ||
| let sub = self.doc | ||
| let sub = self | ||
| .doc | ||
| .observe_transaction_cleanup(move |txn, event| { | ||
| if !event.delete_set.is_empty() || event.before_state != event.after_state { | ||
| Python::attach(|py| { | ||
|
|
@@ -204,7 +286,8 @@ impl Doc { | |
| } | ||
|
|
||
| pub fn observe_subdocs(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> { | ||
| let sub = self.doc | ||
| let sub = self | ||
| .doc | ||
| .observe_subdocs(move |_, event| { | ||
| Python::attach(|py| { | ||
| let event = SubdocsEvent::new(py, event); | ||
|
|
@@ -326,11 +409,20 @@ pub struct SubdocsEvent { | |
|
|
||
| impl SubdocsEvent { | ||
| fn new<'py>(py: Python<'py>, event: &_SubdocsEvent) -> Self { | ||
| let added: Vec<String> = event.added().map(|d| d.guid().clone().to_string()).collect(); | ||
| let added: Vec<String> = event | ||
| .added() | ||
| .map(|d| d.guid().clone().to_string()) | ||
| .collect(); | ||
| let added = PyList::new(py, added).unwrap().into_py_any(py).unwrap(); | ||
| let removed: Vec<String> = event.removed().map(|d| d.guid().clone().to_string()).collect(); | ||
| let removed: Vec<String> = event | ||
| .removed() | ||
| .map(|d| d.guid().clone().to_string()) | ||
| .collect(); | ||
| let removed = PyList::new(py, removed).unwrap().into_py_any(py).unwrap(); | ||
| let loaded: Vec<String> = event.loaded().map(|d| d.guid().clone().to_string()).collect(); | ||
| let loaded: Vec<String> = event | ||
| .loaded() | ||
| .map(|d| d.guid().clone().to_string()) | ||
| .collect(); | ||
| let loaded = PyList::new(py, loaded).unwrap().into_py_any(py).unwrap(); | ||
| SubdocsEvent { | ||
| added, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're introducing
ruff, maybe we should usepre-committo check and lint?