Skip to content

Commit 1aa8dfc

Browse files
committed
Implement to_py for Doc.
This iterates through the roots of the Doc, converting them into their native python types (using the underlying type's to_py() fn). tbd if this can be used to replace `_roots` as well?
1 parent 663d169 commit 1aa8dfc

File tree

3 files changed

+131
-34
lines changed

3 files changed

+131
-34
lines changed

python/pycrdt/_doc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,10 @@ def get_model_state(self) -> Any:
170170
raise RuntimeError(
171171
"no Model defined for doc. Instantiate Doc with Doc(Model=PydanticModel)"
172172
)
173-
d = {k: self[k].to_py() for k in self._Model.model_fields}
174-
return self._Model.model_validate(d)
173+
with self.transaction() as txn:
174+
assert txn._txn is not None
175+
all_roots = self._doc.to_py(txn._txn)
176+
return self._Model.model_validate(all_roots)
175177

176178
def get_update(self, state: bytes | None = None) -> bytes:
177179
"""

python/pycrdt/_pycrdt.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class Doc:
6161
def roots(self, txn: Transaction) -> dict[str, Text | Array | Map]:
6262
"""Get top-level (root) shared types available in current document."""
6363

64+
def to_py(self, txn: Transaction) -> dict[str, Any]:
65+
"""Get top-level (root) shared types as native Python objects."""
66+
6467
def observe(self, callback: Callable[[TransactionEvent], None]) -> Subscription:
6568
"""Subscribes a callback to be called with the shared document change event.
6669
Returns a subscription that can be used to unsubscribe."""

src/doc.rs

Lines changed: 124 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
use pyo3::prelude::*;
2-
use pyo3::IntoPyObjectExt;
3-
use pyo3::exceptions::{PyRuntimeError, PyValueError};
4-
use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList};
5-
use yrs::{
6-
Doc as _Doc, Options, ReadTxn, StateVector, SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update, WriteTxn
7-
};
8-
use yrs::updates::encoder::{Encode, Encoder};
9-
use yrs::updates::decoder::Decode;
10-
use crate::text::Text;
111
use crate::array::Array;
122
use crate::map::Map;
13-
use crate::transaction::Transaction;
143
use crate::subscription::Subscription;
4+
use crate::text::Text;
5+
use crate::transaction::Transaction;
156
use crate::type_conversions::ToPython;
167
use crate::xml::XmlFragment;
17-
8+
use pyo3::exceptions::{PyRuntimeError, PyValueError};
9+
use pyo3::prelude::*;
10+
use pyo3::types::{PyBool, PyBytes, PyDict, PyInt, PyList};
11+
use pyo3::IntoPyObjectExt;
12+
use yrs::updates::decoder::Decode;
13+
use yrs::updates::encoder::{Encode, Encoder};
14+
use yrs::{
15+
Array as YArray, Doc as _Doc, GetString, Map as YMap, Options, ReadTxn, StateVector,
16+
SubdocsEvent as _SubdocsEvent, Transact, TransactionCleanupEvent, TransactionMut, Update,
17+
WriteTxn,
18+
};
1819

1920
#[pyclass]
2021
#[derive(Clone)]
@@ -41,7 +42,8 @@ impl Doc {
4142
let mut encoder = yrs::updates::encoder::EncoderV1::new();
4243
{
4344
let txn = original.doc.transact();
44-
txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder).unwrap();
45+
txn.encode_state_from_snapshot(&snapshot.snapshot, &mut encoder)
46+
.unwrap();
4547
}
4648
let update = yrs::Update::decode_v1(&encoder.to_vec()).unwrap();
4749
{
@@ -53,11 +55,19 @@ impl Doc {
5355
let txn_orig = original.doc.transact();
5456
for (name, root) in txn_orig.root_refs() {
5557
match root {
56-
yrs::Out::YText(_) => { let _ = new_doc.get_or_insert_text(name); },
57-
yrs::Out::YArray(_) => { let _ = new_doc.get_or_insert_array(name); },
58-
yrs::Out::YMap(_) => { let _ = new_doc.get_or_insert_map(name); },
59-
yrs::Out::YXmlFragment(_) => { let _ = new_doc.get_or_insert_xml_fragment(name); },
60-
_ => {}, // ignore unknown types
58+
yrs::Out::YText(_) => {
59+
let _ = new_doc.get_or_insert_text(name);
60+
}
61+
yrs::Out::YArray(_) => {
62+
let _ = new_doc.get_or_insert_array(name);
63+
}
64+
yrs::Out::YMap(_) => {
65+
let _ = new_doc.get_or_insert_map(name);
66+
}
67+
yrs::Out::YXmlFragment(_) => {
68+
let _ = new_doc.get_or_insert_xml_fragment(name);
69+
}
70+
_ => {} // ignore unknown types
6171
}
6272
}
6373
drop(txn_orig);
@@ -71,14 +81,16 @@ impl Doc {
7181
fn new(client_id: &Bound<'_, PyAny>, skip_gc: &Bound<'_, PyAny>) -> PyResult<Self> {
7282
let mut options = Options::default();
7383
if !client_id.is_none() {
74-
let _client_id: u64 = client_id.downcast::<PyInt>()
84+
let _client_id: u64 = client_id
85+
.downcast::<PyInt>()
7586
.map_err(|_| PyValueError::new_err("client_id must be an integer"))?
7687
.extract()
7788
.map_err(|_| PyValueError::new_err("client_id must be a valid u64"))?;
7889
options.client_id = _client_id;
7990
}
8091
if !skip_gc.is_none() {
81-
let _skip_gc: bool = skip_gc.downcast::<PyBool>()
92+
let _skip_gc: bool = skip_gc
93+
.downcast::<PyBool>()
8294
.map_err(|_| PyValueError::new_err("skip_gc must be a boolean"))?
8395
.extract()
8496
.map_err(|_| PyValueError::new_err("skip_gc must be a valid bool"))?;
@@ -90,7 +102,11 @@ impl Doc {
90102

91103
#[staticmethod]
92104
#[pyo3(name = "from_snapshot")]
93-
pub fn from_snapshot(py: Python<'_>, snapshot: PyRef<'_, crate::snapshot::Snapshot>, doc: PyRef<'_, Doc>) -> PyResult<Py<Doc>> {
105+
pub fn from_snapshot(
106+
py: Python<'_>,
107+
snapshot: PyRef<'_, crate::snapshot::Snapshot>,
108+
doc: PyRef<'_, Doc>,
109+
) -> PyResult<Py<Doc>> {
94110
let restored = Doc::_from_snapshot_impl(&doc, &snapshot);
95111
Py::new(py, restored)
96112
}
@@ -103,23 +119,38 @@ impl Doc {
103119
self.doc.client_id()
104120
}
105121

106-
fn get_or_insert_text(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Text>> {
122+
fn get_or_insert_text(
123+
&mut self,
124+
py: Python<'_>,
125+
txn: &mut Transaction,
126+
name: &str,
127+
) -> PyResult<Py<Text>> {
107128
let mut _t = txn.transaction();
108129
let t = _t.as_mut().unwrap().as_mut();
109130
let text = t.get_or_insert_text(name);
110131
let pytext: Py<Text> = Py::new(py, Text::from(text))?;
111132
Ok(pytext)
112133
}
113134

114-
fn get_or_insert_array(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Array>> {
135+
fn get_or_insert_array(
136+
&mut self,
137+
py: Python<'_>,
138+
txn: &mut Transaction,
139+
name: &str,
140+
) -> PyResult<Py<Array>> {
115141
let mut _t = txn.transaction();
116142
let t = _t.as_mut().unwrap().as_mut();
117143
let shared = t.get_or_insert_array(name);
118-
let pyshared: Py<Array > = Py::new(py, Array::from(shared))?;
144+
let pyshared: Py<Array> = Py::new(py, Array::from(shared))?;
119145
Ok(pyshared)
120146
}
121147

122-
fn get_or_insert_map(&mut self, py: Python<'_>, txn: &mut Transaction, name: &str) -> PyResult<Py<Map>> {
148+
fn get_or_insert_map(
149+
&mut self,
150+
py: Python<'_>,
151+
txn: &mut Transaction,
152+
name: &str,
153+
) -> PyResult<Py<Map>> {
123154
let mut _t = txn.transaction();
124155
let t = _t.as_mut().unwrap().as_mut();
125156
let shared = t.get_or_insert_map(name);
@@ -141,7 +172,11 @@ impl Doc {
141172
Err(PyRuntimeError::new_err("Already in a transaction"))
142173
}
143174

144-
fn create_transaction_with_origin(&self, py: Python<'_>, origin: i128) -> PyResult<Py<Transaction>> {
175+
fn create_transaction_with_origin(
176+
&self,
177+
py: Python<'_>,
178+
origin: i128,
179+
) -> PyResult<Py<Transaction>> {
145180
if let Ok(txn) = self.doc.try_transact_mut_with(origin) {
146181
let t: Py<Transaction> = Py::new(py, Transaction::from(txn))?;
147182
return Ok(t);
@@ -160,7 +195,9 @@ impl Doc {
160195
let mut _t = txn.transaction();
161196
let t = _t.as_mut().unwrap().as_mut();
162197
let state: &[u8] = state.extract()?;
163-
let Ok(state_vector) = StateVector::decode_v1(&state) else { return Err(PyValueError::new_err("Cannot decode state")) };
198+
let Ok(state_vector) = StateVector::decode_v1(&state) else {
199+
return Err(PyValueError::new_err("Cannot decode state"));
200+
};
164201
let update = t.encode_diff_v1(&state_vector);
165202
let bytes: Py<PyAny> = Python::attach(|py| PyBytes::new(py, &update).into());
166203
Ok(bytes)
@@ -186,8 +223,53 @@ impl Doc {
186223
result.into()
187224
}
188225

226+
fn to_py(&self, py: Python<'_>, txn: &mut Transaction) -> PyResult<Py<PyAny>> {
227+
let mut _t = txn.transaction();
228+
let t = _t.as_mut().unwrap().as_mut();
229+
let result = PyDict::new(py);
230+
231+
let roots_info: Vec<_> = t
232+
.root_refs()
233+
.map(|(name, root)| (name.to_string(), root))
234+
.collect();
235+
236+
for (name, root) in roots_info {
237+
match root {
238+
yrs::Out::YText(_) => {
239+
let text = t.get_or_insert_text(name.as_str());
240+
let value = text.get_string(t);
241+
result.set_item(name, value)?;
242+
}
243+
yrs::Out::YArray(_) => {
244+
let array = t.get_or_insert_array(name.as_str());
245+
let list = PyList::empty(py);
246+
for item in array.iter(t) {
247+
list.append(item.into_py(py))?;
248+
}
249+
result.set_item(name, list)?;
250+
}
251+
yrs::Out::YMap(_) => {
252+
let map = t.get_or_insert_map(name.as_str());
253+
let dict = PyDict::new(py);
254+
for (key, value) in map.iter(t) {
255+
dict.set_item(key, value.into_py(py))?;
256+
}
257+
result.set_item(name, dict)?;
258+
}
259+
yrs::Out::YXmlFragment(_) => {
260+
let xml = t.get_or_insert_xml_fragment(name.as_str());
261+
let xml_py = Py::new(py, XmlFragment::from(xml))?;
262+
result.set_item(name, xml_py)?;
263+
}
264+
_ => {} // ignore other types
265+
}
266+
}
267+
Ok(result.into())
268+
}
269+
189270
pub fn observe(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> {
190-
let sub = self.doc
271+
let sub = self
272+
.doc
191273
.observe_transaction_cleanup(move |txn, event| {
192274
if !event.delete_set.is_empty() || event.before_state != event.after_state {
193275
Python::attach(|py| {
@@ -204,7 +286,8 @@ impl Doc {
204286
}
205287

206288
pub fn observe_subdocs(&mut self, py: Python<'_>, f: Py<PyAny>) -> PyResult<Py<Subscription>> {
207-
let sub = self.doc
289+
let sub = self
290+
.doc
208291
.observe_subdocs(move |_, event| {
209292
Python::attach(|py| {
210293
let event = SubdocsEvent::new(py, event);
@@ -326,11 +409,20 @@ pub struct SubdocsEvent {
326409

327410
impl SubdocsEvent {
328411
fn new<'py>(py: Python<'py>, event: &_SubdocsEvent) -> Self {
329-
let added: Vec<String> = event.added().map(|d| d.guid().clone().to_string()).collect();
412+
let added: Vec<String> = event
413+
.added()
414+
.map(|d| d.guid().clone().to_string())
415+
.collect();
330416
let added = PyList::new(py, added).unwrap().into_py_any(py).unwrap();
331-
let removed: Vec<String> = event.removed().map(|d| d.guid().clone().to_string()).collect();
417+
let removed: Vec<String> = event
418+
.removed()
419+
.map(|d| d.guid().clone().to_string())
420+
.collect();
332421
let removed = PyList::new(py, removed).unwrap().into_py_any(py).unwrap();
333-
let loaded: Vec<String> = event.loaded().map(|d| d.guid().clone().to_string()).collect();
422+
let loaded: Vec<String> = event
423+
.loaded()
424+
.map(|d| d.guid().clone().to_string())
425+
.collect();
334426
let loaded = PyList::new(py, loaded).unwrap().into_py_any(py).unwrap();
335427
SubdocsEvent {
336428
added,

0 commit comments

Comments
 (0)