diff --git a/changelog.d/19038.feature b/changelog.d/19038.feature new file mode 100644 index 00000000000..13a82e7c061 --- /dev/null +++ b/changelog.d/19038.feature @@ -0,0 +1 @@ +Add more support for MSC4140, namely the ability to inspect sent, cancelled, or failed delayed events, aka "finalised" delayed events. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index e83c0de5a47..3dcb0c28cfa 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -58,6 +58,7 @@ from synapse.storage.databases.main import FilteringWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore +from synapse.storage.databases.main.delayed_events import DelayedEventsStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore @@ -272,6 +273,7 @@ class Store( RelationsWorkerStore, EventFederationWorkerStore, SlidingSyncStore, + DelayedEventsStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f82e8572f22..8d0a33606c6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -551,6 +551,21 @@ def read_config( # MSC4133: Custom profile fields self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False) + # MSC4140: How many delayed events a user is allowed to have scheduled at a time. + self.msc4140_max_delayed_events_per_user = experimental.get( + "msc4140_max_delayed_events_per_user", 100 + ) + + # MSC4140: How long to keep finalised delayed events in the database before deleting them. + self.msc4140_finalised_retention_period = self.parse_duration( + config.get("msc4140_finalised_retention_period", "7d") + ) + + # MSC4140: How many finalised delayed events to keep per user before deleting them. + self.msc4140_finalised_per_user_retention_limit = experimental.get( + "msc4140_finalised_per_user_retention_limit", 1000 + ) + # MSC4143: Matrix RTC Transport using Livekit Backend self.msc4143_enabled: bool = experimental.get("msc4143_enabled", False) diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index b89b7416e63..d22966c5a70 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -18,12 +18,13 @@ from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes -from synapse.api.errors import ShadowBanError, SynapseError +from synapse.api.errors import ShadowBanError, SynapseError, cs_error from synapse.api.ratelimiting import Ratelimiter from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.delayed_events import ( ReplicationAddedDelayedEventRestServlet, ) @@ -43,6 +44,7 @@ UserID, create_requester, ) +from synapse.util.constants import MILLISECONDS_PER_SECOND, ONE_MINUTE_SECONDS from synapse.util.events import generate_fake_event_id from synapse.util.metrics import Measure from synapse.util.sentinel import Sentinel @@ -125,6 +127,12 @@ async def _schedule_db_events() -> None: else: self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._prune_finalised_events, + 5 * ONE_MINUTE_SECONDS * MILLISECONDS_PER_SECOND, + ) + @property def _is_master(self) -> bool: return self._repl_client is None @@ -132,7 +140,7 @@ def _is_master(self) -> bool: def notify_new_event(self) -> None: """ Called when there may be more state event deltas to process, - which should cancel pending delayed events for the same state. + which should cancel scheduled delayed events for the same state. """ if self._event_processing: return @@ -156,8 +164,7 @@ async def _unsafe_process_new_event(self) -> None: room_max_stream_ordering = self._store.get_room_max_stream_ordering() # Check that there are actually any delayed events to process. If not, bail early. - delayed_events_count = await self._store.get_count_of_delayed_events() - if delayed_events_count == 0: + if not await self._store.has_scheduled_delayed_events(): # There are no delayed events to process. Update the # `delayed_events_stream_pos` to the latest `events` stream pos and # exit early. @@ -228,7 +235,7 @@ async def _unsafe_process_new_event(self) -> None: async def _handle_state_deltas(self, deltas: list[StateDelta]) -> None: """ - Process current state deltas to cancel other users' pending delayed events + Process current state deltas to cancel other users' scheduled delayed events that target the same state. """ # Get the senders of each delta's state event (as sender information is @@ -316,11 +323,20 @@ async def _handle_state_deltas(self, deltas: list[StateDelta]) -> None: if sender.domain == self._config.server.server_name else "" ), + finalised_ts=self._get_current_ts(), ) if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) + @wrap_as_background_process("_prune_finalised_events") + async def _prune_finalised_events(self) -> None: + await self._store.prune_finalised_delayed_events( + self._get_current_ts(), + self.hs.config.experimental.msc4140_finalised_retention_period, + self.hs.config.experimental.msc4140_finalised_per_user_retention_limit, + ) + async def add( self, requester: Requester, @@ -380,6 +396,7 @@ async def add( origin_server_ts=origin_server_ts, content=content, delay=delay, + limit=self.hs.config.experimental.msc4140_max_delayed_events_per_user, ) if self._repl_client is not None: @@ -420,6 +437,7 @@ async def cancel(self, requester: Requester, delay_id: str) -> None: next_send_ts = await self._store.cancel_delayed_event( delay_id=delay_id, user_localpart=requester.user.localpart, + finalised_ts=self._get_current_ts(), ) if self._next_send_ts_changed(next_send_ts): @@ -477,18 +495,19 @@ async def send(self, requester: Requester, delay_id: str) -> None: if self._next_send_ts_changed(next_send_ts): self._schedule_next_at_or_none(next_send_ts) - await self._send_event( - DelayedEventDetails( - delay_id=DelayID(delay_id), - user_localpart=UserLocalpart(requester.user.localpart), - room_id=event.room_id, - type=event.type, - state_key=event.state_key, - origin_server_ts=event.origin_server_ts, - content=event.content, - device_id=event.device_id, + if event: + await self._send_event( + DelayedEventDetails( + delay_id=DelayID(delay_id), + user_localpart=UserLocalpart(requester.user.localpart), + room_id=event.room_id, + type=event.type, + state_key=event.state_key, + origin_server_ts=event.origin_server_ts, + content=event.content, + device_id=event.device_id, + ) ) - ) async def _send_on_timeout(self) -> None: self._next_delayed_event_call = None @@ -513,17 +532,19 @@ async def _send_events(self, events: list[DelayedEventDetails]) -> None: state_info = None try: # TODO: send in background if message event or non-conflicting state event - await self._send_event(event) + finalised_ts = await self._send_event(event) if state_info is not None: sent_state.add(state_info) except Exception: logger.exception("Failed to send delayed event") + finalised_ts = self._get_current_ts() for room_id, event_type, state_key in sent_state: - await self._store.delete_processed_delayed_state_events( + await self._store.finalise_processed_delayed_state_events( room_id=str(room_id), event_type=event_type, state_key=state_key, + finalised_ts=finalised_ts, ) def _schedule_next_at_or_none(self, next_send_ts: Optional[Timestamp]) -> None: @@ -547,21 +568,49 @@ def _schedule_next_at(self, next_send_ts: Timestamp) -> None: else: self._next_delayed_event_call.reset(delay_sec) - async def get_all_for_user(self, requester: Requester) -> list[JsonDict]: - """Return all pending delayed events requested by the given user.""" + async def get_delayed_events_for_user( + self, + requester: Requester, + delay_ids: Optional[list[str]], + get_scheduled: bool, + get_finalised: bool, + ) -> dict[str, list[JsonDict]]: + """ + Return all scheduled delayed events for the given user. + + Args: + requester: The user whose delayed events to get. + delay_ids: The IDs of the delayed events to get, or None to get all of them. + get_scheduled: Whether to look up scheduled delayed events. + get_finalised: Whether to look up finalised delayed events. + """ await self._delayed_event_mgmt_ratelimiter.ratelimit( requester, (requester.user.to_string(), requester.device_id), ) - return await self._store.get_all_delayed_events_for_user( - requester.user.localpart - ) + + # TODO: Support Pagination stream API + ret = {} + if get_scheduled: + ret["scheduled"] = await self._store.get_scheduled_delayed_events_for_user( + requester.user.localpart, + delay_ids, + ) + if get_finalised: + ret["finalised"] = await self._store.get_finalised_delayed_events_for_user( + requester.user.localpart, + delay_ids, + self._get_current_ts(), + self.hs.config.experimental.msc4140_finalised_retention_period, + self.hs.config.experimental.msc4140_finalised_per_user_retention_limit, + ) + return ret async def _send_event( self, event: DelayedEventDetails, txn_id: Optional[str] = None, - ) -> None: + ) -> Timestamp: user_id = UserID(event.user_localpart, self._config.server.server_name) user_id_str = user_id.to_string() # Create a new requester from what data is currently available @@ -571,6 +620,7 @@ async def _send_event( device_id=event.device_id, ) + finalised_ts = None try: if event.state_key is not None and event.type == EventTypes.Member: membership = event.content.get("membership") @@ -606,18 +656,34 @@ async def _send_event( txn_id=txn_id, ) event_id = sent_event.event_id + if event.origin_server_ts is None: + finalised_ts = Timestamp(sent_event.origin_server_ts) except ShadowBanError: event_id = generate_fake_event_id() + send_error = None + except SynapseError as e: + send_error = e.error_dict(None) + except Exception: + send_error = cs_error("Internal server error") + else: + send_error = None finally: # TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure + if finalised_ts is None: + finalised_ts = self._get_current_ts() try: - await self._store.delete_processed_delayed_event( - event.delay_id, event.user_localpart + await self._store.finalise_processed_delayed_event( + event.delay_id, + event.user_localpart, + send_error or event_id, + finalised_ts, ) except Exception: - logger.exception("Failed to delete processed delayed event") + logger.exception("Failed to finalise processed delayed event") - set_tag("event_id", event_id) + if send_error is None: + set_tag("event_id", event_id) + return finalised_ts def _get_current_ts(self) -> Timestamp: return Timestamp(self._clock.time_msec()) diff --git a/synapse/rest/client/delayed_events.py b/synapse/rest/client/delayed_events.py index 80abacbc9d6..0f6a7572427 100644 --- a/synapse/rest/client/delayed_events.py +++ b/synapse/rest/client/delayed_events.py @@ -21,7 +21,12 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string_from_args, + parse_strings_from_args, +) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.types import JsonDict @@ -38,6 +43,11 @@ class _UpdateDelayedEventAction(Enum): SEND = "send" +class _DelayedEventStatus(Enum): + SCHEDULED = "scheduled" + FINALISED = "finalised" + + class UpdateDelayedEventServlet(RestServlet): PATTERNS = client_patterns( r"/org\.matrix\.msc4140/delayed_events/(?P[^/]+)$", @@ -97,10 +107,27 @@ def __init__(self, hs: "HomeServer"): async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - # TODO: Support Pagination stream API ("from" query parameter) - delayed_events = await self.delayed_events_handler.get_all_for_user(requester) - ret = {"delayed_events": delayed_events} + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: dict[bytes, list[bytes]] = request.args # type: ignore + statuses = parse_strings_from_args( + args, + "status", + allowed_values=tuple(s.value for s in _DelayedEventStatus), + ) + delay_ids = parse_strings_from_args(args, "delay_id") + # TODO: Support Pagination stream API + _from_token = parse_string_from_args(args, "from") + + ret = await self.delayed_events_handler.get_delayed_events_for_user( + requester, + delay_ids, + statuses is None or _DelayedEventStatus.SCHEDULED.value in statuses, + statuses is None or _DelayedEventStatus.FINALISED.value in statuses, + ) + # TODO: This is here for backwards compatibility. Remove eventually + if statuses is None: + ret["delayed_events"] = ret[_DelayedEventStatus.SCHEDULED.value] return 200, ret diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 6ad161db33d..c64d415983e 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -13,18 +13,27 @@ # import logging -from typing import NewType, Optional +from http import HTTPStatus +from typing import TYPE_CHECKING, NewType, Optional, Union import attr -from synapse.api.errors import NotFoundError +from synapse.api.errors import NotFoundError, StoreError, SynapseError, cs_error from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import LoggingTransaction, StoreError +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomID from synapse.util import stringutils from synapse.util.json import json_encoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -55,10 +64,25 @@ class DelayedEventDetails(EventDetails): class DelayedEventsStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + update_name="delayed_events_finalised_ts", + index_name="delayed_events_finalised_ts", + table="delayed_events", + columns=("finalised_ts",), + ) + async def get_delayed_events_stream_pos(self) -> int: """ Gets the stream position of the background process to watch for state events - that target the same piece of state as any pending delayed events. + that target the same piece of state as any scheduled delayed events. """ return await self.db_pool.simple_select_one_onecol( table="delayed_events_stream_pos", @@ -70,7 +94,7 @@ async def get_delayed_events_stream_pos(self) -> int: async def update_delayed_events_stream_pos(self, stream_id: Optional[int]) -> None: """ Updates the stream position of the background process to watch for state events - that target the same piece of state as any pending delayed events. + that target the same piece of state as any scheduled delayed events. Must only be used by the worker running the background process. """ @@ -93,6 +117,7 @@ async def add_delayed_event( origin_server_ts: Optional[int], content: JsonDict, delay: int, + limit: int, ) -> tuple[DelayID, Timestamp]: """ Inserts a new delayed event in the DB. @@ -100,11 +125,33 @@ async def add_delayed_event( Returns: The generated ID assigned to the added delayed event, and the send time of the next delayed event to be sent, which is either the event just added or one added earlier. + + Raises: + SynapseError: if the user has reached the limit of how many + delayed events they may have scheduled at a time. """ delay_id = _generate_delay_id() send_ts = Timestamp(creation_ts + delay) def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: + txn.execute( + """ + SELECT COUNT(*) FROM delayed_events + WHERE user_localpart = ? + AND finalised_ts IS NULL + """, + (user_localpart,), + ) + num_existing: int = txn.fetchall()[0][0] + if num_existing >= limit: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "The maximum number of delayed events has been reached.", + additional_fields={ + "org.matrix.msc4140.errcode": "M_MAX_DELAYED_EVENTS_EXCEEDED", + }, + ) + self.db_pool.simple_insert_txn( txn, table="delayed_events", @@ -154,6 +201,7 @@ async def restart_delayed_event( Raises: NotFoundError: if there is no matching delayed event. + SynapseError: if the delayed event has already been finalised. """ def restart_delayed_event_txn( @@ -165,6 +213,7 @@ def restart_delayed_event_txn( SET send_ts = ? + delay WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed + AND finalised_ts IS NULL """, ( current_ts, @@ -173,7 +222,25 @@ def restart_delayed_event_txn( ), ) if txn.rowcount == 0: - raise NotFoundError("Delayed event not found") + txn.execute( + """ + SELECT finalised_event_id IS NOT NULL + FROM delayed_events + WHERE delay_id = ? AND user_localpart = ? + AND finalised_ts IS NOT NULL + """, + ( + delay_id, + user_localpart, + ), + ) + row = txn.fetchone() + if not row: + raise NotFoundError("Delayed event not found") + raise SynapseError( + HTTPStatus.CONFLICT, + f"Delayed event has already been {'sent' if row[0] else 'cancelled'}", + ) next_send_ts = self._get_next_delayed_event_send_ts_txn(txn) assert next_send_ts is not None @@ -183,30 +250,115 @@ def restart_delayed_event_txn( "restart_delayed_event", restart_delayed_event_txn ) - async def get_count_of_delayed_events(self) -> int: - """Returns the number of pending delayed events in the DB.""" + async def has_scheduled_delayed_events(self) -> bool: + """Returns whether there are any scheduled delayed events in the DB.""" - def _get_count_of_delayed_events(txn: LoggingTransaction) -> int: - sql = "SELECT count(*) FROM delayed_events" + rows = await self.db_pool.execute( + "has_scheduled_delayed_events", + """ + SELECT 1 WHERE EXISTS ( + SELECT * FROM delayed_events + WHERE finalised_ts IS NULL + ) + """, + ) + return bool(rows) - txn.execute(sql) - resp = txn.fetchone() - return resp[0] if resp is not None else 0 + async def prune_finalised_delayed_events( + self, + current_ts: Timestamp, + retention_period: int, + retention_limit: int, + ) -> None: + def prune_finalised_delayed_events(txn: LoggingTransaction) -> None: + self._prune_expired_finalised_delayed_events( + txn, current_ts, retention_period + ) - return await self.db_pool.runInteraction( - "get_count_of_delayed_events", - _get_count_of_delayed_events, + txn.execute( + """ + SELECT DISTINCT(user_localpart) + FROM delayed_events + WHERE finalised_ts IS NOT NULL + """ + ) + for [user_localpart] in txn.fetchall(): + self._prune_excess_finalised_delayed_events_for_user( + txn, user_localpart, retention_limit + ) + + await self.db_pool.runInteraction( + "prune_finalised_delayed_events", prune_finalised_delayed_events + ) + + def _prune_expired_finalised_delayed_events( + self, txn: LoggingTransaction, current_ts: Timestamp, retention_period: int + ) -> None: + """ + Delete all finalised delayed events that had finalised + before the end of the given retention period. + """ + txn.execute( + """ + DELETE FROM delayed_events + WHERE ? - finalised_ts > ? + """, + ( + current_ts, + retention_period, + ), ) - async def get_all_delayed_events_for_user( + def _prune_excess_finalised_delayed_events_for_user( + self, txn: LoggingTransaction, user_localpart: str, retention_limit: int + ) -> None: + """ + Delete the oldest finalised delayed events for the given user, + such that no more of them remain than the given retention limit. + """ + txn.execute( + """ + SELECT COUNT(*) FROM delayed_events + WHERE user_localpart = ? + AND finalised_ts IS NOT NULL + """, + (user_localpart,), + ) + num_existing: int = txn.fetchall()[0][0] + if num_existing > retention_limit: + txn.execute( + """ + DELETE FROM delayed_events + WHERE user_localpart = ? + AND finalised_ts IS NOT NULL + ORDER BY finalised_ts + LIMIT ? + ) + """, + ( + user_localpart, + num_existing - retention_limit, + ), + ) + + async def get_scheduled_delayed_events_for_user( self, user_localpart: str, + delay_ids: Optional[list[str]], ) -> list[JsonDict]: - """Returns all pending delayed events owned by the given user.""" + """Returns all scheduled delayed events for the given user.""" # TODO: Support Pagination stream API ("next_batch" field) + sql_where = "WHERE user_localpart = ? AND finalised_ts IS NULL" + sql_args = [user_localpart] + if delay_ids: + delay_id_clause_sql, delay_id_clause_args = make_in_list_sql_clause( + self.database_engine, "delay_id", delay_ids + ) + sql_where += f" AND {delay_id_clause_sql}" + sql_args.extend(delay_id_clause_args) rows = await self.db_pool.execute( - "get_all_delayed_events_for_user", - """ + "get_scheduled_delayed_events_for_user", + f""" SELECT delay_id, room_id, @@ -216,10 +368,10 @@ async def get_all_delayed_events_for_user( send_ts, content FROM delayed_events - WHERE user_localpart = ? AND NOT is_processed + {sql_where} ORDER BY send_ts """, - user_localpart, + *sql_args, ) return [ { @@ -234,6 +386,98 @@ async def get_all_delayed_events_for_user( for row in rows ] + async def get_finalised_delayed_events_for_user( + self, + user_localpart: str, + delay_ids: Optional[list[str]], + current_ts: Timestamp, + retention_period: int, + retention_limit: int, + ) -> list[JsonDict]: + """Returns all finalised delayed events for the given user.""" + # TODO: Support Pagination stream API ("next_batch" field) + + def get_finalised_delayed_events_for_user( + txn: LoggingTransaction, + ) -> list[JsonDict]: + # Clear up some space in the DB before returning any results. + self._prune_expired_finalised_delayed_events( + txn, current_ts, retention_period + ) + self._prune_excess_finalised_delayed_events_for_user( + txn, user_localpart, retention_limit + ) + + sql_where = "WHERE user_localpart = ? AND finalised_ts IS NOT NULL" + sql_args = [user_localpart] + if delay_ids: + delay_id_clause_sql, delay_id_clause_args = make_in_list_sql_clause( + self.database_engine, "delay_id", delay_ids + ) + sql_where += f" AND {delay_id_clause_sql}" + sql_args.extend(delay_id_clause_args) + txn.execute( + f""" + SELECT + delay_id, + room_id, + event_type, + state_key, + delay, + send_ts, + content, + finalised_error, + finalised_event_id, + finalised_ts + FROM delayed_events + {sql_where} + ORDER BY finalised_ts + """, + sql_args, + ) + return [ + { + "delayed_event": { + "delay_id": DelayID(row[0]), + "room_id": str(RoomID.from_string(row[1])), + "type": EventType(row[2]), + **( + {"state_key": StateKey(row[3])} + if row[3] is not None + else {} + ), + "delay": Delay(row[4]), + "running_since": Timestamp(row[5] - row[4]), + "content": db_to_json(row[6]), + }, + "outcome": "cancel" if row[8] is None else "send", + "reason": ( + "finalised_error" + if row[7] is not None + else "action" + if row[9] < row[5] + else "delay" + ), + **( + {"finalised_error": db_to_json(row[7])} + if row[7] is not None + else {} + ), + **( + {"finalised_event_id": str(row[8])} + if row[8] is not None + else {} + ), + "origin_server_ts": Timestamp(row[9]), + } + for row in txn + ] + + return await self.db_pool.runInteraction( + "get_finalised_delayed_events_for_user", + get_finalised_delayed_events_for_user, + ) + async def process_timeout_delayed_events( self, current_ts: Timestamp ) -> tuple[ @@ -268,7 +512,10 @@ def process_timeout_delayed_events_txn( ) ) sql_update = "UPDATE delayed_events SET is_processed = TRUE" - sql_where = "WHERE send_ts <= ? AND NOT is_processed" + sql_where = """WHERE send_ts <= ? + AND NOT is_processed + AND finalised_ts IS NULL + """ sql_args = (current_ts,) sql_order = "ORDER BY send_ts" if isinstance(self.database_engine, PostgresEngine): @@ -323,7 +570,7 @@ async def process_target_delayed_event( delay_id: str, user_localpart: str, ) -> tuple[ - EventDetails, + Optional[EventDetails], Optional[Timestamp], ]: """ @@ -339,12 +586,13 @@ async def process_target_delayed_event( Raises: NotFoundError: if there is no matching delayed event. + SynapseError: if the delayed event has already been cancelled. """ def process_target_delayed_event_txn( txn: LoggingTransaction, ) -> tuple[ - EventDetails, + Optional[EventDetails], Optional[Timestamp], ]: txn.execute( @@ -353,6 +601,7 @@ def process_target_delayed_event_txn( SET is_processed = TRUE WHERE delay_id = ? AND user_localpart = ? AND NOT is_processed + AND finalised_ts IS NULL RETURNING room_id, event_type, @@ -367,8 +616,28 @@ def process_target_delayed_event_txn( ), ) row = txn.fetchone() - if row is None: - raise NotFoundError("Delayed event not found") + if not row: + txn.execute( + """ + SELECT finalised_event_id IS NOT NULL + FROM delayed_events + WHERE delay_id = ? AND user_localpart = ? + AND finalised_ts IS NOT NULL + """, + ( + delay_id, + user_localpart, + ), + ) + row = txn.fetchone() + if not row: + raise NotFoundError("Delayed event not found") + elif not row[0]: + raise SynapseError( + HTTPStatus.CONFLICT, + "Delayed event has already been cancelled", + ) + return None, None event = EventDetails( RoomID.from_string(row[0]), @@ -390,6 +659,7 @@ async def cancel_delayed_event( *, delay_id: str, user_localpart: str, + finalised_ts: Timestamp, ) -> Optional[Timestamp]: """ Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed. @@ -402,27 +672,48 @@ async def cancel_delayed_event( Raises: NotFoundError: if there is no matching delayed event. + SynapseError: if the delayed event has already been sent. """ def cancel_delayed_event_txn( txn: LoggingTransaction, ) -> Optional[Timestamp]: - try: - self.db_pool.simple_delete_one_txn( - txn, - table="delayed_events", - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - "is_processed": False, - }, + txn.execute( + """ + UPDATE delayed_events + SET finalised_ts = ? + WHERE delay_id = ? AND user_localpart = ? + AND NOT is_processed + AND finalised_ts IS NULL + """, + ( + finalised_ts, + delay_id, + user_localpart, + ), + ) + if txn.rowcount == 0: + txn.execute( + """ + SELECT finalised_event_id IS NOT NULL + FROM delayed_events + WHERE delay_id = ? AND user_localpart = ? + AND finalised_ts IS NOT NULL + """, + ( + delay_id, + user_localpart, + ), ) - except StoreError: - if txn.rowcount == 0: + row = txn.fetchone() + if not row: raise NotFoundError("Delayed event not found") - else: - raise - + elif row[0]: + raise SynapseError( + HTTPStatus.CONFLICT, + "Delayed event has already been sent", + ) + return None return self._get_next_delayed_event_send_ts_txn(txn) return await self.db_pool.runInteraction( @@ -436,6 +727,7 @@ async def cancel_delayed_state_events( event_type: str, state_key: str, not_from_localpart: str, + finalised_ts: Timestamp, ) -> Optional[Timestamp]: """ Cancels all matching delayed state events, i.e. remove them as long as they haven't been processed. @@ -455,12 +747,18 @@ def cancel_delayed_state_events_txn( ) -> Optional[Timestamp]: txn.execute( """ - DELETE FROM delayed_events + UPDATE delayed_events + SET + finalised_error = ?, + finalised_ts = ? WHERE room_id = ? AND event_type = ? AND state_key = ? AND user_localpart <> ? AND NOT is_processed + AND finalised_ts IS NULL """, ( + _generate_cancelled_by_state_update_json(), + finalised_ts, room_id, event_type, state_key, @@ -473,57 +771,99 @@ def cancel_delayed_state_events_txn( "cancel_delayed_state_events", cancel_delayed_state_events_txn ) - async def delete_processed_delayed_event( + async def finalise_processed_delayed_event( self, delay_id: DelayID, user_localpart: UserLocalpart, + result_or_error: Union[str, JsonDict], + finalised_ts: Timestamp, ) -> None: """ - Delete the matching delayed event, as long as it has been marked as processed. + Finalise the matching delayed event, as long as it has been marked as processed. Throws: StoreError: if there is no matching delayed event, or if it has not yet been processed. """ - return await self.db_pool.simple_delete_one( - table="delayed_events", - keyvalues={ - "delay_id": delay_id, - "user_localpart": user_localpart, - "is_processed": True, - }, - desc="delete_processed_delayed_event", + if isinstance(result_or_error, str): + event_id = result_or_error + send_error = None + else: + event_id = None + send_error = result_or_error + + def finalise_processed_delayed_event_txn(txn: LoggingTransaction) -> None: + table = "delayed_events" + txn.execute( + f""" + UPDATE {table} + SET + finalised_error = ?, + finalised_event_id = ?, + finalised_ts = ? + WHERE delay_id = ? AND user_localpart = ? + AND is_processed + AND finalised_ts IS NULL + """, + ( + json_encoder.encode(send_error) if send_error is not None else None, + event_id, + finalised_ts, + delay_id, + user_localpart, + ), + ) + rowcount = txn.rowcount + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + await self.db_pool.runInteraction( + "finalise_processed_delayed_event", + finalise_processed_delayed_event_txn, ) - async def delete_processed_delayed_state_events( + async def finalise_processed_delayed_state_events( self, *, room_id: str, event_type: str, state_key: str, + finalised_ts: Timestamp, ) -> None: """ - Delete the matching delayed state events that have been marked as processed. + Finalise the matching delayed state events that have been marked as processed. """ - await self.db_pool.simple_delete( - table="delayed_events", - keyvalues={ - "room_id": room_id, - "event_type": event_type, - "state_key": state_key, - "is_processed": True, - }, - desc="delete_processed_delayed_state_events", + await self.db_pool.execute( + "finalise_processed_delayed_state_events", + """ + UPDATE delayed_events + SET + finalised_error = ?, + finalised_ts = ? + WHERE room_id = ? AND event_type = ? AND state_key = ? + AND user_localpart <> ? + AND is_processed + AND finalised_ts IS NULL + """, + _generate_cancelled_by_state_update_json(), + finalised_ts, + room_id, + event_type, + state_key, ) async def unprocess_delayed_events(self) -> None: """ Unmark all delayed events for processing. """ - await self.db_pool.simple_update( - table="delayed_events", - keyvalues={"is_processed": True}, - updatevalues={"is_processed": False}, - desc="unprocess_delayed_events", + await self.db_pool.execute( + "unprocess_delayed_events", + """ + UPDATE delayed_events SET is_processed = FALSE + WHERE is_processed + AND finalised_ts IS NULL + """, ) async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: @@ -539,14 +879,15 @@ async def get_next_delayed_event_send_ts(self) -> Optional[Timestamp]: def _get_next_delayed_event_send_ts_txn( self, txn: LoggingTransaction ) -> Optional[Timestamp]: - result = self.db_pool.simple_select_one_onecol_txn( - txn, - table="delayed_events", - keyvalues={"is_processed": False}, - retcol="MIN(send_ts)", - allow_none=True, + txn.execute( + """ + SELECT MIN(send_ts) FROM delayed_events + WHERE NOT is_processed + AND finalised_ts IS NULL + """ ) - return Timestamp(result) if result is not None else None + resp = txn.fetchone() + return Timestamp(resp[0]) if resp is not None else None def _generate_delay_id() -> DelayID: @@ -558,3 +899,15 @@ def _generate_delay_id() -> DelayID: # the same ID to exist for multiple users. return DelayID(f"syd_{stringutils.random_string(20)}") + + +def _generate_cancelled_by_state_update_json() -> str: + return json_encoder.encode( + cs_error( + "The delayed event did not get sent because a different user updated the same state event. " + + "So the scheduled event might change it in an undesired way.", + **{ + "org.matrix.msc4140.errcode": "M_CANCELLED_BY_STATE_UPDATE", + }, + ) + ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 3c3b13437ef..1171d9ed28d 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 92 # remember to update the list below when updating +SCHEMA_VERSION = 93 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -168,6 +168,10 @@ Changes in SCHEMA_VERSION = 92 - Cleaned up a trigger that was added in #18260 and then reverted. + +Changes in SCHEMA_VERSION = 93 + - MSC4140: Add `finalised_delayed_events` table that keeps track of delayed events + that have been sent, cancelled, or failed to be sent due to an error. """ diff --git a/synapse/storage/schema/main/delta/93/01_add_finalised_delayed_events.sql b/synapse/storage/schema/main/delta/93/01_add_finalised_delayed_events.sql new file mode 100644 index 00000000000..a0b4933d37e --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_add_finalised_delayed_events.sql @@ -0,0 +1,20 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Store when delayed events have either been sent, cancelled, or not sent due to an error (MSC4140) +ALTER TABLE delayed_events ADD COLUMN finalised_error bytea; +ALTER TABLE delayed_events ADD COLUMN finalised_event_id TEXT; +ALTER TABLE delayed_events ADD COLUMN finalised_ts BIGINT; + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('delayed_events_finalised_ts', '{}'); diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index c67ffc76683..3816d8f8858 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -89,7 +89,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) def test_delayed_events_empty_on_startup(self) -> None: - self.assertListEqual([], self._get_delayed_events()) + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertListEqual([], finalised) def test_delayed_state_events_are_sent_on_timeout(self) -> None: state_key = "to_send_on_timeout" @@ -105,9 +107,13 @@ def test_delayed_state_events_are_sent_on_timeout(self) -> None: self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) - content = self._get_delayed_event_content(events[0]) + scheduled, finalised = self._get_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + self.assertListEqual([], finalised) + + scheduled_event = scheduled[0] + content = self._get_delayed_event_content(scheduled_event) + self.assertEqual(setter_expected, content.get(setter_key), content) self.helper.get_state( self.room_id, @@ -118,7 +124,17 @@ def test_delayed_state_events_are_sent_on_timeout(self) -> None: ) self.reactor.advance(1) - self.assertListEqual([], self._get_delayed_events()) + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertEqual(1, len(finalised), finalised) + + finalised_event_info = finalised[0] + self.assertDictEqual(scheduled_event, finalised_event_info["delayed_event"]) + self.assertEqual("send", finalised_event_info["outcome"]) + self.assertEqual("delay", finalised_event_info["reason"]) + self.assertNotIn("finalised_error", finalised_event_info) + self.assertIsNotNone(finalised_event_info["finalised_event_id"]) + content = self.helper.get_state( self.room_id, _EVENT_TYPE, @@ -224,9 +240,11 @@ def test_cancel_delayed_state_event(self) -> None: self.assertIsNotNone(delay_id) self.reactor.advance(1) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) - content = self._get_delayed_event_content(events[0]) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + + scheduled_event = scheduled[0] + content = self._get_delayed_event_content(scheduled_event) self.assertEqual(setter_expected, content.get(setter_key), content) self.helper.get_state( self.room_id, @@ -243,7 +261,17 @@ def test_cancel_delayed_state_event(self) -> None: self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - self.assertListEqual([], self._get_delayed_events()) + + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertEqual(1, len(finalised), finalised) + + finalised_event_info = finalised[0] + self.assertDictEqual(scheduled_event, finalised_event_info["delayed_event"]) + self.assertEqual("cancel", finalised_event_info["outcome"]) + self.assertEqual("action", finalised_event_info["reason"]) + self.assertNotIn("finalised_error", finalised_event_info) + self.assertNotIn("finalised_event_id", finalised_event_info) self.reactor.advance(1) content = self.helper.get_state( @@ -317,9 +345,11 @@ def test_send_delayed_state_event(self) -> None: self.assertIsNotNone(delay_id) self.reactor.advance(1) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) - content = self._get_delayed_event_content(events[0]) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + + scheduled_event = scheduled[0] + content = self._get_delayed_event_content(scheduled_event) self.assertEqual(setter_expected, content.get(setter_key), content) self.helper.get_state( self.room_id, @@ -336,7 +366,18 @@ def test_send_delayed_state_event(self) -> None: self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - self.assertListEqual([], self._get_delayed_events()) + + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertEqual(1, len(finalised), finalised) + + finalised_event_info = finalised[0] + self.assertDictEqual(scheduled_event, finalised_event_info["delayed_event"]) + self.assertEqual("send", finalised_event_info["outcome"]) + self.assertEqual("action", finalised_event_info["reason"]) + self.assertNotIn("finalised_error", finalised_event_info) + self.assertIsNotNone(finalised_event_info["finalised_event_id"]) + content = self.helper.get_state( self.room_id, _EVENT_TYPE, @@ -406,9 +447,11 @@ def test_restart_delayed_state_event(self) -> None: self.assertIsNotNone(delay_id) self.reactor.advance(1) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) - content = self._get_delayed_event_content(events[0]) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + + scheduled_event = scheduled[0] + content = self._get_delayed_event_content(scheduled_event) self.assertEqual(setter_expected, content.get(setter_key), content) self.helper.get_state( self.room_id, @@ -427,9 +470,12 @@ def test_restart_delayed_state_event(self) -> None: self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.reactor.advance(1) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) - content = self._get_delayed_event_content(events[0]) + + scheduled, finalised = self._get_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + self.assertListEqual([], finalised) + + content = self._get_delayed_event_content(scheduled[0]) self.assertEqual(setter_expected, content.get(setter_key), content) self.helper.get_state( self.room_id, @@ -440,7 +486,16 @@ def test_restart_delayed_state_event(self) -> None: ) self.reactor.advance(1) - self.assertListEqual([], self._get_delayed_events()) + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertEqual(1, len(finalised), finalised) + + finalised_event_info = finalised[0] + self.assertEqual("send", finalised_event_info["outcome"]) + self.assertEqual("delay", finalised_event_info["reason"]) + self.assertNotIn("finalised_error", finalised_event_info) + self.assertIsNotNone(finalised_event_info["finalised_event_id"]) + content = self.helper.get_state( self.room_id, _EVENT_TYPE, @@ -510,8 +565,8 @@ def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) self.helper.send_state( self.room_id, @@ -522,8 +577,8 @@ def test_delayed_state_is_not_cancelled_by_new_state_from_same_user( self.user1_access_token, state_key=state_key, ) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) self.reactor.advance(1) content = self.helper.get_state( @@ -549,8 +604,9 @@ def test_delayed_state_is_cancelled_by_new_state_from_other_user( self.user1_access_token, ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - events = self._get_delayed_events() - self.assertEqual(1, len(events), events) + scheduled = self._get_scheduled_delayed_events() + self.assertEqual(1, len(scheduled), scheduled) + scheduled_event = scheduled[0] setter_expected = "other_user" self.helper.send_state( @@ -562,7 +618,23 @@ def test_delayed_state_is_cancelled_by_new_state_from_other_user( self.user2_access_token, state_key=state_key, ) - self.assertListEqual([], self._get_delayed_events()) + + scheduled, finalised = self._get_delayed_events() + self.assertListEqual([], scheduled) + self.assertEqual(1, len(finalised), finalised) + + finalised_event_info = finalised[0] + self.assertDictEqual(scheduled_event, finalised_event_info["delayed_event"]) + self.assertEqual("cancel", finalised_event_info["outcome"]) + self.assertEqual("finalised_error", finalised_event_info["reason"]) + self.assert_dict( + { + "errcode": "M_UNKNOWN", + "org.matrix.msc4140.errcode": "M_CANCELLED_BY_STATE_UPDATE", + }, + finalised_event_info["finalised_error"], + ) + self.assertNotIn("finalised_event_id", finalised_event_info) self.reactor.advance(1) content = self.helper.get_state( @@ -573,7 +645,7 @@ def test_delayed_state_is_cancelled_by_new_state_from_other_user( ) self.assertEqual(setter_expected, content.get(setter_key), content) - def _get_delayed_events(self) -> list[JsonDict]: + def _get_delayed_events(self) -> tuple[list[JsonDict], list[JsonDict]]: channel = self.make_request( "GET", PATH_PREFIX, @@ -581,13 +653,56 @@ def _get_delayed_events(self) -> list[JsonDict]: ) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) - key = "delayed_events" - self.assertIn(key, channel.json_body) + scheduled = self._validate_scheduled_delayed_events(channel.json_body) + finalised = self._validate_finalised_delayed_events(channel.json_body) + + return scheduled, finalised + + def _get_scheduled_delayed_events(self) -> list[JsonDict]: + channel = self.make_request( + "GET", + PATH_PREFIX + "?status=scheduled", + access_token=self.user1_access_token, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + scheduled = self._validate_scheduled_delayed_events(channel.json_body) + + return scheduled + + def _get_finalised_delayed_events(self) -> list[JsonDict]: + channel = self.make_request( + "GET", + PATH_PREFIX + "?status=finalised", + access_token=self.user1_access_token, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + finalised = self._validate_finalised_delayed_events(channel.json_body) + + return finalised + + def _validate_scheduled_delayed_events(self, json_body: JsonDict) -> list[JsonDict]: + key = "scheduled" + self.assertIn(key, json_body) + + scheduled = json_body[key] + self.assertIsInstance(scheduled, list) + + return scheduled + + def _validate_finalised_delayed_events(self, json_body: JsonDict) -> list[JsonDict]: + key = "finalised" + self.assertIn(key, json_body) + + finalised = json_body[key] + self.assertIsInstance(finalised, list) - events = channel.json_body[key] - self.assertIsInstance(events, list) + for item in finalised: + for key in ("delayed_event", "outcome", "reason", "origin_server_ts"): + self.assertIsNotNone(item.get(key)) - return events + return finalised def _get_delayed_event_content(self, event: JsonDict) -> JsonDict: key = "content" diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4142aed3632..1fe4a39b368 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2402,6 +2402,8 @@ def test_room_message_filter_wildcard(self) -> None: class RoomDelayedEventTestCase(RoomBase): """Tests delayed events.""" + servlets = RoomBase.servlets + [admin.register_servlets] + user_id = "@sid1:red" def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -2533,6 +2535,54 @@ def test_add_delayed_event_ratelimit(self) -> None: channel = self.make_request(*args) self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + @unittest.override_config( + { + "max_event_delay_duration": "24h", + "experimental_features": { + "msc4140_max_delayed_events_per_user": 1, + }, + } + ) + def test_add_delayed_event_num_limit(self) -> None: + """Test that users may not have too many scheduled delayed events at once.""" + user2_user_id = self.register_user("user2", "pass") + + room_id = self.helper.create_room_as(self.user_id, is_public=True) + self.helper.join(room_id, user2_user_id) + + txn_id = 0 + + def add_delayed_event( + expect_success: bool, + user_id: str = self.user_id, + ) -> None: + nonlocal room_id, txn_id + self.helper.auth_user_id = user_id + txn_id += 1 + channel = self.make_request( + "PUT", + ( + "rooms/%s/send/m.room.message/%s?org.matrix.msc4140.delay=2000" + % (room_id, txn_id) + ).encode("ascii"), + {"body": "test", "msgtype": "m.text"}, + ) + if expect_success: + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + else: + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) + self.assertEqual( + "M_MAX_DELAYED_EVENTS_EXCEEDED", + channel.json_body.get("org.matrix.msc4140.errcode"), + channel.json_body, + ) + + add_delayed_event(True) + add_delayed_event(False) + add_delayed_event(True, user2_user_id) + self.reactor.advance(2) + add_delayed_event(True) + class RoomSearchTestCase(unittest.HomeserverTestCase): servlets = [