1616import logging
1717from enum import Enum
1818from itertools import chain
19- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
19+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , cast
2020
2121from typing_extensions import Counter
2222
2323from twisted .internet .defer import DeferredLock
2424
2525from synapse .api .constants import EventContentFields , EventTypes , Membership
2626from synapse .api .errors import StoreError
27- from synapse .storage .database import DatabasePool , LoggingDatabaseConnection
27+ from synapse .storage .database import (
28+ DatabasePool ,
29+ LoggingDatabaseConnection ,
30+ LoggingTransaction ,
31+ )
2832from synapse .storage .databases .main .state_deltas import StateDeltasStore
2933from synapse .types import JsonDict
3034from synapse .util .caches .descriptors import cached
@@ -122,7 +126,9 @@ def __init__(
122126 self .db_pool .updates .register_noop_background_update ("populate_stats_cleanup" )
123127 self .db_pool .updates .register_noop_background_update ("populate_stats_prepare" )
124128
125- async def _populate_stats_process_users (self , progress , batch_size ):
129+ async def _populate_stats_process_users (
130+ self , progress : JsonDict , batch_size : int
131+ ) -> int :
126132 """
127133 This is a background update which regenerates statistics for users.
128134 """
@@ -134,7 +140,7 @@ async def _populate_stats_process_users(self, progress, batch_size):
134140
135141 last_user_id = progress .get ("last_user_id" , "" )
136142
137- def _get_next_batch (txn ) :
143+ def _get_next_batch (txn : LoggingTransaction ) -> List [ str ] :
138144 sql = """
139145 SELECT DISTINCT name FROM users
140146 WHERE name > ?
@@ -168,7 +174,9 @@ def _get_next_batch(txn):
168174
169175 return len (users_to_work_on )
170176
171- async def _populate_stats_process_rooms (self , progress , batch_size ):
177+ async def _populate_stats_process_rooms (
178+ self , progress : JsonDict , batch_size : int
179+ ) -> int :
172180 """This is a background update which regenerates statistics for rooms."""
173181 if not self .stats_enabled :
174182 await self .db_pool .updates ._end_background_update (
@@ -178,7 +186,7 @@ async def _populate_stats_process_rooms(self, progress, batch_size):
178186
179187 last_room_id = progress .get ("last_room_id" , "" )
180188
181- def _get_next_batch (txn ) :
189+ def _get_next_batch (txn : LoggingTransaction ) -> List [ str ] :
182190 sql = """
183191 SELECT DISTINCT room_id FROM current_state_events
184192 WHERE room_id > ?
@@ -307,7 +315,7 @@ async def bulk_update_stats_delta(
307315 stream_id: Current position.
308316 """
309317
310- def _bulk_update_stats_delta_txn (txn ) :
318+ def _bulk_update_stats_delta_txn (txn : LoggingTransaction ) -> None :
311319 for stats_type , stats_updates in updates .items ():
312320 for stats_id , fields in stats_updates .items ():
313321 logger .debug (
@@ -339,7 +347,7 @@ async def update_stats_delta(
339347 stats_type : str ,
340348 stats_id : str ,
341349 fields : Dict [str , int ],
342- complete_with_stream_id : Optional [ int ] ,
350+ complete_with_stream_id : int ,
343351 absolute_field_overrides : Optional [Dict [str , int ]] = None ,
344352 ) -> None :
345353 """
@@ -372,14 +380,14 @@ async def update_stats_delta(
372380
373381 def _update_stats_delta_txn (
374382 self ,
375- txn ,
376- ts ,
377- stats_type ,
378- stats_id ,
379- fields ,
380- complete_with_stream_id ,
381- absolute_field_overrides = None ,
382- ):
383+ txn : LoggingTransaction ,
384+ ts : int ,
385+ stats_type : str ,
386+ stats_id : str ,
387+ fields : Dict [ str , int ] ,
388+ complete_with_stream_id : int ,
389+ absolute_field_overrides : Optional [ Dict [ str , int ]] = None ,
390+ ) -> None :
383391 if absolute_field_overrides is None :
384392 absolute_field_overrides = {}
385393
@@ -422,20 +430,23 @@ def _update_stats_delta_txn(
422430 )
423431
424432 def _upsert_with_additive_relatives_txn (
425- self , txn , table , keyvalues , absolutes , additive_relatives
426- ):
433+ self ,
434+ txn : LoggingTransaction ,
435+ table : str ,
436+ keyvalues : Dict [str , Any ],
437+ absolutes : Dict [str , Any ],
438+ additive_relatives : Dict [str , int ],
439+ ) -> None :
427440 """Used to update values in the stats tables.
428441
429442 This is basically a slightly convoluted upsert that *adds* to any
430443 existing rows.
431444
432445 Args:
433- txn
434- table (str): Table name
435- keyvalues (dict[str, any]): Row-identifying key values
436- absolutes (dict[str, any]): Absolute (set) fields
437- additive_relatives (dict[str, int]): Fields that will be added onto
438- if existing row present.
446+ table: Table name
447+ keyvalues: Row-identifying key values
448+ absolutes: Absolute (set) fields
449+ additive_relatives: Fields that will be added onto if existing row present.
439450 """
440451 if self .database_engine .can_native_upsert :
441452 absolute_updates = [
@@ -491,20 +502,17 @@ def _upsert_with_additive_relatives_txn(
491502 current_row .update (absolutes )
492503 self .db_pool .simple_update_one_txn (txn , table , keyvalues , current_row )
493504
494- async def _calculate_and_set_initial_state_for_room (
495- self , room_id : str
496- ) -> Tuple [dict , dict , int ]:
505+ async def _calculate_and_set_initial_state_for_room (self , room_id : str ) -> None :
497506 """Calculate and insert an entry into room_stats_current.
498507
499508 Args:
500509 room_id: The room ID under calculation.
501-
502- Returns:
503- A tuple of room state, membership counts and stream position.
504510 """
505511
506- def _fetch_current_state_stats (txn ):
507- pos = self .get_room_max_stream_ordering ()
512+ def _fetch_current_state_stats (
513+ txn : LoggingTransaction ,
514+ ) -> Tuple [List [str ], Dict [str , int ], int , List [str ], int ]:
515+ pos = self .get_room_max_stream_ordering () # type: ignore[attr-defined]
508516
509517 rows = self .db_pool .simple_select_many_txn (
510518 txn ,
@@ -524,7 +532,7 @@ def _fetch_current_state_stats(txn):
524532 retcols = ["event_id" ],
525533 )
526534
527- event_ids = [ row ["event_id" ] for row in rows ]
535+ event_ids = cast ( List [ str ], [ row ["event_id" ] for row in rows ])
528536
529537 txn .execute (
530538 """
@@ -544,9 +552,9 @@ def _fetch_current_state_stats(txn):
544552 (room_id ,),
545553 )
546554
547- ( current_state_events_count ,) = txn .fetchone ()
555+ current_state_events_count = cast ( Tuple [ int ], txn .fetchone ())[ 0 ]
548556
549- users_in_room = self .get_users_in_room_txn (txn , room_id )
557+ users_in_room = self .get_users_in_room_txn (txn , room_id ) # type: ignore[attr-defined]
550558
551559 return (
552560 event_ids ,
@@ -566,7 +574,7 @@ def _fetch_current_state_stats(txn):
566574 "get_initial_state_for_room" , _fetch_current_state_stats
567575 )
568576
569- state_event_map = await self .get_events (event_ids , get_prev_content = False )
577+ state_event_map = await self .get_events (event_ids , get_prev_content = False ) # type: ignore[attr-defined]
570578
571579 room_state = {
572580 "join_rules" : None ,
@@ -622,8 +630,10 @@ def _fetch_current_state_stats(txn):
622630 },
623631 )
624632
625- async def _calculate_and_set_initial_state_for_user (self , user_id ):
626- def _calculate_and_set_initial_state_for_user_txn (txn ):
633+ async def _calculate_and_set_initial_state_for_user (self , user_id : str ) -> None :
634+ def _calculate_and_set_initial_state_for_user_txn (
635+ txn : LoggingTransaction ,
636+ ) -> Tuple [int , int ]:
627637 pos = self ._get_max_stream_id_in_current_state_deltas_txn (txn )
628638
629639 txn .execute (
@@ -634,7 +644,7 @@ def _calculate_and_set_initial_state_for_user_txn(txn):
634644 """ ,
635645 (user_id ,),
636646 )
637- ( count ,) = txn .fetchone ()
647+ count = cast ( Tuple [ int ], txn .fetchone ())[ 0 ]
638648 return count , pos
639649
640650 joined_rooms , pos = await self .db_pool .runInteraction (
@@ -678,7 +688,9 @@ async def get_users_media_usage_paginate(
678688 users that exist given this query
679689 """
680690
681- def get_users_media_usage_paginate_txn (txn ):
691+ def get_users_media_usage_paginate_txn (
692+ txn : LoggingTransaction ,
693+ ) -> Tuple [List [JsonDict ], int ]:
682694 filters = []
683695 args = [self .hs .config .server .server_name ]
684696
@@ -733,7 +745,7 @@ def get_users_media_usage_paginate_txn(txn):
733745 sql_base = sql_base ,
734746 )
735747 txn .execute (sql , args )
736- count = txn .fetchone ()[0 ]
748+ count = cast ( Tuple [ int ], txn .fetchone () )[0 ]
737749
738750 sql = """
739751 SELECT
0 commit comments