1515# See the License for the specific language governing permissions and
1616# limitations under the License.
1717import logging
18- import sys
1918import time
20- from typing import Iterable , Tuple
19+ from time import monotonic as monotonic_time
20+ from typing import Any , Callable , Dict , Iterable , Iterator , List , Optional , Tuple
2121
2222from six import iteritems , iterkeys , itervalues
2323from six .moves import intern , range
3232from synapse .logging .context import LoggingContext , make_deferred_yieldable
3333from synapse .metrics .background_process_metrics import run_as_background_process
3434from synapse .storage .background_updates import BackgroundUpdater
35- from synapse .storage .engines import PostgresEngine , Sqlite3Engine
35+ from synapse .storage .engines import BaseDatabaseEngine , PostgresEngine , Sqlite3Engine
36+ from synapse .storage .types import Connection , Cursor
3637from synapse .util .stringutils import exception_to_unicode
3738
38- # import a function which will return a monotonic time, in seconds
39- try :
40- # on python 3, use time.monotonic, since time.clock can go backwards
41- from time import monotonic as monotonic_time
42- except ImportError :
43- # ... but python 2 doesn't have it
44- from time import clock as monotonic_time
45-
4639logger = logging .getLogger (__name__ )
4740
48- try :
49- MAX_TXN_ID = sys .maxint - 1
50- except AttributeError :
51- # python 3 does not have a maximum int value
52- MAX_TXN_ID = 2 ** 63 - 1
41+ # python 3 does not have a maximum int value
42+ MAX_TXN_ID = 2 ** 63 - 1
5343
5444sql_logger = logging .getLogger ("synapse.storage.SQL" )
5545transaction_logger = logging .getLogger ("synapse.storage.txn" )
7767
7868
7969def make_pool (
80- reactor , db_config : DatabaseConnectionConfig , engine
70+ reactor , db_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
8171) -> adbapi .ConnectionPool :
8272 """Get the connection pool for the database.
8373 """
@@ -90,7 +80,9 @@ def make_pool(
9080 )
9181
9282
93- def make_conn (db_config : DatabaseConnectionConfig , engine ):
83+ def make_conn (
84+ db_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
85+ ) -> Connection :
9486 """Make a new connection to the database and return it.
9587
9688 Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
10799 return db_conn
108100
109101
110- class LoggingTransaction (object ):
102+ # The type of entry which goes on our after_callbacks and exception_callbacks lists.
103+ #
104+ # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
105+ # that mypy sees the type but the runtime python doesn't.
106+ _CallbackListEntry = Tuple ["Callable[..., None]" , Iterable [Any ], Dict [str , Any ]]
107+
108+
109+ class LoggingTransaction :
111110 """An object that almost-transparently proxies for the 'txn' object
112111 passed to the constructor. Adds logging and metrics to the .execute()
113112 method.
114113
115114 Args:
116115 txn: The database transcation object to wrap.
117- name (str) : The name of this transactions for logging.
118- database_engine (Sqlite3Engine|PostgresEngine)
119- after_callbacks(list|None) : A list that callbacks will be appended to
116+ name: The name of this transactions for logging.
117+ database_engine
118+ after_callbacks: A list that callbacks will be appended to
120119 that have been added by `call_after` which should be run on
121120 successful completion of the transaction. None indicates that no
122121 callbacks should be allowed to be scheduled to run.
123- exception_callbacks(list|None) : A list that callbacks will be appended
122+ exception_callbacks: A list that callbacks will be appended
124123 to that have been added by `call_on_exception` which should be run
125124 if transaction ends with an error. None indicates that no callbacks
126125 should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
135134 ]
136135
137136 def __init__ (
138- self , txn , name , database_engine , after_callbacks = None , exception_callbacks = None
137+ self ,
138+ txn : Cursor ,
139+ name : str ,
140+ database_engine : BaseDatabaseEngine ,
141+ after_callbacks : Optional [List [_CallbackListEntry ]] = None ,
142+ exception_callbacks : Optional [List [_CallbackListEntry ]] = None ,
139143 ):
140- object . __setattr__ ( self , " txn" , txn )
141- object . __setattr__ ( self , " name" , name )
142- object . __setattr__ ( self , " database_engine" , database_engine )
143- object . __setattr__ ( self , " after_callbacks" , after_callbacks )
144- object . __setattr__ ( self , " exception_callbacks" , exception_callbacks )
144+ self . txn = txn
145+ self . name = name
146+ self . database_engine = database_engine
147+ self . after_callbacks = after_callbacks
148+ self . exception_callbacks = exception_callbacks
145149
146- def call_after (self , callback , * args , ** kwargs ):
150+ def call_after (self , callback : "Callable[..., None]" , * args , ** kwargs ):
147151 """Call the given callback on the main twisted thread after the
148152 transaction has finished. Used to invalidate the caches on the
149153 correct thread.
150154 """
155+ # if self.after_callbacks is None, that means that whatever constructed the
156+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
157+ # is not the case.
158+ assert self .after_callbacks is not None
151159 self .after_callbacks .append ((callback , args , kwargs ))
152160
153- def call_on_exception (self , callback , * args , ** kwargs ):
161+ def call_on_exception (self , callback : "Callable[..., None]" , * args , ** kwargs ):
162+ # if self.exception_callbacks is None, that means that whatever constructed the
163+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
164+ # is not the case.
165+ assert self .exception_callbacks is not None
154166 self .exception_callbacks .append ((callback , args , kwargs ))
155167
156- def __getattr__ (self , name ) :
157- return getattr ( self .txn , name )
168+ def fetchall (self ) -> List [ Tuple ] :
169+ return self .txn . fetchall ( )
158170
159- def __setattr__ (self , name , value ) :
160- setattr ( self .txn , name , value )
171+ def fetchone (self ) -> Tuple :
172+ return self .txn . fetchone ( )
161173
162- def __iter__ (self ):
174+ def __iter__ (self ) -> Iterator [ Tuple ] :
163175 return self .txn .__iter__ ()
164176
177+ @property
178+ def rowcount (self ) -> int :
179+ return self .txn .rowcount
180+
181+ @property
182+ def description (self ) -> Any :
183+ return self .txn .description
184+
165185 def execute_batch (self , sql , args ):
166186 if isinstance (self .database_engine , PostgresEngine ):
167- from psycopg2 .extras import execute_batch
187+ from psycopg2 .extras import execute_batch # type: ignore
168188
169189 self ._do_execute (lambda * x : execute_batch (self .txn , * x ), sql , args )
170190 else :
171191 for val in args :
172192 self .execute (sql , val )
173193
174- def execute (self , sql , * args ):
194+ def execute (self , sql : str , * args : Any ):
175195 self ._do_execute (self .txn .execute , sql , * args )
176196
177- def executemany (self , sql , * args ):
197+ def executemany (self , sql : str , * args : Any ):
178198 self ._do_execute (self .txn .executemany , sql , * args )
179199
180200 def _make_sql_one_line (self , sql ):
@@ -207,6 +227,9 @@ def _do_execute(self, func, sql, *args):
207227 sql_logger .debug ("[SQL time] {%s} %f sec" , self .name , secs )
208228 sql_query_timer .labels (sql .split ()[0 ]).observe (secs )
209229
230+ def close (self ):
231+ self .txn .close ()
232+
210233
211234class PerformanceCounters (object ):
212235 def __init__ (self ):
@@ -251,17 +274,19 @@ class Database(object):
251274
252275 _TXN_ID = 0
253276
254- def __init__ (self , hs , database_config : DatabaseConnectionConfig , engine ):
277+ def __init__ (
278+ self , hs , database_config : DatabaseConnectionConfig , engine : BaseDatabaseEngine
279+ ):
255280 self .hs = hs
256281 self ._clock = hs .get_clock ()
257282 self ._database_config = database_config
258283 self ._db_pool = make_pool (hs .get_reactor (), database_config , engine )
259284
260285 self .updates = BackgroundUpdater (hs , self )
261286
262- self ._previous_txn_total_time = 0
263- self ._current_txn_total_time = 0
264- self ._previous_loop_ts = 0
287+ self ._previous_txn_total_time = 0.0
288+ self ._current_txn_total_time = 0.0
289+ self ._previous_loop_ts = 0.0
265290
266291 # TODO(paul): These can eventually be removed once the metrics code
267292 # is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ def new_transaction(
463488 sql_txn_timer .labels (desc ).observe (duration )
464489
465490 @defer .inlineCallbacks
466- def runInteraction (self , desc , func , * args , ** kwargs ):
491+ def runInteraction (self , desc : str , func : Callable , * args : Any , ** kwargs : Any ):
467492 """Starts a transaction on the database and runs a given function
468493
469494 Arguments:
470- desc (str) : description of the transaction, for logging and metrics
471- func (func) : callback function, which will be called with a
495+ desc: description of the transaction, for logging and metrics
496+ func: callback function, which will be called with a
472497 database transaction (twisted.enterprise.adbapi.Transaction) as
473498 its first argument, followed by `args` and `kwargs`.
474499
475- args (list) : positional args to pass to `func`
476- kwargs (dict) : named args to pass to `func`
500+ args: positional args to pass to `func`
501+ kwargs: named args to pass to `func`
477502
478503 Returns:
479504 Deferred: The result of func
480505 """
481- after_callbacks = []
482- exception_callbacks = []
506+ after_callbacks = [] # type: List[_CallbackListEntry]
507+ exception_callbacks = [] # type: List[_CallbackListEntry]
483508
484509 if LoggingContext .current_context () == LoggingContext .sentinel :
485510 logger .warning ("Starting db txn '%s' from sentinel context" , desc )
@@ -505,15 +530,15 @@ def runInteraction(self, desc, func, *args, **kwargs):
505530 return result
506531
507532 @defer .inlineCallbacks
508- def runWithConnection (self , func , * args , ** kwargs ):
533+ def runWithConnection (self , func : Callable , * args : Any , ** kwargs : Any ):
509534 """Wraps the .runWithConnection() method on the underlying db_pool.
510535
511536 Arguments:
512- func (func) : callback function, which will be called with a
537+ func: callback function, which will be called with a
513538 database connection (twisted.enterprise.adbapi.Connection) as
514539 its first argument, followed by `args` and `kwargs`.
515- args (list) : positional args to pass to `func`
516- kwargs (dict) : named args to pass to `func`
540+ args: positional args to pass to `func`
541+ kwargs: named args to pass to `func`
517542
518543 Returns:
519544 Deferred: The result of func
@@ -800,7 +825,7 @@ def _getwhere(key):
800825 return False
801826
802827 # We didn't find any existing rows, so insert a new one
803- allvalues = {}
828+ allvalues = {} # type: Dict[str, Any]
804829 allvalues .update (keyvalues )
805830 allvalues .update (values )
806831 allvalues .update (insertion_values )
@@ -829,7 +854,7 @@ def simple_upsert_txn_native_upsert(
829854 Returns:
830855 None
831856 """
832- allvalues = {}
857+ allvalues = {} # type: Dict[str, Any]
833858 allvalues .update (keyvalues )
834859 allvalues .update (insertion_values )
835860
@@ -916,7 +941,7 @@ def simple_upsert_many_txn_native_upsert(
916941 Returns:
917942 None
918943 """
919- allnames = []
944+ allnames = [] # type: List[str]
920945 allnames .extend (key_names )
921946 allnames .extend (value_names )
922947
@@ -1100,7 +1125,7 @@ def simple_select_many_batch(
11001125 keyvalues : dict of column names and values to select the rows with
11011126 retcols : list of strings giving the names of the columns to return
11021127 """
1103- results = []
1128+ results = [] # type: List[Dict[str, Any]]
11041129
11051130 if not iterable :
11061131 return results
@@ -1439,7 +1464,7 @@ def simple_select_list_paginate_txn(
14391464 raise ValueError ("order_direction must be one of 'ASC' or 'DESC'." )
14401465
14411466 where_clause = "WHERE " if filters or keyvalues else ""
1442- arg_list = []
1467+ arg_list = [] # type: List[Any]
14431468 if filters :
14441469 where_clause += " AND " .join ("%s LIKE ?" % (k ,) for k in filters )
14451470 arg_list += list (filters .values ())
0 commit comments