Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11411.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ exclude = (?x)
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
Expand Down Expand Up @@ -184,6 +183,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True

Expand Down
22 changes: 10 additions & 12 deletions synapse/replication/slave/storage/_slaved_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,18 @@
from typing import List, Optional, Tuple

from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id


class SlavedIdTracker:
class SlavedIdTracker(AbstractStreamIdTracker):
"""Tracks the "current" stream ID of a stream with a single writer.

See `AbstractStreamIdTracker` for more details.

Note that this class does not work correctly when there are multiple
writers.
"""

def __init__(
self,
db_conn: LoggingDatabaseConnection,
Expand All @@ -36,17 +44,7 @@ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)

def get_current_token(self) -> int:
"""

Returns:
int
"""
return self._current

def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
4 changes: 0 additions & 4 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore

Expand All @@ -25,9 +24,6 @@ def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)

if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
6 changes: 3 additions & 3 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type

import attr

Expand Down Expand Up @@ -157,7 +157,7 @@ async def _update_function(

# now we fetch up to that many rows from the events table

event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)

Expand Down Expand Up @@ -191,7 +191,7 @@ async def _update_function(
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.

ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"

def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database

Expand Down
3 changes: 2 additions & 1 deletion synapse/state/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.types import StreamToken, get_domain_from_id
from synapse.types import get_domain_from_id
from synapse.util import json_decoder

if TYPE_CHECKING:
Expand All @@ -48,7 +48,7 @@ def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: StreamToken,
token: int,
rows: Iterable[Any],
) -> None:
pass
Expand Down
29 changes: 17 additions & 12 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
from collections import OrderedDict, namedtuple
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -41,9 +41,10 @@
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
Expand All @@ -64,9 +65,6 @@
)


_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))


@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
Expand Down Expand Up @@ -108,16 +106,21 @@ def __init__(
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id

# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen

# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"

# Since we have been configured to write, we ought to have id generators,
# rather than id trackers.
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)

# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen

async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
Expand Down Expand Up @@ -1553,11 +1556,13 @@ def _add_to_cache(self, txn, events_and_contexts):
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))

def prefill():
for cache_entry in to_prefill:
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)

txn.call_after(prefill)

Expand Down
Loading