Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/core/src/host/module_host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,9 @@ impl ModuleHost {
reducer_name: &str,
args: ReducerArgs,
) -> Result<ReducerCallResult, ReducerCallError> {
if reducer_name.starts_with("__") && reducer_name.ends_with("__") {
return Err(ReducerCallError::NoSuchReducer);
}
let res = self
.call_reducer_inner(
caller_identity,
Expand Down
12 changes: 9 additions & 3 deletions crates/core/src/host/wasm_common/module_host_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::time::Duration;

use spacetimedb_lib::buffer::DecodeError;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::{bsatn, Address, ModuleDef, TableDesc};
use spacetimedb_lib::{bsatn, Address, ModuleDef, ModuleValidationError, TableDesc};
use spacetimedb_vm::expr::CrudExpr;

use super::instrumentation::CallTimes;
Expand Down Expand Up @@ -90,6 +90,8 @@ pub(crate) struct WasmModuleHostActor<T: WasmModule> {
pub enum InitializationError {
#[error(transparent)]
Validation(#[from] ValidationError),
#[error(transparent)]
ModuleValidation(#[from] ModuleValidationError),
#[error("setup function returned an error: {0}")]
Setup(Box<str>),
#[error("wasm trap while calling {func:?}")]
Expand Down Expand Up @@ -153,7 +155,8 @@ impl<T: WasmModule> WasmModuleHostActor<T> {
)?;

let desc = instance.extract_descriptions()?;
let desc = bsatn::from_slice(&desc).map_err(DescribeError::Decode)?;
let desc: ModuleDef = bsatn::from_slice(&desc).map_err(DescribeError::Decode)?;
desc.validate_reducers()?;
let ModuleDef {
mut typespace,
mut tables,
Expand All @@ -173,7 +176,10 @@ impl<T: WasmModule> WasmModuleHostActor<T> {
tables
.into_iter()
.map(|x| (x.schema.table_name.clone(), EntityDef::Table(x))),
reducers.iter().map(|x| (x.name.clone(), EntityDef::Reducer(x.clone()))),
reducers
.iter()
.filter(|r| !(r.name.starts_with("__") && r.name.ends_with("__")))
.map(|x| (x.name.clone(), EntityDef::Reducer(x.clone()))),
)
.collect();
let reducers = ReducersMap(reducers.into_iter().map(|x| (x.name.clone(), x)).collect());
Expand Down
32 changes: 32 additions & 0 deletions crates/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,35 @@ pub struct TypeAlias {
pub name: String,
pub ty: sats::AlgebraicTypeRef,
}

impl ModuleDef {
pub fn validate_reducers(&self) -> Result<(), ModuleValidationError> {
for reducer in &self.reducers {
match &*reducer.name {
// in the future, these should maybe be flagged as lifecycle reducers by a MiscModuleExport
// or something, rather than by magic names
"__init__" => {}
"__identity_connected__" | "__identity_disconnected__" | "__update__" | "__migrate__" => {
if !reducer.args.is_empty() {
return Err(ModuleValidationError::InvalidLifecycleReducer {
reducer: reducer.name.clone(),
});
}
}
name if name.starts_with("__") && name.ends_with("__") => {
return Err(ModuleValidationError::UnknownDunderscore)
}
_ => {}
}
}
Ok(())
}
}

#[derive(thiserror::Error, Debug)]
pub enum ModuleValidationError {
#[error("lifecycle reducer {reducer:?} has invalid signature")]
InvalidLifecycleReducer { reducer: Box<str> },
#[error("reducers with double-underscores at the start and end of their names are not allowed")]
UnknownDunderscore,
}
12 changes: 12 additions & 0 deletions smoketests/tests/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ def test_private_table(self):
with self.assertRaises(Exception):
self.spacetime("sql", self.address, "select * from _Secret")


class LifecycleReducers(Smoketest):
def test_lifecycle_reducers_cant_be_called(self):
"""Ensure that reducers like __init__ can't be called"""

with self.assertRaises(Exception):
self.call("__init__")
with self.assertRaises(Exception):
self.call("__identity_connected__")
with self.assertRaises(Exception):
self.call("__identity_disconnected__")