Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ pub struct ClientActorId {

impl ClientActorId {
#[cfg(test)]
pub fn for_test(identity: Identity) -> Self {
pub fn for_test(identity: Option<Identity>, address: Option<Address>) -> Self {
ClientActorId {
identity,
address: Address::ZERO,
identity: identity.unwrap_or(Identity::from_byte_array(rand::random())),
address: address.unwrap_or(Address::from_arr(&rand::random())),
name: ClientName(0),
}
}
Expand Down
25 changes: 16 additions & 9 deletions crates/core/src/client/client_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,27 @@ pub enum ClientSendError {
}

impl ClientConnectionSender {
pub fn dummy(id: ClientActorId, protocol: Protocol) -> Self {
let (sendtx, _) = mpsc::channel(1);
pub fn dummy_with_channel(id: ClientActorId, protocol: Protocol) -> (Self, mpsc::Receiver<SerializableMessage>) {
let (sendtx, rx) = mpsc::channel(10);
// just make something up, it doesn't need to be attached to a real task
let abort_handle = match tokio::runtime::Handle::try_current() {
Ok(h) => h.spawn(async {}).abort_handle(),
Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
};
Self {
id,
protocol,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
}
(
Self {
id,
protocol,
sendtx,
abort_handle,
cancelled: AtomicBool::new(false),
},
rx,
)
}

pub fn dummy(id: ClientActorId, protocol: Protocol) -> Self {
Self::dummy_with_channel(id, protocol).0
}

pub fn send_message(&self, message: impl Into<SerializableMessage>) -> Result<(), ClientSendError> {
Expand Down
117 changes: 115 additions & 2 deletions crates/core/src/subscription/module_subscription_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,15 @@ impl ModuleSubscriptions {
#[cfg(test)]
mod tests {
use super::ModuleSubscriptions;
use crate::client::messages::{SerializableMessage, SubscriptionUpdate, TransactionUpdateMessage};
use crate::client::{ClientActorId, ClientConnectionSender, Protocol};
use crate::db::relational_db::tests_utils::TestDB;
use crate::energy::EnergyQuanta;
use crate::execution_context::ExecutionContext;
use crate::host::module_host::{
DatabaseTableUpdate, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall, ProtocolDatabaseUpdate,
};
use crate::host::{ArgsTuple, Timestamp};
use spacetimedb_client_api_messages::client_api::Subscribe;
use spacetimedb_lib::{error::ResultTest, AlgebraicType, Identity};
use spacetimedb_sats::product;
Expand All @@ -204,8 +210,8 @@ mod tests {
})?;

let id = Identity::ZERO;
let client = ClientActorId::for_test(id);
let sender = Arc::new(ClientConnectionSender::dummy(client, Protocol::Binary));
let client_id = ClientActorId::for_test(None, None);
let sender = Arc::new(ClientConnectionSender::dummy(client_id, Protocol::Binary));
let module_subscriptions = ModuleSubscriptions::new(db.clone(), id);

let (send, mut recv) = mpsc::unbounded_channel();
Expand Down Expand Up @@ -250,4 +256,111 @@ mod tests {

Ok(())
}

#[test]
/// checks if multiple clients with the same identity are properly handled
fn test_subscriptions_for_the_same_client_identity() -> ResultTest<()> {
let test_db = TestDB::durable()?;
let runtime = test_db.runtime().cloned().unwrap();

// Create table with no rows
let db = Arc::new(test_db.db.clone());
let table_id = db.create_table_for_test("T", &[("a", AlgebraicType::U8)], &[])?;

let id = ClientActorId::for_test(None, None);
let sender = Arc::new(ClientConnectionSender::dummy(id, Protocol::Binary));
let module_subscriptions = ModuleSubscriptions::new(db.clone(), id.identity);

let client_id0 = ClientActorId::for_test(None, None);
let client_id1 = ClientActorId::for_test(Some(client_id0.identity), None);
let (client0, mut rx0) = ClientConnectionSender::dummy_with_channel(client_id0, Protocol::Binary);
let (client1, mut rx1) = ClientConnectionSender::dummy_with_channel(client_id1, Protocol::Binary);

// Subscribing to T should return a single row
let query_strings = vec!["select * from T where a = 1".into()];
module_subscriptions
.add_subscriber(
Arc::new(client0),
Subscribe {
query_strings,
request_id: 0,
},
Instant::now(),
None,
)
.unwrap();

let query_strings = vec!["select * from T where a = 2".into()];
module_subscriptions
.add_subscriber(
Arc::new(client1),
Subscribe {
query_strings,
request_id: 1,
},
Instant::now(),
None,
)
.unwrap();

let inserts = Arc::new([product!(1u8), product!(2u8), product!(2u8)]);
let table_update = DatabaseTableUpdate {
table_id,
table_name: Box::from("T"),
inserts,
deletes: Arc::new([]),
};
let database_update = DatabaseUpdate {
tables: vec![table_update],
};
let event = Arc::new(ModuleEvent {
timestamp: Timestamp::now(),
caller_identity: client_id0.identity,
caller_address: None,
function_call: ModuleFunctionCall {
reducer: "DummyReducer".into(),
args: ArgsTuple::nullary(),
},
status: EventStatus::Committed(database_update),
energy_quanta_used: EnergyQuanta::ZERO,
host_execution_duration: Duration::default(),
request_id: None,
timer: None,
});

runtime.block_on(async move {
tokio::task::block_in_place(move || {
let subscriptions = module_subscriptions.subscriptions.read();
module_subscriptions.blocking_broadcast_event(Some(&sender), &subscriptions, event);
});
tokio::time::sleep(Duration::from_secs(4)).await;
tokio::time::timeout(Duration::from_millis(100), async move {
rx0.recv().await.expect("Expected subscription update message");
let m0 = rx0.recv().await.expect("Expected transaction update message");
rx1.recv().await.expect("Expected subscription update message");
let m1 = rx1.recv().await.expect("Expected transaction update message");

// check if the first client got the update with only 1 row and the second client
// got the update with 2 rows
assert!(matches!(m0,
SerializableMessage::ProtocolUpdate(
TransactionUpdateMessage {
database_update: SubscriptionUpdate {
database_update: ProtocolDatabaseUpdate { tables, .. },
..},
..}) if tables.clone().left().unwrap()[0].table_row_operations.len() == 1));
assert!(matches!(m1,
SerializableMessage::ProtocolUpdate(
TransactionUpdateMessage {
database_update: SubscriptionUpdate {
database_update: ProtocolDatabaseUpdate { tables, .. },
..},
..}) if tables.clone().left().unwrap()[0].table_row_operations.len() == 2));
})
.await
.unwrap();
});

Ok(())
}
}