Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12423.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
4 changes: 2 additions & 2 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ async def _send_renewal_emails(self) -> None:
expiring_users = await self.store.get_users_expiring_soon()

if expiring_users:
for user in expiring_users:
for user_id, expiration_ts_ms in expiring_users:
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
user_id=user_id, expiration_ts=expiration_ts_ms
)

async def send_renewal_email_to_user(self, user_id: str) -> None:
Expand Down
28 changes: 18 additions & 10 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple

from synapse.appservice import (
ApplicationService,
Expand All @@ -26,7 +26,11 @@
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
Expand Down Expand Up @@ -92,7 +96,7 @@ def get_max_as_txn_id(txn: Cursor) -> int:

super().__init__(database, db_conn, hs)

def get_app_services(self):
def get_app_services(self) -> List[ApplicationService]:
return self.services_cache

def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
Expand Down Expand Up @@ -256,7 +260,7 @@ async def create_appservice_txn(
A new transaction.
"""

def _create_appservice_txn(txn):
def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)

# Insert new txn into txn table
Expand Down Expand Up @@ -291,7 +295,7 @@ async def complete_appservice_txn(
service: The application service which was sent this transaction.
"""

def _complete_appservice_txn(txn):
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
Expand Down Expand Up @@ -322,7 +326,9 @@ async def get_oldest_unsent_txn(
An AppServiceTransaction or None.
"""

def _get_oldest_unsent_txn(txn):
def _get_oldest_unsent_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
Expand Down Expand Up @@ -364,7 +370,7 @@ def _get_oldest_unsent_txn(txn):
)

async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
Expand All @@ -378,7 +384,9 @@ async def get_new_events_for_appservice(
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""

def get_new_events_for_appservice_txn(txn):
def get_new_events_for_appservice_txn(
txn: LoggingTransaction,
) -> Tuple[int, List[str]]:
sql = (
"SELECT e.stream_ordering, e.event_id"
" FROM events AS e"
Expand Down Expand Up @@ -416,7 +424,7 @@ async def get_type_stream_id_for_appservice(
% (type,)
)

def get_type_stream_id_for_appservice_txn(txn):
def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
stream_id_type = "%s_stream_id" % type
txn.execute(
# We do NOT want to escape `stream_id_type`.
Expand Down Expand Up @@ -444,7 +452,7 @@ async def set_appservice_stream_type_pos(
% (stream_type,)
)

def set_appservice_stream_type_pos_txn(txn):
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
Expand Down
Loading