diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index 6313f7fa5b5..bc6626fcb16 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -21,7 +21,7 @@ use spacetimedb_sats::buffer::BufWriter; use spacetimedb_sats::db::def::{IndexDef, IndexType}; use spacetimedb_sats::relation::{FieldExpr, FieldName}; use spacetimedb_sats::{ProductType, Typespace}; -use spacetimedb_vm::expr::{Code, ColumnOp}; +use spacetimedb_vm::expr::{Code, ColumnOp, SourceSet}; #[derive(Clone)] pub struct InstanceEnv { @@ -368,11 +368,14 @@ impl InstanceEnv { filter, ) .map_err(NodesError::DecodeFilter)?; - let q = spacetimedb_vm::dsl::query(&*schema).with_select(filter_to_column_op(&schema.table_name, filter)); + + let q = + spacetimedb_vm::dsl::query(schema.as_ref()).with_select(filter_to_column_op(&schema.table_name, filter)); //TODO: How pass the `caller` here? let mut tx: TxMode = tx.into(); let p = &mut DbProgram::new(ctx, stdb, &mut tx, AuthCtx::for_current(self.dbic.identity)); - let results = match spacetimedb_vm::eval::run_ast(p, q.into()) { + // SQL queries can never reference `MemTable`s, so pass in an empty `SourceSet`. + let results = match spacetimedb_vm::eval::run_ast(p, q.into(), SourceSet::default()) { Code::Table(table) => table, _ => unreachable!("query should always return a table"), }; diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index ff31bc099c2..92b8421a968 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -268,9 +268,9 @@ impl Module for WasmModuleHostActor { let auth = AuthCtx::new(self.database_instance_context.identity, caller_identity); log::debug!("One-off query: {query}"); let ctx = &ExecutionContext::sql(db.address()); - let compiled = db.with_read_only(ctx, |tx| { - sql::compiler::compile_sql(db, tx, &query)? - .into_iter() + let compiled: Vec<_> = db.with_read_only(ctx, |tx| { + let ast = sql::compiler::compile_sql(db, tx, &query)?; + ast.into_iter() .map(|expr| { if matches!(expr, CrudExpr::Query { .. }) { Ok(expr) diff --git a/crates/core/src/sql/compiler.rs b/crates/core/src/sql/compiler.rs index a7401e9dc6e..1ad2a5c2b76 100644 --- a/crates/core/src/sql/compiler.rs +++ b/crates/core/src/sql/compiler.rs @@ -14,12 +14,13 @@ use std::sync::Arc; use super::ast::TableSchemaView; -/// Compile the `SQL` expression into a `ast` +/// Compile the `SQL` expression into an `ast` #[tracing::instrument(skip_all)] pub fn compile_sql(db: &RelationalDB, tx: &T, sql_text: &str) -> Result, DBError> { tracing::trace!(sql = sql_text); let ast = compile_to_ast(db, tx, sql_text)?; + // TODO(perf, bikeshedding): SmallVec? let mut results = Vec::with_capacity(ast.len()); for sql in ast { @@ -133,23 +134,25 @@ fn compile_select(table: From, project: Vec, selection: Option { - let t = db_table(rhs, rhs.table_id); + let rhs_source_expr = SourceExpr::DbTable(db_table(rhs, rhs.table_id)); match on.op { OpCmp::Eq => {} x => unreachable!("Unsupported operator `{x}` for joins"), } - q = q.with_join_inner(t, on.lhs.clone(), on.rhs.clone()); + q = q.with_join_inner(rhs_source_expr, on.lhs.clone(), on.rhs.clone()); } } } @@ -196,7 +199,7 @@ fn compile_insert( columns: Vec, values: Vec>, ) -> Result { - let db_table = compile_columns(&table, columns); + let source_expr = SourceExpr::DbTable(compile_columns(&table, columns)); let mut rows = Vec::with_capacity(values.len()); for x in values { @@ -215,7 +218,7 @@ fn compile_insert( } Ok(CrudExpr::Insert { - source: SourceExpr::DbTable(db_table), + source: source_expr, rows, }) } @@ -299,7 +302,6 @@ mod tests { use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::AlgebraicType; use spacetimedb_vm::expr::{IndexJoin, IndexScan, JoinExpr, Query}; - use spacetimedb_vm::relation::Table; fn assert_index_scan( op: Query, @@ -939,7 +941,7 @@ mod tests { table: ref probe_table, field: ref probe_field, }, - index_side: Table::DbTable(DbTable { + index_side: SourceExpr::DbTable(DbTable { table_id: index_table, .. }), index_col, diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index b3c8a943303..08f2717a83c 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -1,16 +1,15 @@ -use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{ProductType, ProductValue}; -use spacetimedb_vm::eval::run_ast; -use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr}; -use spacetimedb_vm::relation::MemTable; -use tracing::info; - +use super::compiler::compile_sql; use crate::database_instance_context_controller::DatabaseInstanceContextController; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::{DBError, DatabaseError}; use crate::execution_context::ExecutionContext; -use crate::sql::compiler::compile_sql; use crate::vm::{DbProgram, TxMode}; +use spacetimedb_lib::identity::AuthCtx; +use spacetimedb_lib::{ProductType, ProductValue}; +use spacetimedb_vm::eval::run_ast; +use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr, SourceSet}; +use spacetimedb_vm::relation::MemTable; +use tracing::info; pub struct StmtResult { pub schema: ProductType, @@ -59,13 +58,14 @@ pub fn execute_single_sql( tx: &Tx, ast: CrudExpr, auth: AuthCtx, + sources: SourceSet, ) -> Result, DBError> { let mut tx: TxMode = tx.into(); let p = &mut DbProgram::new(cx, db, &mut tx, auth); let q = Expr::Crud(Box::new(ast)); let mut result = Vec::with_capacity(1); - collect_result(&mut result, run_ast(p, q).into())?; + collect_result(&mut result, run_ast(p, q, sources).into())?; Ok(result) } @@ -75,6 +75,7 @@ pub fn execute_sql_mut_tx( tx: &mut MutTx, ast: Vec, auth: AuthCtx, + sources: SourceSet, ) -> Result, DBError> { let total = ast.len(); let mut tx: TxMode = tx.into(); @@ -83,7 +84,7 @@ pub fn execute_sql_mut_tx( let q = Expr::Block(ast.into_iter().map(|x| Expr::Crud(Box::new(x))).collect()); let mut result = Vec::with_capacity(total); - collect_result(&mut result, run_ast(p, q).into())?; + collect_result(&mut result, run_ast(p, q, sources).into())?; Ok(result) } @@ -100,13 +101,15 @@ pub fn execute_sql(db: &RelationalDB, ast: Vec, auth: AuthCtx) -> Resu let mut tx: TxMode = mut_tx.into(); let q = Expr::Block(ast.into_iter().map(|x| Expr::Crud(Box::new(x))).collect()); let p = &mut DbProgram::new(&ctx, db, &mut tx, auth); - collect_result(&mut result, run_ast(p, q).into()) + // SQL queries can never reference `MemTable`s, so pass an empty `SourceSet`. + collect_result(&mut result, run_ast(p, q, SourceSet::default()).into()) }), true => db.with_read_only(&ctx, |tx| { let mut tx = TxMode::Tx(tx); let q = Expr::Block(ast.into_iter().map(|x| Expr::Crud(Box::new(x))).collect()); let p = &mut DbProgram::new(&ctx, db, &mut tx, auth); - collect_result(&mut result, run_ast(p, q).into()) + // SQL queries can never reference `MemTable`s, so pass an empty `SourceSet`. + collect_result(&mut result, run_ast(p, q, SourceSet::default()).into()) }), }?; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index b9f33e2de55..05391439b0d 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use super::{ query::compile_read_only_query, subscription::{ExecutionSet, Subscription}, @@ -20,6 +18,7 @@ use parking_lot::RwLock; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Identity; +use std::sync::Arc; type Subscriptions = Arc>>; #[derive(Debug)] diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index c14e494e8f7..f7a9d28e1eb 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; -use std::time::Instant; - use crate::db::db_metrics::{DB_METRICS, MAX_QUERY_COMPILE_TIME}; use crate::db::relational_db::{RelationalDB, Tx}; use crate::error::{DBError, SubscriptionError}; @@ -16,9 +13,10 @@ use spacetimedb_lib::Address; use spacetimedb_sats::db::auth::StAccess; use spacetimedb_sats::relation::{Column, FieldName, Header}; use spacetimedb_sats::AlgebraicType; -use spacetimedb_vm::expr; -use spacetimedb_vm::expr::{Crud, CrudExpr, DbType, QueryExpr}; +use spacetimedb_vm::expr::{self, Crud, CrudExpr, DbType, QueryExpr, SourceSet}; use spacetimedb_vm::relation::MemTable; +use std::sync::Arc; +use std::time::Instant; use super::subscription::get_all; @@ -67,9 +65,12 @@ pub fn to_mem_table_with_op_type(head: Arc
, table_access: StAccess, data /// /// To be able to reify the `op_type` of the individual operations in the update, /// each virtual row is extended with a column [`OP_TYPE_FIELD_NAME`]. -pub fn to_mem_table(mut of: QueryExpr, data: &DatabaseTableUpdate) -> QueryExpr { - of.source = to_mem_table_with_op_type(of.source.head().clone(), of.source.table_access(), data).into(); - of +pub fn to_mem_table(mut of: QueryExpr, data: &DatabaseTableUpdate) -> (QueryExpr, SourceSet) { + let mem_table = to_mem_table_with_op_type(of.source.head().clone(), of.source.table_access(), data); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(mem_table); + of.source = source_expr; + (of, sources) } /// Runs a query that evaluates if the changes made should be reported to the [ModuleSubscriptionManager] @@ -80,26 +81,21 @@ pub(crate) fn run_query( tx: &Tx, query: &QueryExpr, auth: AuthCtx, + sources: SourceSet, ) -> Result, DBError> { - execute_single_sql(cx, db, tx, CrudExpr::Query(query.clone()), auth) + execute_single_sql(cx, db, tx, CrudExpr::Query(query.clone()), auth, sources) } // TODO: It's semantically wrong to `SUBSCRIBE_TO_ALL_QUERY` // as it can only return back the changes valid for the tables in scope *right now* // instead of **continuously updating** the db changes // with system table modifications (add/remove tables, indexes, ...). -/// Compile from `SQL` into a [`Query`], rejecting empty queries and queries that attempt to modify the data in any way. -/// -/// NOTE: When the `input` query is equal to [`SUBSCRIBE_TO_ALL_QUERY`], -/// **compilation is bypassed** and the equivalent of the following is done: +// +/// Variant of [`compile_read_only_query`] which appends `SourceExpr`s into a given `SourceBuilder`, +/// rather than returning a new `SourceSet`. /// -///```rust,ignore -/// for t in db.user_tables { -/// query.push(format!("SELECT * FROM {t}")); -/// } -/// ``` -/// -/// WARNING: [`SUBSCRIBE_TO_ALL_QUERY`] is only valid for repeated calls as long there is not change on database schema, and the clients must `unsubscribe` before modifying it. +/// This is necessary when merging multiple SQL queries into a single query set, +/// as in [`crate::subscription::module_subscription_actor::ModuleSubscriptions::add_subscriber`]. #[tracing::instrument(skip(relational_db, auth, tx))] pub fn compile_read_only_query( relational_db: &RelationalDB, @@ -173,7 +169,7 @@ fn record_query_compilation_metrics(workload: WorkloadType, db: &Address, query: } /// The kind of [`QueryExpr`] currently supported for incremental evaluation. -#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)] pub enum Supported { /// A scan or [`QueryExpr::Select`] of a single table. Scan, @@ -218,7 +214,7 @@ mod tests { use spacetimedb_sats::db::def::*; use spacetimedb_sats::relation::FieldName; use spacetimedb_sats::{product, ProductType, ProductValue}; - use spacetimedb_vm::dsl::{db_table, mem_table, scalar}; + use spacetimedb_vm::dsl::{mem_table, scalar}; use spacetimedb_vm::operator::OpCmp; fn insert_op(table_id: TableId, table_name: &str, row: ProductValue) -> DatabaseTableUpdate { @@ -263,7 +259,8 @@ mod tests { }; let schema = db.schema_for_table_mut(tx, table_id).unwrap().into_owned(); - let q = QueryExpr::new(db_table(&schema, table_id)); + + let q = QueryExpr::new(&schema); Ok((schema, table, data, q)) } @@ -323,8 +320,15 @@ mod tests { q: &QueryExpr, data: &DatabaseTableUpdate, ) -> ResultTest<()> { - let q = to_mem_table(q.clone(), data); - let result = run_query(&ExecutionContext::default(), db, tx, &q, AuthCtx::for_testing())?; + let (q, sources) = to_mem_table(q.clone(), data); + let result = run_query( + &ExecutionContext::default(), + db, + tx, + &q, + AuthCtx::for_testing(), + sources, + )?; assert_eq!( Some(table.as_without_table_name()), @@ -386,6 +390,10 @@ mod tests { Ok(()) } + fn singleton_execution_set(expr: QueryExpr) -> ResultTest { + Ok(ExecutionSet::from_iter([SupportedQuery::try_from(expr)?])) + } + #[test] fn test_eval_incr_for_index_scan() -> ResultTest<()> { let (db, _tmp) = make_test_db()?; @@ -423,7 +431,7 @@ mod tests { panic!("unexpected query {:#?}", exp[0]); }; - let query: ExecutionSet = query.try_into()?; + let query: ExecutionSet = singleton_execution_set(query)?; let result = query.eval_incr(&db, &tx, &update, AuthCtx::for_testing())?; @@ -487,7 +495,7 @@ mod tests { panic!("unexpected query {:#?}", exp[0]); }; - let query: ExecutionSet = query.try_into()?; + let query: ExecutionSet = singleton_execution_set(query)?; db.release_tx(&ExecutionContext::default(), tx); @@ -758,13 +766,13 @@ mod tests { check_query(&db, &table, &tx, &q, &data)?; //SELECT * FROM inventory WHERE inventory_id = 1 - let q_id = QueryExpr::new(db_table(&schema, schema.table_id)).with_select_cmp( + let q_id = QueryExpr::new(&schema).with_select_cmp( OpCmp::Eq, FieldName::named("_inventory", "inventory_id"), scalar(1u64), ); - let s = ExecutionSet::from_iter([q_id.try_into()?]); + let s = singleton_execution_set(q_id)?; let row2 = TableOp::insert(row.clone()); @@ -780,9 +788,9 @@ mod tests { check_query_incr(&db, &tx, &s, &update, 1, &[row])?; - let q = QueryExpr::new(db_table(&schema, schema.table_id)); + let q = QueryExpr::new(&schema); - let q = to_mem_table(q, &data); + let (q, sources) = to_mem_table(q, &data); //Try access the private table match run_query( &ExecutionContext::default(), @@ -790,6 +798,7 @@ mod tests { &tx, &q, AuthCtx::new(Identity::__dummy(), Identity::from_byte_array([1u8; 32])), + sources, ) { Ok(_) => { panic!("it allows to execute against private table") @@ -863,6 +872,7 @@ mod tests { &tx, q.as_expr(), AuthCtx::for_testing(), + SourceSet::default(), )?; assert_eq!(result.len(), 1, "Join query did not return any rows"); } diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index a32119f3b4a..2800e373d25 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -37,11 +37,11 @@ use anyhow::Context; use itertools::Either; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::ProductValue; use spacetimedb_primitives::TableId; use spacetimedb_sats::db::auth::{StAccess, StTableType}; -use spacetimedb_sats::relation::{Header, Relation}; -use spacetimedb_vm::expr::{self, IndexJoin, QueryExpr}; +use spacetimedb_sats::relation::Header; +use spacetimedb_sats::ProductValue; +use spacetimedb_vm::expr::{self, IndexJoin, QueryExpr, SourceSet}; use spacetimedb_vm::relation::MemTable; use std::collections::{hash_map, HashMap, HashSet}; use std::ops::Deref; @@ -82,7 +82,7 @@ impl Subscription { /// A [`QueryExpr`] tagged with [`query::Supported`]. /// /// Constructed via `TryFrom`, which rejects unsupported queries. -#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct SupportedQuery { kind: query::Supported, expr: QueryExpr, @@ -128,9 +128,10 @@ fn eval_secondary_updates<'a>( auth: AuthCtx, tx: &Tx, query: &QueryExpr, + sources: SourceSet, ) -> Result, DBError> { let ctx = ExecutionContext::incremental_update(db.address()); - Ok(run_query(&ctx, db, tx, query, auth)? + Ok(run_query(&ctx, db, tx, query, auth, sources)? .into_iter() .flat_map(|data| data.data)) } @@ -144,9 +145,10 @@ fn eval_primary_updates<'a>( auth: AuthCtx, tx: &Tx, query: &QueryExpr, + sources: SourceSet, ) -> Result, DBError> { let ctx = ExecutionContext::incremental_update(db.address()); - let updates = run_query(&ctx, db, tx, query, auth)?; + let updates = run_query(&ctx, db, tx, query, auth, sources)?; let updates = updates.into_iter().flat_map(|MemTable { data, head, .. }| { // Remove the special __op_type field before computing each row's primary key. let pos_op_type = head.find_pos_by_name(OP_TYPE_FIELD_NAME).unwrap_or_else(|| { @@ -182,11 +184,12 @@ fn eval_updates<'a>( tx: &Tx, query: &QueryExpr, is_primary: bool, + sources: SourceSet, ) -> Result, DBError> { Ok(if is_primary { - Either::Left(eval_primary_updates(db, auth, tx, query)?.map(|(_, row)| row)) + Either::Left(eval_primary_updates(db, auth, tx, query, sources)?.map(|(_, row)| row)) } else { - Either::Right(eval_secondary_updates(db, auth, tx, query)?) + Either::Right(eval_secondary_updates(db, auth, tx, query, sources)?) }) } @@ -345,42 +348,49 @@ impl<'a> IncrementalJoin<'a> { let mut inserts = { // Replan query after replacing the indexed table with a virtual table, // since join order may need to be reversed. - let join_a = with_delta_table(self.join.clone(), true, self.index_side.inserts()); + let (join_a, join_a_sources) = with_delta_table(self.join.clone(), Some(self.index_side.inserts()), None); let join_a = QueryExpr::from(join_a).optimize(&|table_id, table_name| db.row_count(table_id, table_name)); // No need to replan after replacing the probe side with a virtual table, // since no new constraints have been added. - let join_b = with_delta_table(self.join.clone(), false, self.probe_side.inserts()).into(); + let (join_b, join_b_sources) = with_delta_table(self.join.clone(), None, Some(self.probe_side.inserts())); + let join_b = join_b.into(); // {A+ join B} - let a = eval_updates(db, *auth, tx, &join_a, self.join.return_index_rows)?; + let a = eval_updates(db, *auth, tx, &join_a, self.join.return_index_rows, join_a_sources)?; // {A join B+} - let b = eval_updates(db, *auth, tx, &join_b, !self.join.return_index_rows)?; + let b = eval_updates(db, *auth, tx, &join_b, !self.join.return_index_rows, join_b_sources)?; + // {A+ join B} U {A join B+} itertools::chain![a, b].collect::>() }; let mut deletes = { // Replan query after replacing the indexed table with a virtual table, // since join order may need to be reversed. - let join_a = with_delta_table(self.join.clone(), true, self.index_side.deletes()); + let (join_a, join_a_sources) = with_delta_table(self.join.clone(), Some(self.index_side.deletes()), None); let join_a = QueryExpr::from(join_a).optimize(&|table_id, table_name| db.row_count(table_id, table_name)); // No need to replan after replacing the probe side with a virtual table, // since no new constraints have been added. - let join_b = with_delta_table(self.join.clone(), false, self.probe_side.deletes()).into(); + let (join_b, join_b_sources) = with_delta_table(self.join.clone(), None, Some(self.probe_side.deletes())); + let join_b = join_b.into(); // No need to replan after replacing both sides with a virtual tables, // since there are no indexes available to us. // The only valid plan in this case is that of an inner join. - let join_c = with_delta_table(self.join.clone(), true, self.index_side.deletes()); - let join_c = with_delta_table(join_c, false, self.probe_side.deletes()).into(); + let (join_c, join_c_sources) = with_delta_table( + self.join.clone(), + Some(self.index_side.deletes()), + Some(self.probe_side.deletes()), + ); + let join_c = join_c.into(); // {A- join B} - let a = eval_updates(db, *auth, tx, &join_a, self.join.return_index_rows)?; + let a = eval_updates(db, *auth, tx, &join_a, self.join.return_index_rows, join_a_sources)?; // {A join B-} - let b = eval_updates(db, *auth, tx, &join_b, !self.join.return_index_rows)?; + let b = eval_updates(db, *auth, tx, &join_b, !self.join.return_index_rows, join_b_sources)?; // {A- join B-} - let c = eval_updates(db, *auth, tx, &join_c, true)?; + let c = eval_updates(db, *auth, tx, &join_c, true, join_c_sources)?; // {A- join B} U {A join B-} U {A- join B-} itertools::chain![a, b, c].collect::>() }; @@ -397,7 +407,11 @@ impl<'a> IncrementalJoin<'a> { /// Replace an [IndexJoin]'s scan or fetch operation with a delta table. /// A delta table consists purely of updates or changes to the base table. -fn with_delta_table(mut join: IndexJoin, index_side: bool, delta: DatabaseTableUpdate) -> IndexJoin { +fn with_delta_table( + mut join: IndexJoin, + index_side: Option, + probe_side: Option, +) -> (IndexJoin, SourceSet) { fn to_mem_table(head: Arc
, table_access: StAccess, delta: DatabaseTableUpdate) -> MemTable { MemTable::new( head, @@ -406,47 +420,49 @@ fn with_delta_table(mut join: IndexJoin, index_side: bool, delta: DatabaseTableU ) } - // We are replacing the indexed table, - // and the rows of the indexed table are being returned. - // Therefore we must add a column with the op type. - if index_side && join.return_index_rows { - let head = join.index_side.head().clone(); - let table_access = join.index_side.table_access(); - join.index_side = to_mem_table_with_op_type(head, table_access, &delta).into(); - return join; - } - // We are replacing the indexed table, - // but the rows of the indexed table are not being returned. - // Therefore we do not need to add a column with the op type. - if index_side && !join.return_index_rows { + let mut sources = SourceSet::default(); + + if let Some(index_side) = index_side { let head = join.index_side.head().clone(); let table_access = join.index_side.table_access(); - join.index_side = to_mem_table(head, table_access, delta).into(); - return join; - } - // We are replacing the probe table, - // but the rows of the indexed table are being returned. - // Therefore we do not need to add a column with the op type. - if !index_side && join.return_index_rows { - let head = join.probe_side.source.head().clone(); - let table_access = join.probe_side.source.table_access(); - join.probe_side.source = to_mem_table(head, table_access, delta).into(); - return join; + let mem_table = if join.return_index_rows { + // We are replacing the indexed table, + // and the rows of the indexed table are being returned. + // Therefore we must add a column with the op type. + to_mem_table_with_op_type(head, table_access, &index_side) + } else { + // We are replacing the indexed table, + // but the rows of the indexed table are not being returned. + // Therefore we do not need to add a column with the op type. + to_mem_table(head, table_access, index_side) + }; + let source_expr = sources.add_mem_table(mem_table); + join.index_side = source_expr; } - // We are replacing the probe table, - // and the rows of the probe table are being returned. - // Therefore we must add a column with the op type. - if !index_side && !join.return_index_rows { + + if let Some(probe_side) = probe_side { let head = join.probe_side.source.head().clone(); let table_access = join.probe_side.source.table_access(); - join.probe_side.source = to_mem_table_with_op_type(head, table_access, &delta).into(); - return join; + let mem_table = if join.return_index_rows { + // We are replacing the probe table, + // but the rows of the indexed table are being returned. + // Therefore we do not need to add a column with the op type. + to_mem_table(head, table_access, probe_side) + } else { + // We are replacing the probe table, + // and the rows of the probe table are being returned. + // Therefore we must add a column with the op type. + to_mem_table_with_op_type(head, table_access, &probe_side) + }; + let source_expr = sources.add_mem_table(mem_table); + join.probe_side.source = source_expr; } - join + + (join, sources) } /// The atomic unit of execution within a subscription set. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ExecutionUnit { table_id: TableId, table_name: String, @@ -459,7 +475,9 @@ impl ExecutionUnit { let ctx = ExecutionContext::subscribe(db.address()); let ops = match &self.queries[..] { // special-case single query - we don't have to deduplicate - [query] => run_query(&ctx, db, tx, &query.expr, auth)? + // Raw SQL queries (not incrementalized) never reference `MemTable`s, + // so pass an empty `SourceSet`. + [query] => run_query(&ctx, db, tx, &query.expr, auth, SourceSet::default())? .into_iter() .flat_map(|table| table.data) .map(TableOp::insert) @@ -469,7 +487,9 @@ impl ExecutionUnit { let mut ops = Vec::new(); for SupportedQuery { kind: _, expr } in queries { - for table in run_query(&ctx, db, tx, expr, auth)? { + // Raw SQL queries (not incrementalized) never reference `MemTable`s, + // so pass an empty `SourceSet`. + for table in run_query(&ctx, db, tx, expr, auth, SourceSet::default())? { ops.extend(table.data.into_iter().map(TableOp::insert)); } } @@ -505,9 +525,9 @@ impl ExecutionUnit { .find(|update| update.table_id == self.table_id) { // Replace table reference in original query plan with virtual MemTable - let plan = query::to_mem_table(query.expr.clone(), rows); + let (plan, sources) = query::to_mem_table(query.expr.clone(), rows); // Evaluate the new plan and capture the new row operations. - eval_primary_updates(db, auth, tx, &plan)? + eval_primary_updates(db, auth, tx, &plan, sources)? .map(|r| TableOp::new(r.0, r.1)) .collect() } else { @@ -537,9 +557,11 @@ impl ExecutionUnit { .find(|update| update.table_id == self.table_id) { // Replace table reference in original query plan with virtual MemTable - let plan = query::to_mem_table(query.expr.clone(), rows); + let (plan, sources) = query::to_mem_table(query.expr.clone(), rows); // Evaluate the new plan and capture the new row operations. - ops.extend(eval_primary_updates(db, auth, tx, &plan)?.map(|r| TableOp::new(r.0, r.1))); + ops.extend( + eval_primary_updates(db, auth, tx, &plan, sources)?.map(|r| TableOp::new(r.0, r.1)), + ); } } Semijoin => { @@ -630,11 +652,6 @@ impl FromIterator for ExecutionSet { } } - for exec_unit in &mut exec_units { - exec_unit.queries.sort(); - } - exec_units.sort(); - ExecutionSet { exec_units } } } @@ -645,15 +662,6 @@ impl From> for ExecutionSet { } } -#[cfg(test)] -impl TryFrom for ExecutionSet { - type Error = DBError; - - fn try_from(expr: QueryExpr) -> Result { - Ok(ExecutionSet::from_iter(vec![SupportedQuery::try_from(expr)?])) - } -} - /// Queries all the [`StTableType::User`] tables *right now* /// and turns them into [`QueryExpr`], /// the moral equivalent of `SELECT * FROM table`. @@ -682,7 +690,6 @@ mod tests { use spacetimedb_sats::relation::{DbTable, FieldName}; use spacetimedb_sats::{product, AlgebraicType}; use spacetimedb_vm::expr::{CrudExpr, IndexJoin, Query, SourceExpr}; - use spacetimedb_vm::relation::Table; #[test] // Compile an index join after replacing the index side with a virtual table. @@ -730,7 +737,8 @@ mod tests { }; // Optimize the query plan for the incremental update. - let expr: QueryExpr = with_delta_table(join, true, delta).into(); + let (expr, _sources) = with_delta_table(join, Some(delta), None); + let expr: QueryExpr = expr.into(); let mut expr = expr.optimize(&|_, _| i64::MAX); assert_eq!(expr.source.table_name(), "lhs"); assert_eq!(expr.query.len(), 1); @@ -743,7 +751,7 @@ mod tests { let IndexJoin { probe_side: QueryExpr { - source: SourceExpr::MemTable(_), + source: SourceExpr::MemTable { .. }, query: ref lhs, }, probe_field: @@ -751,7 +759,7 @@ mod tests { table: ref probe_table, field: ref probe_field, }, - index_side: Table::DbTable(DbTable { + index_side: SourceExpr::DbTable(DbTable { table_id: index_table, .. }), index_select: Some(_), @@ -818,11 +826,13 @@ mod tests { }; // Optimize the query plan for the incremental update. - let expr: QueryExpr = with_delta_table(join, false, delta).into(); + let (expr, _sources) = with_delta_table(join, None, Some(delta)); + let expr = QueryExpr::from(expr); let mut expr = expr.optimize(&|_, _| i64::MAX); assert_eq!(expr.source.table_name(), "lhs"); assert_eq!(expr.query.len(), 1); + assert!(expr.source.is_db_table()); let join = expr.query.pop().unwrap(); let Query::IndexJoin(join) = join else { @@ -832,7 +842,7 @@ mod tests { let IndexJoin { probe_side: QueryExpr { - source: SourceExpr::MemTable(_), + source: SourceExpr::MemTable { .. }, query: ref rhs, }, probe_field: @@ -840,7 +850,7 @@ mod tests { table: ref probe_table, field: ref probe_field, }, - index_side: Table::DbTable(DbTable { + index_side: SourceExpr::DbTable(DbTable { table_id: index_table, .. }), index_select: None, diff --git a/crates/core/src/vm.rs b/crates/core/src/vm.rs index bcaf73ab88a..ea303d60921 100644 --- a/crates/core/src/vm.rs +++ b/crates/core/src/vm.rs @@ -46,12 +46,28 @@ pub fn build_query<'a>( stdb: &'a RelationalDB, tx: &'a TxMode, query: QueryCode, + sources: &mut SourceSet, ) -> Result>, ErrorVm> { - let db_table = matches!(&query.table, Table::DbTable(_)); - let mut result = get_table(ctx, stdb, tx, query.table.into())?; + let db_table = query.table.is_db_table(); + + // We're incrementally building a query iterator by applying each operation in the `query.query`. + // Most such operations will modify their parent, but certain operations (i.e. `IndexJoin`s) + // are only valid as the first operation in the list, + // and construct a new base query. + // + // Branches which use `result` will do `unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources))` + // to get an `IterRows` defaulting to the `query.table`. + // + // Branches which do not use the `result` will assert that it is `None`, + // i.e. that they are the first operator. + // + // TODO(bikeshedding): Avoid duplication of the ugly `result.take().map(...).unwrap_or_else(...)?` expr? + // TODO(bikeshedding): Refactor `QueryCode` to separate `IndexJoin` from other `Query` variants, + // removing the need for this convoluted logic? + let mut result = None; for op in query.query { - result = match op { + result = Some(match op { Query::IndexScan(IndexScan { table, columns, @@ -63,6 +79,10 @@ pub fn build_query<'a>( iter_by_col_range(ctx, stdb, tx, table, col_id, (lower_bound, upper_bound))? } Query::IndexScan(index_scan) => { + let result = result + .take() + .map(Ok) + .unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources))?; let header = result.head().clone(); let cmp: ColumnOp = index_scan.into(); let iter = result.select(move |row| cmp.compare(row, &header)); @@ -77,15 +97,20 @@ pub fn build_query<'a>( // It should not be possible for the planner to produce an invalid plan. Query::IndexJoin( join @ IndexJoin { - index_side: Table::MemTable(_), + index_side: SourceExpr::MemTable { .. }, .. }, - ) => build_query(ctx, stdb, tx, join.to_inner_join().into())?, + ) => { + if result.is_some() { + return Err(anyhow::anyhow!("Invalid query: `IndexJoin` must be the first operator").into()); + } + build_query(ctx, stdb, tx, join.to_inner_join().into(), sources)? + } Query::IndexJoin(IndexJoin { probe_side, probe_field, index_side: - Table::DbTable(DbTable { + SourceExpr::DbTable(DbTable { head: index_header, table_id: index_table, .. @@ -94,7 +119,10 @@ pub fn build_query<'a>( index_col, return_index_rows, }) => { - let probe_side = build_query(ctx, stdb, tx, probe_side.into())?; + if result.is_some() { + return Err(anyhow::anyhow!("Invalid query: `IndexJoin` must be the first operator").into()); + } + let probe_side = build_query(ctx, stdb, tx, probe_side.into(), sources)?; Box::new(IndexSemiJoin { ctx, db: stdb, @@ -110,11 +138,19 @@ pub fn build_query<'a>( }) } Query::Select(cmp) => { + let result = result + .take() + .map(Ok) + .unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources))?; let header = result.head().clone(); let iter = result.select(move |row| cmp.compare(row, &header)); Box::new(iter) } Query::Project(cols, _) => { + let result = result + .take() + .map(Ok) + .unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources))?; if cols.is_empty() { result } else { @@ -126,12 +162,19 @@ pub fn build_query<'a>( } } Query::JoinInner(join) => { - let iter = join_inner(ctx, stdb, tx, result, join, false)?; + let result = result + .take() + .map(Ok) + .unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources))?; + let iter = join_inner(ctx, stdb, tx, result, join, false, sources)?; Box::new(iter) } - } + }) } - Ok(result) + + result + .map(Ok) + .unwrap_or_else(|| get_table(ctx, stdb, tx, &query.table, sources)) } fn join_inner<'a>( @@ -141,13 +184,14 @@ fn join_inner<'a>( lhs: impl RelOps<'a> + 'a, rhs: JoinExpr, semi: bool, + sources: &mut SourceSet, ) -> Result + 'a, ErrorVm> { let col_lhs = FieldExpr::Name(rhs.col_lhs); let col_rhs = FieldExpr::Name(rhs.col_rhs); let key_lhs = [col_lhs.clone()]; let key_rhs = [col_rhs.clone()]; - let rhs = build_query(ctx, db, tx, rhs.rhs.into())?; + let rhs = build_query(ctx, db, tx, rhs.rhs.into(), sources)?; let key_lhs_header = lhs.head().clone(); let key_rhs_header = rhs.head().clone(); let col_lhs_header = lhs.head().clone(); @@ -179,22 +223,37 @@ fn join_inner<'a>( ) } +/// Resolve `query` to a table iterator, either an [`IterRows`] or a [`TableCursor`]. +/// +/// If `query` refers to a `MemTable`, this will `Option::take` said `MemTable` out of `sources`, +/// leaving `None`. +/// This means that a query cannot refer to a `MemTable` multiple times, +/// nor can a `SourceSet` which contains a `MemTable` be reused. +/// This is because [`IterRows`] takes ownership of the `MemTable`. +/// +/// On the other hand, if the `query` is a `DbTable`, `sources` is unused +/// and therefore unmodified. fn get_table<'a>( ctx: &'a ExecutionContext, stdb: &'a RelationalDB, tx: &'a TxMode, - query: SourceExpr, + query: &SourceExpr, + sources: &mut SourceSet, ) -> Result + 'a>, ErrorVm> { let head = query.head().clone(); let row_count = query.row_count(); Ok(match query { - SourceExpr::MemTable(x) => Box::new(RelIter::new(head, row_count, x)) as Box>, + SourceExpr::MemTable { source_id, .. } => Box::new(RelIter::new( + head, + row_count, + sources.take_mem_table(*source_id).expect("Unable to get MemTable"), + )) as Box>, SourceExpr::DbTable(x) => { let iter = match tx { TxMode::MutTx(tx) => stdb.iter_mut(ctx, tx, x.table_id)?, TxMode::Tx(tx) => stdb.iter(ctx, tx, x.table_id)?, }; - Box::new(TableCursor::new(x, iter)?) as Box> + Box::new(TableCursor::new(x.clone(), iter)?) as Box> } }) } @@ -325,11 +384,11 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { } #[tracing::instrument(skip_all)] - fn _eval_query(&mut self, query: QueryCode) -> Result { + fn _eval_query(&mut self, query: QueryCode, sources: &mut SourceSet) -> Result { let table_access = query.table.table_access(); tracing::trace!(table = query.table.table_name()); - let result = build_query(self.ctx, self.db, self.tx, query)?; + let result = build_query(self.ctx, self.db, self.tx, query, sources)?; let head = result.head().clone(); let rows = result.collect_vec(|row| row.into_product_value())?; @@ -366,9 +425,11 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { } } - fn _delete_query(&mut self, query: QueryCode) -> Result { - let table = query.table.clone(); - let result = self._eval_query(query)?; + fn _delete_query(&mut self, query: QueryCode, sources: &mut SourceSet) -> Result { + let table = sources + .take_table(&query.table) + .expect("Cannot delete from a `MemTable`"); + let result = self._eval_query(query, sources)?; match result { Code::Table(result) => self._execute_delete(&table, result.data), @@ -432,22 +493,26 @@ impl ProgramVm for DbProgram<'_, '_> { } // Safety: For DbProgram with tx = TxMode::Tx variant, all queries must match to CrudCode::Query and no other branch. - fn eval_query(&mut self, query: CrudCode) -> Result { + fn eval_query(&mut self, query: CrudCode, sources: &mut SourceSet) -> Result { query.check_auth(self.auth.owner, self.auth.caller)?; match query { - CrudCode::Query(query) => self._eval_query(query), - CrudCode::Insert { table, rows } => self._execute_insert(&table, rows), + CrudCode::Query(query) => self._eval_query(query, sources), + CrudCode::Insert { table, rows } => { + let src = sources.take_table(&table).unwrap(); + self._execute_insert(&src, rows) + } CrudCode::Update { delete, mut assignments, } => { let table = delete.table.clone(); - let result = self._eval_query(delete)?; + let result = self._eval_query(delete, sources)?; let Code::Table(deleted) = result else { return Ok(result); }; + let table = sources.take_table(&table).unwrap(); self._execute_delete(&table, deleted.data.clone())?; // Replace the columns in the matched rows with the assigned @@ -483,7 +548,7 @@ impl ProgramVm for DbProgram<'_, '_> { self._execute_insert(&table, insert_rows) } CrudCode::Delete { query } => { - let result = self._delete_query(query)?; + let result = self._delete_query(query, sources)?; Ok(result) } CrudCode::CreateTable { table } => { @@ -632,14 +697,17 @@ pub(crate) mod tests { let row = product!(1u64, "health"); let table_id = create_table_from_program(p, "inventory", head.clone(), &[row])?; - let inv = db_table(head, table_id); + let schema = TableDef::from_product("test", head).into_schema(table_id); let data = MemTable::from_value(scalar(1u64)); let rhs = data.get_field_pos(0).unwrap().clone(); - let q = query(inv).with_join_inner(data, FieldName::positional("inventory", 0), rhs); + let mut sources = SourceSet::default(); + let rhs_source_expr = sources.add_mem_table(data); + + let q = query(&schema).with_join_inner(rhs_source_expr, FieldName::positional("inventory", 0), rhs); - let result = match run_ast(p, q.into()) { + let result = match run_ast(p, q.into(), sources) { Code::Table(x) => x, x => panic!("invalid result {x}"), }; @@ -661,7 +729,7 @@ pub(crate) mod tests { } fn check_catalog(p: &mut DbProgram, name: &str, row: ProductValue, q: QueryExpr, schema: DbTable) { - let result = run_ast(p, q.into()); + let result = run_ast(p, q.into(), SourceSet::default()); //The expected result let input = mem_table(schema.head.clone_for_error(), vec![row]); diff --git a/crates/sats/src/relation.rs b/crates/sats/src/relation.rs index 01e3ead9600..45a77c54c80 100644 --- a/crates/sats/src/relation.rs +++ b/crates/sats/src/relation.rs @@ -419,7 +419,7 @@ impl From for Header { } /// An estimate for the range of rows in the [Relation] -#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct RowCount { pub min: usize, pub max: Option, diff --git a/crates/vm/src/errors.rs b/crates/vm/src/errors.rs index a490bb9b556..dd098f5af87 100644 --- a/crates/vm/src/errors.rs +++ b/crates/vm/src/errors.rs @@ -3,6 +3,8 @@ use spacetimedb_sats::AlgebraicValue; use std::fmt; use thiserror::Error; +use crate::expr::SourceId; + /// Typing Errors #[derive(Error, Debug)] pub enum ErrorType { @@ -25,6 +27,8 @@ pub enum ErrorVm { Auth(#[from] AuthError), #[error("Unsupported: {0}")] Unsupported(String), + #[error("No source table with index {0:?}")] + NoSuchSource(SourceId), #[error("{0}")] Other(#[from] anyhow::Error), } @@ -120,6 +124,11 @@ impl From for ErrorLang { ErrorVm::Unsupported(err) => ErrorLang::new(ErrorKind::Compiler, Some(&err)), ErrorVm::Lang(err) => err, ErrorVm::Auth(err) => ErrorLang::new(ErrorKind::Unauthorized, Some(&err.to_string())), + err @ ErrorVm::NoSuchSource(_) => ErrorLang { + kind: ErrorKind::Invalid, + msg: Some(format!("{err:?}")), + context: None, + }, } } } diff --git a/crates/vm/src/eval.rs b/crates/vm/src/eval.rs index 835cd00cc96..36860d0d399 100644 --- a/crates/vm/src/eval.rs +++ b/crates/vm/src/eval.rs @@ -1,24 +1,18 @@ use std::sync::Arc; use crate::errors::ErrorVm; -use crate::expr::{Code, CrudCode, CrudExpr, QueryCode, QueryExpr, SourceExpr}; +use crate::expr::{Code, CrudCode, CrudExpr, QueryCode, QueryExpr, SourceSet}; use crate::expr::{Expr, Query}; use crate::iterators::RelIter; use crate::program::ProgramVm; use crate::rel_ops::RelOps; -use crate::relation::{RelValue, Table}; +use crate::relation::RelValue; use spacetimedb_sats::relation::{FieldExpr, Relation}; fn compile_query(q: QueryExpr) -> QueryCode { - match q.source { - SourceExpr::MemTable(x) => QueryCode { - table: Table::MemTable(x), - query: q.query.clone(), - }, - SourceExpr::DbTable(x) => QueryCode { - table: Table::DbTable(x), - query: q.query.clone(), - }, + QueryCode { + table: q.source, + query: q.query, } } @@ -26,16 +20,7 @@ fn compile_query_expr(q: CrudExpr) -> Code { match q { CrudExpr::Query(q) => Code::Crud(CrudCode::Query(compile_query(q))), CrudExpr::Insert { source, rows } => { - let q = match source { - SourceExpr::MemTable(x) => CrudCode::Insert { - table: Table::MemTable(x), - rows, - }, - SourceExpr::DbTable(x) => CrudCode::Insert { - table: Table::DbTable(x), - rows, - }, - }; + let q = CrudCode::Insert { table: source, rows }; Code::Crud(q) } CrudExpr::Update { delete, assignments } => { @@ -61,8 +46,17 @@ fn compile_query_expr(q: CrudExpr) -> Code { pub type IterRows<'a> = dyn RelOps<'a> + 'a; +/// `sources` should be a `Vec` +/// where the `idx`th element is the table referred to in the `query` as `SourceId(idx)`. +/// While constructing the query, the `sources` will be destructively modified with `Option::take` +/// to extract the sources, +/// so the `query` cannot refer to the same `SourceId` multiple times. #[tracing::instrument(skip_all)] -pub fn build_query<'a>(mut result: Box>, query: Vec) -> Result>, ErrorVm> { +pub fn build_query<'a>( + mut result: Box>, + query: Vec, + sources: &mut SourceSet, +) -> Result>, ErrorVm> { for q in query { result = match q { Query::IndexScan(_) => { @@ -96,14 +90,18 @@ pub fn build_query<'a>(mut result: Box>, query: Vec) -> Resu let row_rhs = q.rhs.source.row_count(); let head = q.rhs.source.head().clone(); - let rhs = match q.rhs.source { - SourceExpr::MemTable(x) => Box::new(RelIter::new(head, row_rhs, x)) as Box>, - SourceExpr::DbTable(_) => { - todo!("How pass the db iter?") - } + let rhs = if let Some(rhs_source_id) = q.rhs.source.source_id() { + let Some(rhs_table) = sources.take_mem_table(rhs_source_id) else { + panic!( + "Query plan specifies a `MemTable` for {rhs_source_id:?}, but found a `DbTable` or nothing" + ); + }; + Box::new(RelIter::new(head, row_rhs, rhs_table)) as Box> + } else { + todo!("How pass the db iter?") }; - let rhs = build_query(rhs, q.rhs.query)?; + let rhs = build_query(rhs, q.rhs.query, sources)?; let lhs = result; let key_lhs_header = lhs.head().clone(); @@ -138,13 +136,13 @@ fn build_ast(ast: CrudExpr) -> Code { /// Execute the code #[tracing::instrument(skip_all)] -fn eval(p: &mut P, code: Code) -> Code { +fn eval(p: &mut P, code: Code, sources: &mut SourceSet) -> Code { match code { Code::Value(_) => code.clone(), Code::Block(lines) => { let mut result = Vec::with_capacity(lines.len()); for x in lines { - let r = eval(p, x); + let r = eval(p, x, sources); if r != Code::Pass { result.push(r); } @@ -156,7 +154,7 @@ fn eval(p: &mut P, code: Code) -> Code { _ => Code::Block(result), } } - Code::Crud(q) => p.eval_query(q).unwrap_or_else(|err| Code::Halt(err.into())), + Code::Crud(q) => p.eval_query(q, sources).unwrap_or_else(|err| Code::Halt(err.into())), Code::Pass => Code::Pass, Code::Halt(_) => code, Code::Table(_) => code, @@ -178,7 +176,7 @@ fn to_vec(of: Vec) -> Code { /// Optimize, compile & run the [Expr] #[tracing::instrument(skip_all)] -pub fn run_ast(p: &mut P, ast: Expr) -> Code { +pub fn run_ast(p: &mut P, ast: Expr, mut sources: SourceSet) -> Code { let code = match ast { Expr::Block(x) => to_vec(x), Expr::Crud(x) => build_ast(*x), @@ -186,7 +184,7 @@ pub fn run_ast(p: &mut P, ast: Expr) -> Code { Expr::Halt(err) => Code::Halt(err), Expr::Ident(x) => Code::Halt(ErrorVm::Unsupported(format!("Ident {x}")).into()), }; - eval(p, code) + eval(p, code, &mut sources) } /// Used internally for testing SQL JOINS. @@ -232,6 +230,7 @@ pub mod tests { use super::test_data::*; use super::*; use crate::dsl::{mem_table, query, scalar}; + use crate::expr::SourceSet; use crate::program::Program; use crate::relation::MemTable; use spacetimedb_lib::identity::AuthCtx; @@ -241,8 +240,8 @@ pub mod tests { use spacetimedb_sats::relation::FieldName; use spacetimedb_sats::{product, AlgebraicType, ProductType}; - fn run_query(p: &mut Program, ast: Expr) -> MemTable { - match run_ast(p, ast) { + fn run_query(p: &mut Program, ast: Expr, sources: SourceSet) -> MemTable { + match run_ast(p, ast, sources) { Code::Table(x) => x, x => panic!("Unexpected result on query: {x}"), } @@ -253,12 +252,14 @@ pub mod tests { let p = &mut Program::new(AuthCtx::for_testing()); let input = MemTable::from_value(scalar(1)); let field = input.get_field_pos(0).unwrap().clone(); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(input); - let q = query(input).with_select_cmp(OpCmp::Eq, field, scalar(1)); + let q = query(source_expr).with_select_cmp(OpCmp::Eq, field, scalar(1)); let head = q.source.head().clone(); - let result = run_ast(p, q.into()); + let result = run_ast(p, q.into(), sources); let row = scalar(1).into(); assert_eq!( result, @@ -272,13 +273,16 @@ pub mod tests { let p = &mut Program::new(AuthCtx::for_testing()); let input = scalar(1); let table = MemTable::from_value(scalar(1)); - let field = table.get_field_pos(0).unwrap().clone(); - let source = query(table.clone()); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(table.clone()); + + let source = query(source_expr); + let field = table.get_field_pos(0).unwrap().clone(); let q = source.clone().with_project(&[field.into()], None); let head = q.source.head().clone(); - let result = run_ast(p, q.into()); + let result = run_ast(p, q.into(), sources); let row = input.into(); assert_eq!( result, @@ -286,10 +290,14 @@ pub mod tests { "Project" ); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(table.clone()); + + let source = query(source_expr); let field = FieldName::positional(&table.head.table_name, 1); let q = source.with_project(&[field.clone().into()], None); - let result = run_ast(p, q.into()); + let result = run_ast(p, q.into(), sources); assert_eq!( result, Code::Halt(RelationError::FieldNotFound(head.clone_for_error(), field).into()), @@ -303,8 +311,12 @@ pub mod tests { let table = MemTable::from_value(scalar(1)); let field = table.get_field_pos(0).unwrap().clone(); - let q = query(table.clone()).with_join_inner(table, field.clone(), field); - let result = match run_ast(p, q.into()) { + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(table.clone()); + let second_source_expr = sources.add_mem_table(table); + + let q = query(source_expr).with_join_inner(second_source_expr, field.clone(), field); + let result = match run_ast(p, q.into(), sources) { Code::Table(x) => x, x => panic!("Invalid result {x}"), }; @@ -331,15 +343,21 @@ pub mod tests { let input = mem_table(inv, vec![row]); let inv = input.clone(); - let q = query(input.clone()).with_select_cmp(OpLogic::And, scalar(true), scalar(true)); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(input.clone()); - let result = run_ast(p, q.into()); + let q = query(source_expr.clone()).with_select_cmp(OpLogic::And, scalar(true), scalar(true)); + + let result = run_ast(p, q.into(), sources); assert_eq!(result, Code::Table(inv.clone()), "Query And"); - let q = query(input).with_select_cmp(OpLogic::Or, scalar(true), scalar(false)); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(input); + + let q = query(source_expr).with_select_cmp(OpLogic::Or, scalar(true), scalar(false)); - let result = run_ast(p, q.into()); + let result = run_ast(p, q.into(), sources); assert_eq!(result, Code::Table(inv), "Query Or"); } @@ -357,9 +375,13 @@ pub mod tests { let input = mem_table(inv, vec![row]); let field = input.get_field_pos(0).unwrap().clone(); - let q = query(input.clone()).with_join_inner(input, field.clone(), field); + let mut sources = SourceSet::default(); + let source_expr = sources.add_mem_table(input.clone()); + let second_source_expr = sources.add_mem_table(input); + + let q = query(source_expr).with_join_inner(second_source_expr, field.clone(), field); - let result = match run_ast(p, q.into()) { + let result = match run_ast(p, q.into(), sources) { Code::Table(x) => x, x => panic!("Invalid result {x}"), }; @@ -396,6 +418,10 @@ pub mod tests { let location_x = data.location.get_field_named("x").unwrap().clone(); let location_z = data.location.get_field_named("z").unwrap().clone(); + let mut sources = SourceSet::default(); + let player_source_expr = sources.add_mem_table(data.player.clone()); + let location_source_expr = sources.add_mem_table(data.location.clone()); + // SELECT // Player.* // FROM @@ -403,9 +429,9 @@ pub mod tests { // JOIN Location // ON Location.entity_id = Player.entity_id // WHERE x > 0 AND x <= 32 AND z > 0 AND z <= 32 - let q = query(data.player.clone()) + let q = query(player_source_expr) .with_join_inner( - data.location.clone(), + location_source_expr, player_entity_id.clone(), location_entity_id.clone(), ) @@ -418,7 +444,7 @@ pub mod tests { None, ); - let result = run_query(p, q.into()); + let result = run_query(p, q.into(), sources); let head = ProductType::from([("entity_id", AlgebraicType::U64), ("inventory_id", AlgebraicType::U64)]); let row1 = product!(100u64, 1u64); @@ -426,6 +452,11 @@ pub mod tests { assert_eq!(result.as_without_table_name(), input.as_without_table_name(), "Player"); + let mut sources = SourceSet::default(); + let player_source_expr = sources.add_mem_table(data.player); + let location_source_expr = sources.add_mem_table(data.location); + let inventory_source_expr = sources.add_mem_table(data.inv); + // SELECT // Inventory.* // FROM @@ -435,16 +466,16 @@ pub mod tests { // JOIN Location // ON Player.entity_id = Location.entity_id // WHERE x > 0 AND x <= 32 AND z > 0 AND z <= 32 - let q = query(data.inv) - .with_join_inner(data.player, inv_inventory_id.clone(), player_inventory_id) - .with_join_inner(data.location, player_entity_id, location_entity_id) + let q = query(inventory_source_expr) + .with_join_inner(player_source_expr, inv_inventory_id.clone(), player_inventory_id) + .with_join_inner(location_source_expr, player_entity_id, location_entity_id) .with_select_cmp(OpCmp::Gt, location_x.clone(), scalar(0.0f32)) .with_select_cmp(OpCmp::LtEq, location_x, scalar(32.0f32)) .with_select_cmp(OpCmp::Gt, location_z.clone(), scalar(0.0f32)) .with_select_cmp(OpCmp::LtEq, location_z, scalar(32.0f32)) .with_project(&[inv_inventory_id.into(), inv_name.into()], None); - let result = run_query(p, q.into()); + let result = run_query(p, q.into(), sources); let head = ProductType::from([("inventory_id", AlgebraicType::U64), ("name", AlgebraicType::String)]); let row1 = product!(1u64, "health"); diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index ec3685c13cd..911abf4baf2 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -10,12 +10,9 @@ use spacetimedb_sats::db::auth::{StAccess, StTableType}; use spacetimedb_sats::db::def::{TableDef, TableSchema}; use spacetimedb_sats::db::error::AuthError; use spacetimedb_sats::relation::{Column, DbTable, FieldExpr, FieldName, Header, Relation, RowCount}; -use spacetimedb_sats::satn::Satn; -use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; -use std::cmp::Ordering; +use spacetimedb_sats::ProductValue; use std::collections::{HashMap, VecDeque}; use std::fmt; -use std::hash::{Hash, Hasher}; use std::ops::Bound; use std::sync::Arc; @@ -184,7 +181,7 @@ impl From for ColumnOp { let columns = value.columns; assert_eq!(columns.len(), 1, "multi-column predicates are not yet supported"); - let field = table.head.fields[usize::from(columns.head())].field.clone(); + let field = table.head().fields[usize::from(columns.head())].field.clone(); match (value.lower_bound, value.upper_bound) { // Inclusive lower bound => field >= value (Bound::Included(value), Bound::Unbounded) => ColumnOp::Cmp { @@ -244,64 +241,145 @@ impl From for Option { } } -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, From)] -pub enum SourceExpr { - MemTable(MemTable), - DbTable(DbTable), -} +#[derive(Debug, PartialEq, Eq, Clone, Default)] +#[repr(transparent)] +/// A set of [`MemTable`]s referenced by a query plan by their [`SourceId`]. +/// +/// Rather than embedding [`MemTable`]s in query plans, we store a `SourceExpr::MemTable`, +/// which contains the information necessary for optimization along with a [`SourceId`]. +/// Query execution then executes the plan, and when it encounters a `SourceExpr::MemTable`, +/// retrieves the [`MemTable`] from the corresponding `SourceSet`. +/// This allows query plans to be re-used, though each execution requires a new `SourceSet`. +/// +/// Internally, the `SourceSet` stores an `Option` for each planned [`SourceId`]. +/// During execution, the VM will [`Option::take`] the [`MemTable`] to consume them. +/// This means that a query plan may not include multiple references to the same [`SourceId`]. +pub struct SourceSet(Vec>); + +impl SourceSet { + /// Get a fresh `SourceId` which can be used as the id for a new entry. + fn next_id(&self) -> SourceId { + SourceId(self.0.len()) + } + + /// Insert a [`MemTable`] into this `SourceSet` so it can be used in a query plan, + /// and return a [`SourceExpr`] which can be embedded in that plan. + pub fn add_mem_table(&mut self, table: MemTable) -> SourceExpr { + let source_id = self.next_id(); + let expr = SourceExpr::from_mem_table(&table, source_id); + self.0.push(Some(table)); + expr + } + + /// Extract the [`MemTable`] referred to by `id` from this `SourceSet`, + /// leaving a "gap" in its place. + /// + /// Subsequent calls to `take_mem_table` on the same `id` will return `None`. + pub fn take_mem_table(&mut self, id: SourceId) -> Option { + self.0.get_mut(id.0)?.take() + } -impl Hash for SourceExpr { - fn hash(&self, state: &mut H) { - // IMPORTANT: Required for hashing query plans. - // In general a query plan will only contain static data. - // However, currently it is possible to inline a virtual table. - // Such plans though are hybrids and should not be hashed, - // Since they contain raw data values. - // Therefore we explicitly disallow it here. - match self { - SourceExpr::DbTable(t) => { - t.hash(state); - } - SourceExpr::MemTable(_) => { - panic!("Cannot hash a virtual table"); - } + /// Resolve `source` to a `Table` for use in query execution. + /// + /// If the `source` is a [`SourceExpr::DbTable`], this simply clones the [`DbTable`] and returns it. + /// ([`DbTable::clone`] is inexpensive.) + /// In this case, `self` is not modified. + /// + /// If the `source` is a [`SourceExpr::MemTable`], this behaves like [`Self::take_mem_table`]. + /// Subsequent calls to `take_table` or `take_mem_table` with the same `source` will fail. + pub fn take_table(&mut self, source: &SourceExpr) -> Option { + match source { + SourceExpr::DbTable(db_table) => Some(Table::DbTable(db_table.clone())), + SourceExpr::MemTable { source_id, .. } => self.take_mem_table(*source_id).map(Table::MemTable), } } } +impl std::ops::Index for SourceSet { + type Output = Option; + + fn index(&self, idx: SourceId) -> &Option { + &self.0[idx.0] + } +} + +impl std::ops::IndexMut for SourceSet { + fn index_mut(&mut self, idx: SourceId) -> &mut Option { + &mut self.0[idx.0] + } +} + +/// An identifier for a data source (i.e. a table) in a query plan. +/// +/// When compiling a query plan, rather than embedding the inputs in the plan, +/// we annotate each input with a `SourceId`, and the compiled plan refers to its inputs by id. +/// This allows the plan to be re-used with distinct inputs, +/// assuming the inputs obey the same schema. +/// +/// Note that re-using a query plan is only a good idea +/// if the new inputs are similar to those used for compilation +/// in terms of cardinality and distribution. +#[derive(Debug, Copy, Clone, PartialEq, Eq, From, Hash)] +pub struct SourceId(pub usize); + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +/// A reference to a table within a query plan, +/// used as the source for selections, scans, filters and joins. +pub enum SourceExpr { + /// A plan for a "virtual" or projected table. + /// + /// The actual [`MemTable`] is not stored within the query plan; + /// rather, the `source_id` is an index with a corresponding [`SourceSet`] + /// which contains the [`MemTable`]. + /// + /// This allows query plans to be reused by supplying a new [`SourceSet`]. + MemTable { + source_id: SourceId, + header: Arc
, + table_type: StTableType, + table_access: StAccess, + row_count: RowCount, + }, + /// A plan for a database table. Because [`DbTable`] is small and efficiently cloneable, + /// no indirection into a [`SourceSet`] is required. + DbTable(DbTable), +} + impl SourceExpr { - pub fn get_db_table(&self) -> Option<&DbTable> { - match self { - SourceExpr::DbTable(x) => Some(x), - _ => None, + /// If `self` refers to a [`MemTable`], returns the [`SourceId`] for its location in the plan's [`SourceSet`]. + /// + /// Returns `None` if `self` refers to a [`DbTable`], as [`DbTable`]s are stored directly in the `SourceExpr`, + /// rather than indirected through the [`SourceSet`]. + pub fn source_id(&self) -> Option { + if let SourceExpr::MemTable { source_id, .. } = self { + Some(*source_id) + } else { + None } } pub fn table_name(&self) -> &str { - match self { - SourceExpr::MemTable(x) => &x.head.table_name, - SourceExpr::DbTable(x) => &x.head.table_name, - } + &self.head().table_name } pub fn table_type(&self) -> StTableType { match self { - SourceExpr::MemTable(_) => StTableType::User, - SourceExpr::DbTable(x) => x.table_type, + SourceExpr::MemTable { table_type, .. } => *table_type, + SourceExpr::DbTable(db_table) => db_table.table_type, } } pub fn table_access(&self) -> StAccess { match self { - SourceExpr::MemTable(x) => x.table_access, - SourceExpr::DbTable(x) => x.table_access, + SourceExpr::MemTable { table_access, .. } => *table_access, + SourceExpr::DbTable(db_table) => db_table.table_access, } } pub fn head(&self) -> &Arc
{ match self { - SourceExpr::MemTable(x) => &x.head, - SourceExpr::DbTable(x) => &x.head, + SourceExpr::MemTable { header, .. } => header, + SourceExpr::DbTable(db_table) => &db_table.head, } } @@ -311,38 +389,56 @@ impl SourceExpr { pub fn get_column_by_field<'a>(&'a self, field: &'a FieldName) -> Option<&Column> { self.head().column(field) } -} -impl Relation for SourceExpr { - fn head(&self) -> &Arc
{ - match self { - SourceExpr::MemTable(x) => x.head(), - SourceExpr::DbTable(x) => x.head(), + pub fn is_mem_table(&self) -> bool { + matches!(self, SourceExpr::MemTable { .. }) + } + + pub fn is_db_table(&self) -> bool { + matches!(self, SourceExpr::DbTable(_)) + } + + pub fn from_mem_table(mem_table: &MemTable, id: SourceId) -> Self { + SourceExpr::MemTable { + source_id: id, + header: mem_table.head.clone(), + table_type: StTableType::User, + table_access: mem_table.table_access, + row_count: RowCount::exact(mem_table.data.len()), } } - fn row_count(&self) -> RowCount { - match self { - SourceExpr::MemTable(x) => x.row_count(), - SourceExpr::DbTable(x) => x.row_count(), + pub fn table_id(&self) -> Option { + if let SourceExpr::DbTable(db_table) = self { + Some(db_table.table_id) + } else { + None } } -} -impl From
for SourceExpr { - fn from(value: Table) -> Self { - match value { - Table::MemTable(t) => SourceExpr::MemTable(t), - Table::DbTable(t) => SourceExpr::DbTable(t), + /// If `self` refers to a [`DbTable`], get a reference to it. + /// + /// Returns `None` if `self` refers to a [`MemTable`]. + /// In that case, retrieving the [`MemTable`] requires inspecting the plan's corresponding [`SourceSet`] + /// via [`SourceSet::take_mem_table`] or [`SourceSet::take_table`]. + pub fn get_db_table(&self) -> Option<&DbTable> { + if let SourceExpr::DbTable(db_table) = self { + Some(db_table) + } else { + None } } } -impl From for Table { - fn from(value: SourceExpr) -> Self { - match value { - SourceExpr::MemTable(t) => Table::MemTable(t), - SourceExpr::DbTable(t) => Table::DbTable(t), +impl Relation for SourceExpr { + fn head(&self) -> &Arc
{ + self.head() + } + + fn row_count(&self) -> RowCount { + match self { + SourceExpr::MemTable { row_count, .. } => *row_count, + SourceExpr::DbTable(_) => RowCount::unknown(), } } } @@ -358,22 +454,13 @@ impl From<&TableSchema> for SourceExpr { } } -impl From<&SourceExpr> for DbTable { - fn from(value: &SourceExpr) -> Self { - match value { - SourceExpr::MemTable(_) => unreachable!(), - SourceExpr::DbTable(t) => t.clone(), - } - } -} - // A descriptor for an index join operation. // The semantics are those of a semijoin with rows from the index or the probe side being returned. -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct IndexJoin { pub probe_side: QueryExpr, pub probe_field: FieldName, - pub index_side: Table, + pub index_side: SourceExpr, pub index_select: Option, pub index_col: ColId, pub return_index_rows: bool, @@ -382,7 +469,7 @@ pub struct IndexJoin { impl From for QueryExpr { fn from(join: IndexJoin) -> Self { let source: SourceExpr = if join.return_index_rows { - join.index_side.clone().into() + join.index_side.clone() } else { join.probe_side.source.clone() }; @@ -399,7 +486,7 @@ impl IndexJoin { // A delta table is a virtual table consisting of changes or updates to a physical table. pub fn reorder(self, row_count: impl Fn(TableId, &str) -> i64) -> Self { // The probe table must be a physical table. - if matches!(self.probe_side.source, SourceExpr::MemTable(_)) { + if self.probe_side.source.is_mem_table() { return self; } // It must have an index defined on the join field. @@ -424,18 +511,19 @@ impl IndexJoin { // The existence of this column has already been verified, // during construction of the index join. let probe_column = self.probe_side.source.head().column(&self.probe_field).unwrap().col_id; - match self.index_side { + match self.index_side.get_db_table() { // If the size of the indexed table is sufficiently large, // do not reorder. // // TODO: This determination is quite arbitrary. // Ultimately we should be using cardinality estimation. - Table::DbTable(DbTable { table_id, ref head, .. }) if row_count(table_id, &head.table_name) > 3000 => self, + Some(DbTable { head, table_id, .. }) if row_count(*table_id, &head.table_name) > 3000 => self, // If this is a delta table, we must reorder. // If this is a sufficiently small physical table, we should reorder. - table => { + _ => { // For the same reason the compiler also ensures this unwrap is safe. - let index_field = table + let index_field = self + .index_side .head() .fields .iter() @@ -457,11 +545,11 @@ impl IndexJoin { // Push any selections on the index side to the probe side. let probe_side = if let Some(predicate) = self.index_select { QueryExpr { - source: table.into(), + source: self.index_side, query: vec![predicate.into()], } } else { - table.into() + self.index_side.into() }; IndexJoin { // The new probe side consists of the updated rows. @@ -470,7 +558,7 @@ impl IndexJoin { // The new probe field is the previous index field. probe_field: index_field, // The original probe table is now the table that is being probed. - index_side: self.probe_side.source.into(), + index_side: self.probe_side.source, // Any selections from the original probe side are pulled above the index lookup. index_select: predicate, // The new index field is the previous probe field. @@ -502,8 +590,8 @@ impl IndexJoin { .map(|Column { field, .. }| field.into()) .collect(); - let table = self.index_side.get_db_table().map(|t| t.table_id); - let source = self.index_side.into(); + let table = self.index_side.table_id(); + let source = self.index_side; let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs)); let project = Query::Project(fields, table); let query = if let Some(predicate) = self.index_select { @@ -531,7 +619,7 @@ impl IndexJoin { .map(|Column { field, .. }| field.into()) .collect(); - let table = self.probe_side.source.get_db_table().map(|t| t.table_id); + let table = self.probe_side.source.table_id(); let source = self.probe_side.source; let inner_join = Query::JoinInner(JoinExpr::new(rhs, col_lhs, col_rhs)); let project = Query::Project(fields, table); @@ -541,7 +629,7 @@ impl IndexJoin { } } -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct JoinExpr { pub rhs: QueryExpr, pub col_lhs: FieldName, @@ -554,7 +642,7 @@ impl JoinExpr { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum DbType { Table, Index, @@ -562,7 +650,7 @@ pub enum DbType { Constraint, } -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum Crud { Query, Insert, @@ -604,8 +692,8 @@ impl CrudExpr { } } - pub fn is_reads(exprs: &[CrudExpr]) -> bool { - exprs.iter().all(|expr| matches!(expr, CrudExpr::Query(_))) + pub fn is_reads<'a>(exprs: impl IntoIterator) -> bool { + exprs.into_iter().all(|expr| matches!(expr, CrudExpr::Query(_))) } } @@ -617,65 +705,8 @@ pub struct IndexScan { pub upper_bound: Bound, } -impl PartialOrd for IndexScan { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for IndexScan { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - #[derive(Eq, PartialEq)] - struct RangeBound<'a, T: Ord>(&'a Bound); - - impl<'a, T: Ord> PartialOrd for RangeBound<'a, T> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl<'a, T: Ord> Ord for RangeBound<'a, T> { - fn cmp(&self, other: &Self) -> Ordering { - match (&self.0, &other.0) { - (Bound::Included(ref l), Bound::Included(ref r)) - | (Bound::Excluded(ref l), Bound::Excluded(ref r)) => l.cmp(r), - (Bound::Included(ref l), Bound::Excluded(ref r)) => match l.cmp(r) { - Ordering::Equal => Ordering::Less, - ord => ord, - }, - (Bound::Excluded(ref l), Bound::Included(ref r)) => match l.cmp(r) { - Ordering::Equal => Ordering::Greater, - ord => ord, - }, - (Bound::Unbounded, Bound::Unbounded) => Ordering::Equal, - (Bound::Unbounded, _) => Ordering::Less, - (_, Bound::Unbounded) => Ordering::Greater, - } - } - } - - let order = self.table.cmp(&other.table); - let Ordering::Equal = order else { - return order; - }; - - let order = self.columns.cmp(&other.columns); - let Ordering::Equal = order else { - return order; - }; - - match ( - RangeBound(&self.lower_bound).cmp(&RangeBound(&other.lower_bound)), - RangeBound(&self.upper_bound).cmp(&RangeBound(&other.upper_bound)), - ) { - (Ordering::Equal, ord) => ord, - (ord, _) => ord, - } - } -} - // An individual operation in a query. -#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord, From, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, From, Hash)] pub enum Query { // Fetching rows via an index. IndexScan(IndexScan), @@ -703,7 +734,7 @@ impl Query { pub fn sources(&self) -> QuerySources { match self { Self::Select(..) | Self::Project(..) => QuerySources::None, - Self::IndexScan(scan) => QuerySources::One(Some(scan.table.clone().into())), + Self::IndexScan(scan) => QuerySources::One(Some(SourceExpr::DbTable(scan.table.clone()))), Self::IndexJoin(join) => QuerySources::Expr(join.probe_side.sources()), Self::JoinInner(join) => QuerySources::Expr(join.rhs.sources()), } @@ -789,36 +820,15 @@ fn is_sargable(table: &SourceExpr, op: &ColumnOp) -> Option { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct QueryExpr { pub source: SourceExpr, pub query: Vec, } -impl From for QueryExpr { - fn from(value: MemTable) -> Self { - QueryExpr { - source: value.into(), - query: vec![], - } - } -} - -impl From for QueryExpr { - fn from(value: DbTable) -> Self { - QueryExpr { - source: value.into(), - query: vec![], - } - } -} - -impl From
for QueryExpr { - fn from(value: Table) -> Self { - QueryExpr { - source: value.into(), - query: vec![], - } +impl From for QueryExpr { + fn from(source: SourceExpr) -> Self { + QueryExpr { source, query: vec![] } } } @@ -862,18 +872,13 @@ impl QueryExpr { /// Does this query read from a given table? pub fn reads_from_table(&self, id: &TableId) -> bool { - self.source - .get_db_table() - .is_some_and(|DbTable { table_id, .. }| table_id == id) + self.source.table_id() == Some(*id) || self.query.iter().any(|q| match q { Query::Select(_) | Query::Project(_, _) => false, Query::IndexScan(scan) => scan.table.table_id == *id, Query::JoinInner(join) => join.rhs.reads_from_table(id), Query::IndexJoin(join) => { - join.index_side - .get_db_table() - .is_some_and(|DbTable { table_id, .. }| table_id == id) - || join.probe_side.reads_from_table(id) + join.index_side.table_id() == Some(*id) || join.probe_side.reads_from_table(id) } }) } @@ -897,15 +902,12 @@ impl QueryExpr { Query::JoinInner(JoinExpr { rhs: QueryExpr { - source: - SourceExpr::DbTable(DbTable { - table_id: rhs_table_id, .. - }), + source: SourceExpr::DbTable(ref db_table), .. }, .. - }) if table.table_id != rhs_table_id => { - self = self.with_index_eq(table, columns, value); + }) if table.table_id != db_table.table_id => { + self = self.with_index_eq(db_table.clone(), columns, value); self.query.push(query); self } @@ -975,14 +977,11 @@ impl QueryExpr { Query::JoinInner(JoinExpr { rhs: QueryExpr { - source: - SourceExpr::DbTable(DbTable { - table_id: rhs_table_id, .. - }), + source: SourceExpr::DbTable(ref db_table), .. }, .. - }) if table.table_id != rhs_table_id => { + }) if table.table_id != db_table.table_id => { self = self.with_index_lower_bound(table, columns, value, inclusive); self.query.push(query); self @@ -1083,14 +1082,11 @@ impl QueryExpr { Query::JoinInner(JoinExpr { rhs: QueryExpr { - source: - SourceExpr::DbTable(DbTable { - table_id: rhs_table_id, .. - }), + source: SourceExpr::DbTable(ref db_table), .. }, .. - }) if table.table_id != rhs_table_id => { + }) if table.table_id != db_table.table_id => { self = self.with_index_upper_bound(table, columns, value, inclusive); self.query.push(query); self @@ -1270,10 +1266,11 @@ impl QueryExpr { return query; } - let Some(table) = query.source.get_db_table().cloned() else { + // If the source is a `MemTable`, it doesn't have any indexes, + // so we can't plan an index join. + let Some(source_table_id) = query.source.table_id() else { return query; }; - let source = query.source; let second = query.query.pop().unwrap(); let first = query.query.pop().unwrap(); @@ -1292,15 +1289,15 @@ impl QueryExpr { col_lhs: index_field, col_rhs: probe_field, }) => { - if !probe_side.query.is_empty() && wildcard_table_id == table.table_id { + if !probe_side.query.is_empty() && wildcard_table_id == source_table_id { // An applicable join must have an index defined on the correct field. - if let Some(col) = table.head.column(&index_field) { + if let Some(col) = source.head().column(&index_field) { let index_col = col.col_id; - if table.head().has_constraint(&index_field, Constraints::indexed()) { + if source.head().has_constraint(&index_field, Constraints::indexed()) { let index_join = IndexJoin { probe_side, probe_field, - index_side: table.into(), + index_side: source.clone(), index_select: None, index_col, return_index_rows: true, @@ -1338,7 +1335,9 @@ impl QueryExpr { match is_sargable(schema, op) { // found sargable equality condition for one of the table schemas Some(IndexArgument::Eq { col_id, value }) => { - q = q.with_index_eq(schema.into(), col_id.into(), value); + // `unwrap` here is infallible because `is_sargable(schema, op)` implies `schema.is_db_table` + // for any `op`. + q = q.with_index_eq(schema.get_db_table().unwrap().clone(), col_id.into(), value); continue 'outer; } // found sargable range condition for one of the table schemas @@ -1347,7 +1346,14 @@ impl QueryExpr { value, inclusive, }) => { - q = q.with_index_lower_bound(schema.into(), col_id.into(), value, inclusive); + q = q.with_index_lower_bound( + // `unwrap` here is infallible because `is_sargable(schema, op)` implies `schema.is_db_table` + // for any `op`. + schema.get_db_table().unwrap().clone(), + col_id.into(), + value, + inclusive, + ); continue 'outer; } // found sargable range condition for one of the table schemas @@ -1356,7 +1362,14 @@ impl QueryExpr { value, inclusive, }) => { - q = q.with_index_upper_bound(schema.into(), col_id.into(), value, inclusive); + q = q.with_index_upper_bound( + // `unwrap` here is infallible because `is_sargable(schema, op)` implies `schema.is_db_table` + // for any `op`. + schema.get_db_table().unwrap().clone(), + col_id.into(), + value, + inclusive, + ); continue 'outer; } None => {} @@ -1469,21 +1482,12 @@ impl From for Expr { } } -pub(crate) fn fmt_value(ty: &AlgebraicType, val: &AlgebraicValue) -> String { - let ts = Typespace::new(vec![]); - WithTypespace::new(&ts, ty).with_value(val).to_satn() -} - impl fmt::Display for SourceExpr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SourceExpr::MemTable(x) => { - let ty = &AlgebraicType::Product(x.head().ty()); - for row in &x.data { - let x = fmt_value(ty, &row.clone().into()); - write!(f, "{x}")?; - } - Ok(()) + SourceExpr::MemTable { header, source_id, .. } => { + let ty = AlgebraicType::Product(header.ty()); + write!(f, "SourceExpr({source_id:?} => virtual {ty:?})") } SourceExpr::DbTable(x) => { write!(f, "DbTable({})", x.table_id) @@ -1525,20 +1529,38 @@ impl fmt::Display for Query { } #[derive(Debug, Clone, PartialEq, Eq)] +// TODO(bikeshedding): Refactor this struct so that `IndexJoin`s replace the `table`, +// rather than appearing as the first element of the `query`. +// +// `IndexJoin`s do not behave like filters; in fact they behave more like data sources. +// A query conceptually starts with either a single table or an `IndexJoin`, +// and then stacks a set of filters on top of that. pub struct QueryCode { - pub table: Table, + pub table: SourceExpr, pub query: Vec, } impl From for QueryCode { fn from(value: QueryExpr) -> Self { QueryCode { - table: value.source.into(), + table: value.source, query: value.query, } } } +impl AuthAccess for SourceExpr { + fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { + if owner == caller || self.table_access() == StAccess::Public { + return Ok(()); + } + + Err(AuthError::TablePrivate { + named: self.table_name().to_string(), + }) + } +} + impl AuthAccess for Table { fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { if owner == caller || self.table_access() == StAccess::Public { @@ -1579,7 +1601,7 @@ impl Relation for QueryCode { pub enum CrudCode { Query(QueryCode), Insert { - table: Table, + table: SourceExpr, rows: Vec, }, Update { @@ -1669,8 +1691,6 @@ impl From for CodeResult { #[cfg(test)] mod tests { - use crate::relation::{MemTable, Table}; - use super::*; const ALICE: Identity = Identity::from_byte_array([1; 32]); @@ -1679,18 +1699,20 @@ mod tests { // TODO(kim): Should better do property testing here, but writing generators // on recursive types (ie. `Query` and friends) is tricky. - fn tables() -> [Table; 2] { + fn tables() -> [SourceExpr; 2] { [ - Table::MemTable(MemTable { - head: Arc::new(Header { + SourceExpr::MemTable { + source_id: SourceId(0), + header: Arc::new(Header { table_name: "foo".into(), fields: vec![], constraints: Default::default(), }), - data: vec![], + row_count: RowCount::unknown(), + table_type: StTableType::User, table_access: StAccess::Private, - }), - Table::DbTable(DbTable { + }, + SourceExpr::DbTable(DbTable { head: Arc::new(Header { table_name: "foo".into(), fields: vec![], @@ -1704,14 +1726,12 @@ mod tests { } fn queries() -> impl IntoIterator { - let [Table::MemTable(mem_table), Table::DbTable(db_table)] = tables() else { - unreachable!() - }; + let [mem_table, db_table] = tables(); // Skip `Query::Select` and `QueryProject` -- they don't have table // information [ Query::IndexScan(IndexScan { - table: db_table, + table: db_table.get_db_table().unwrap().clone(), columns: ColList::new(42.into()), lower_bound: Bound::Included(22.into()), upper_bound: Bound::Unbounded, @@ -1722,7 +1742,7 @@ mod tests { table: "foo".into(), field: "bar".into(), }, - index_side: Table::DbTable(DbTable { + index_side: SourceExpr::DbTable(DbTable { head: Arc::new(Header { table_name: "bar".into(), fields: vec![], @@ -1752,10 +1772,7 @@ mod tests { fn query_codes() -> impl IntoIterator { tables().map(|table| { - let expr = match table { - Table::DbTable(table) => QueryExpr::from(table), - Table::MemTable(table) => QueryExpr::from(table), - }; + let expr = QueryExpr::from(table); let mut code = QueryCode::from(expr); code.query = queries().into_iter().collect(); code @@ -1775,9 +1792,8 @@ mod tests { assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired))); } - fn mem_table(name: &str, fields: &[(&str, AlgebraicType, bool)]) -> MemTable { + fn mem_table(name: &str, fields: &[(&str, AlgebraicType, bool)]) -> SourceExpr { let table_access = StAccess::Public; - let data = Vec::new(); let head = Header::new( name.into(), fields @@ -1792,10 +1808,12 @@ mod tests { .map(|(i, _)| (ColId(i as u32).into(), Constraints::indexed())) .collect(), ); - MemTable { - head: Arc::new(head), - data, + SourceExpr::MemTable { + source_id: SourceId(0), + header: Arc::new(head), + row_count: RowCount::unknown(), table_access, + table_type: StTableType::User, } } @@ -1810,7 +1828,7 @@ mod tests { &[("c", AlgebraicType::U8, false), ("b", AlgebraicType::U8, true)], ); - let probe_field = probe_side.head.fields[1].field.clone(); + let probe_field = probe_side.head().fields[1].field.clone(); let select_field = FieldName::Name { table: "index".into(), field: "a".into(), @@ -1819,7 +1837,7 @@ mod tests { let join = IndexJoin { probe_side: probe_side.clone().into(), probe_field, - index_side: index_side.clone().into(), + index_side: index_side.clone(), index_select: Some(index_select.clone()), index_col: 1.into(), return_index_rows: false, @@ -1827,7 +1845,7 @@ mod tests { let expr = join.to_inner_join(); - assert_eq!(expr.source, SourceExpr::MemTable(probe_side)); + assert_eq!(expr.source, probe_side); assert_eq!(expr.query.len(), 2); let Query::JoinInner(ref join) = expr.query[0] else { @@ -1839,7 +1857,7 @@ mod tests { assert_eq!( join.rhs, QueryExpr { - source: SourceExpr::MemTable(index_side), + source: index_side, query: vec![index_select.into()] } ); diff --git a/crates/vm/src/program.rs b/crates/vm/src/program.rs index 05237bf062a..dbb50fde4be 100644 --- a/crates/vm/src/program.rs +++ b/crates/vm/src/program.rs @@ -4,10 +4,10 @@ use crate::errors::ErrorVm; use crate::eval::{build_query, IterRows}; -use crate::expr::{Code, CrudCode}; +use crate::expr::{Code, CrudCode, SourceSet}; use crate::iterators::RelIter; use crate::rel_ops::RelOps; -use crate::relation::{MemTable, Table}; +use crate::relation::MemTable; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Address; use spacetimedb_sats::relation::Relation; @@ -25,7 +25,7 @@ pub trait ProgramVm { /// Allows to execute the query with the state carried by the implementation of this /// trait - fn eval_query(&mut self, query: CrudCode) -> Result; + fn eval_query(&mut self, query: CrudCode, sources: &mut SourceSet) -> Result; } pub struct ProgramStore

{ @@ -64,20 +64,22 @@ impl ProgramVm for Program { } #[tracing::instrument(skip_all)] - fn eval_query(&mut self, query: CrudCode) -> Result { + fn eval_query(&mut self, query: CrudCode, sources: &mut SourceSet) -> Result { match query { CrudCode::Query(query) => { let head = query.head().clone(); let row_count = query.row_count(); let table_access = query.table.table_access(); - let result = match query.table { - Table::MemTable(x) => Box::new(RelIter::new(head, row_count, x)) as Box>, - Table::DbTable(_) => { - panic!("DB not set") - } + let result = if let Some(source_id) = query.table.source_id() { + let Some(result_table) = sources.take_mem_table(source_id) else { + panic!("Query plan specifies a `MemTable` for {source_id:?}, but found a `DbTable` or nothing"); + }; + Box::new(RelIter::new(head, row_count, result_table)) as Box> + } else { + panic!("DB not set") }; - let result = build_query(result, query.query)?; + let result = build_query(result, query.query, sources)?; let head = result.head().clone(); let rows: Vec<_> = result.collect_vec(|row| row.into_product_value())?;