Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
99 changes: 49 additions & 50 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# limitations under the License.

import logging
from typing import Iterable, List, TypeVar
from typing import Dict, Iterable, List, TypeVar

import attr

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -347,55 +346,54 @@ def get_state_group_delta(self, state_group: int):

return self.stores.state.get_state_group_delta(state_group)

@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events

Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
_room_id: id of the room for these events
event_ids: ids of the events

Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}

event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
group_to_state = await self.stores.state._get_state_for_groups(groups)

return group_to_state

@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group: int) -> dict:
"""Get the event IDs of all the state in the given state group

Args:
state_group (int)
state_group

Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))

return group_to_state[state_group]

@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}

group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)

state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
Expand Down Expand Up @@ -429,25 +427,26 @@ def _get_state_groups_from_groups(

return self.stores.state._get_state_groups_from_groups(groups, state_filter)

@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids
state_filter: The state filter used to fetch state

Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)

state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
Expand All @@ -463,24 +462,24 @@ def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):

return {event: event_to_state[event] for event in event_ids}

@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)

Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.

Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)

Expand All @@ -491,36 +490,36 @@ def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):

return {event: event_to_state[event] for event in event_ids}

@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event

Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.

Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]

@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event

Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.

Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]

def _get_state_for_groups(
Expand Down
60 changes: 36 additions & 24 deletions tests/storage/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def test_get_state_groups_ids(self):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)

state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
Expand All @@ -108,8 +108,8 @@ def test_get_state_groups(self):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)

state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
Expand Down Expand Up @@ -150,7 +150,9 @@ def test_get_state_for_event(self):
)

# check we get the full state as of the final event
state = yield self.storage.state.get_state_for_event(e5.event_id)
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)

self.assertIsNotNone(e4)

Expand All @@ -166,22 +168,28 @@ def test_get_state_for_event(self):
)

# check we can filter to the m.room.name event (with a '' state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
)

self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)

# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
)

self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)

# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
)

self.assertStateMapEqual(
Expand All @@ -190,12 +198,14 @@ def test_get_state_for_event(self):

# check we can grab a specific room member without filtering out the
# other event types
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
)
)

self.assertStateMapEqual(
Expand All @@ -208,11 +218,13 @@ def test_get_state_for_event(self):
)

# check that we can grab everything except members
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
)
)

self.assertStateMapEqual(
Expand All @@ -224,8 +236,8 @@ def test_get_state_for_event(self):
#######################################################

room_id = self.room.to_string()
group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
group_ids = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]

Expand Down