1313# limitations under the License.
1414
1515import logging
16- from typing import Optional , Tuple
16+ from typing import List , Optional , Tuple , Union
1717
1818import attr
1919
2020from synapse .api .constants import RelationTypes
2121from synapse .events import EventBase
2222from synapse .storage ._base import SQLBaseStore
23+ from synapse .storage .database import LoggingTransaction
2324from synapse .storage .databases .main .stream import generate_pagination_where_clause
2425from synapse .storage .relations import (
2526 AggregationPaginationToken ,
@@ -63,7 +64,7 @@ async def get_relations_for_event(
6364 """
6465
6566 where_clause = ["relates_to_id = ?" ]
66- where_args = [event_id ]
67+ where_args : List [ Union [ str , int ]] = [event_id ]
6768
6869 if relation_type is not None :
6970 where_clause .append ("relation_type = ?" )
@@ -80,8 +81,8 @@ async def get_relations_for_event(
8081 pagination_clause = generate_pagination_where_clause (
8182 direction = direction ,
8283 column_names = ("topological_ordering" , "stream_ordering" ),
83- from_token = attr .astuple (from_token ) if from_token else None ,
84- to_token = attr .astuple (to_token ) if to_token else None ,
84+ from_token = attr .astuple (from_token ) if from_token else None , # type: ignore[arg-type]
85+ to_token = attr .astuple (to_token ) if to_token else None , # type: ignore[arg-type]
8586 engine = self .database_engine ,
8687 )
8788
@@ -106,7 +107,9 @@ async def get_relations_for_event(
106107 order ,
107108 )
108109
109- def _get_recent_references_for_event_txn (txn ):
110+ def _get_recent_references_for_event_txn (
111+ txn : LoggingTransaction ,
112+ ) -> PaginationChunk :
110113 txn .execute (sql , where_args + [limit + 1 ])
111114
112115 last_topo_id = None
@@ -160,7 +163,7 @@ async def get_aggregation_groups_for_event(
160163 """
161164
162165 where_clause = ["relates_to_id = ?" , "relation_type = ?" ]
163- where_args = [event_id , RelationTypes .ANNOTATION ]
166+ where_args : List [ Union [ str , int ]] = [event_id , RelationTypes .ANNOTATION ]
164167
165168 if event_type :
166169 where_clause .append ("type = ?" )
@@ -169,8 +172,8 @@ async def get_aggregation_groups_for_event(
169172 having_clause = generate_pagination_where_clause (
170173 direction = direction ,
171174 column_names = ("COUNT(*)" , "MAX(stream_ordering)" ),
172- from_token = attr .astuple (from_token ) if from_token else None ,
173- to_token = attr .astuple (to_token ) if to_token else None ,
175+ from_token = attr .astuple (from_token ) if from_token else None , # type: ignore[arg-type]
176+ to_token = attr .astuple (to_token ) if to_token else None , # type: ignore[arg-type]
174177 engine = self .database_engine ,
175178 )
176179
@@ -199,7 +202,9 @@ async def get_aggregation_groups_for_event(
199202 having_clause = having_clause ,
200203 )
201204
202- def _get_aggregation_groups_for_event_txn (txn ):
205+ def _get_aggregation_groups_for_event_txn (
206+ txn : LoggingTransaction ,
207+ ) -> PaginationChunk :
203208 txn .execute (sql , where_args + [limit + 1 ])
204209
205210 next_batch = None
@@ -254,11 +259,12 @@ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
254259 LIMIT 1
255260 """
256261
257- def _get_applicable_edit_txn (txn ) :
262+ def _get_applicable_edit_txn (txn : LoggingTransaction ) -> Optional [ str ] :
258263 txn .execute (sql , (event_id , RelationTypes .REPLACE ))
259264 row = txn .fetchone ()
260265 if row :
261266 return row [0 ]
267+ return None
262268
263269 edit_id = await self .db_pool .runInteraction (
264270 "get_applicable_edit" , _get_applicable_edit_txn
@@ -267,7 +273,7 @@ def _get_applicable_edit_txn(txn):
267273 if not edit_id :
268274 return None
269275
270- return await self .get_event (edit_id , allow_none = True )
276+ return await self .get_event (edit_id , allow_none = True ) # type: ignore[attr-defined]
271277
272278 @cached ()
273279 async def get_thread_summary (
@@ -283,7 +289,9 @@ async def get_thread_summary(
283289 The number of items in the thread and the most recent response, if any.
284290 """
285291
286- def _get_thread_summary_txn (txn ) -> Tuple [int , Optional [str ]]:
292+ def _get_thread_summary_txn (
293+ txn : LoggingTransaction ,
294+ ) -> Tuple [int , Optional [str ]]:
287295 # Fetch the count of threaded events and the latest event ID.
288296 # TODO Should this only allow m.room.message events.
289297 sql = """
@@ -312,7 +320,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
312320 AND relation_type = ?
313321 """
314322 txn .execute (sql , (event_id , RelationTypes .THREAD ))
315- count = txn .fetchone ()[0 ]
323+ count = txn .fetchone ()[0 ] # type: ignore[index]
316324
317325 return count , latest_event_id
318326
@@ -322,7 +330,7 @@ def _get_thread_summary_txn(txn) -> Tuple[int, Optional[str]]:
322330
323331 latest_event = None
324332 if latest_event_id :
325- latest_event = await self .get_event (latest_event_id , allow_none = True )
333+ latest_event = await self .get_event (latest_event_id , allow_none = True ) # type: ignore[attr-defined]
326334
327335 return count , latest_event
328336
@@ -354,7 +362,7 @@ async def has_user_annotated_event(
354362 LIMIT 1;
355363 """
356364
357- def _get_if_user_has_annotated_event (txn ) :
365+ def _get_if_user_has_annotated_event (txn : LoggingTransaction ) -> bool :
358366 txn .execute (
359367 sql ,
360368 (
0 commit comments