Skip to content

Commit e4ef1f8

Browse files
Fix multi-threading logger
1 parent 53f50a8 commit e4ef1f8

File tree

3 files changed

+75
-60
lines changed

3 files changed

+75
-60
lines changed

bootstrap/lib/logger.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import sqlite3
1717
import sys
18+
from contextlib import closing
1819

1920

2021
class Logger(object):
@@ -71,15 +72,22 @@ def code(value):
7172

7273
log_level = None # log level
7374
dir_logs = None
74-
sqlite_cur = None
7575
sqlite_file = None
76-
sqlite_conn = None
76+
connection = None
7777
path_txt = None
7878
file_txt = None
7979
name = None
8080
max_lineno_width = 3
8181

82-
def __new__(cls, dir_logs=None, name='logs'):
82+
def __new__(cls, dir_logs=None, name=None):
83+
return Logger._get_instance(dir_logs, name)
84+
85+
def __call__(self, *args, **kwargs):
86+
return self.log_message(*args, **kwargs, stack_displacement=2)
87+
88+
@staticmethod
89+
def _get_instance(dir_logs=None, name=None):
90+
name = name or 'logs'
8391
if Logger._instance is None:
8492
Logger._instance = object.__new__(Logger)
8593
Logger._instance.set_level(Logger._instance.INFO)
@@ -90,16 +98,12 @@ def __new__(cls, dir_logs=None, name='logs'):
9098
Logger._instance.path_txt = os.path.join(dir_logs, '{}.txt'.format(name))
9199
Logger._instance.file_txt = open(os.path.join(dir_logs, '{}.txt'.format(name)), 'a+')
92100
Logger._instance.sqlite_file = os.path.join(dir_logs, '{}.sqlite'.format(name))
93-
Logger._instance.init_sqlite()
94101
else:
95102
Logger._instance.log_message('No logs files will be created (dir_logs attribute is empty)',
96103
log_level=Logger.WARNING)
97104

98105
return Logger._instance
99106

100-
def __call__(self, *args, **kwargs):
101-
return self.log_message(*args, **kwargs, stack_displacement=2)
102-
103107
def set_level(self, log_level):
104108
self.log_level = log_level
105109

@@ -173,24 +177,24 @@ def print_subitem(prefix, subdictionary, stack_displacement=3):
173177
self.log_message('{}: {}'.format(group, description), log_level=log_level, stack_displacement=stack_displacement)
174178
print_subitem(' ', dictionary, stack_displacement=stack_displacement + 1)
175179

176-
def _execute(self, statement, parameters=None, commit=True):
180+
def _execute(self, statement, parameters=None, commit=True, cursor=None):
177181
assert parameters is None or isinstance(parameters, tuple)
178182
parameters = parameters or ()
179-
return_value = self.sqlite_cur.execute(statement, parameters)
183+
return_value = cursor.execute(statement, parameters)
180184
if commit:
181-
self.sqlite_conn.commit()
185+
self.get_conn().commit()
182186
return return_value
183187

184-
def _run_query(self, query, parameters=None):
185-
return self._execute(query, parameters, commit=False)
188+
def _run_query(self, query, parameters=None, cursor=None):
189+
return self._execute(query, parameters, commit=False, cursor=cursor)
186190

187191
def _get_internal_table_name(self, table_name):
188192
return f'_{table_name}'
189193

190-
def _check_table_exists(self, table_name):
194+
def _check_table_exists(self, table_name, cursor=None):
191195
table_name = self._get_internal_table_name(table_name)
192196
query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
193-
return self._run_query(query, (table_name,))
197+
return self._run_query(query, (table_name,), cursor=cursor)
194198

195199
def _create_table(self, table_name):
196200
table_name = self._get_internal_table_name(table_name)
@@ -200,15 +204,17 @@ def _create_table(self, table_name):
200204
"__timestamp" DATETIME DEFAULT CURRENT_TIMESTAMP
201205
);
202206
"""
203-
self._execute(statement)
207+
with closing(self.get_conn().cursor()) as cursor:
208+
self._execute(statement, cursor=cursor)
204209

205210
def _list_columns(self, table_name):
206211
table_name = self._get_internal_table_name(table_name)
207212
query = "SELECT name FROM PRAGMA_TABLE_INFO(?)"
208-
qry_cur = self._run_query(query, (table_name,))
209-
columns = (res[0] for res in qry_cur)
210-
# remove __id and __timestamp columns
211-
columns = [c for c in columns if not c.startswith('__')]
213+
with closing(self.get_conn().cursor()) as cursor:
214+
qry_cur = self._run_query(query, (table_name,), cursor=cursor)
215+
columns = (res[0] for res in qry_cur)
216+
# remove __id and __timestamp columns
217+
columns = [c for c in columns if not c.startswith('__')]
212218
return columns
213219

214220
@staticmethod
@@ -223,7 +229,8 @@ def _add_column(self, table_name, column_name, value_sample=None):
223229
table_name = self._get_internal_table_name(table_name)
224230
value_type = self._get_data_type(value_sample)
225231
statement = f'ALTER TABLE {table_name} ADD COLUMN "{column_name}" {value_type}'
226-
return self._execute(statement)
232+
with closing(self.get_conn().cursor()) as cursor:
233+
return self._execute(statement, cursor=cursor)
227234

228235
def _flatten_dict(self, dictionary, flatten_dict=None, prefix=''):
229236
flatten_dict = flatten_dict if flatten_dict is not None else {}
@@ -247,51 +254,58 @@ def _insert_row(self, table_name, flat_dictionary):
247254
value_placeholder = ', '.join(['?'] * len(columns))
248255
statement = f'INSERT INTO {table_name} ({column_string}) VALUES({value_placeholder})'
249256
parameters = tuple(val for val in flat_dictionary.values())
250-
return self._execute(statement, parameters)
257+
with closing(self.get_conn().cursor()) as cursor:
258+
return self._execute(statement, parameters, cursor=cursor)
251259

252260
def log_dict(self, group, dictionary, description='', stack_displacement=2, should_print=False, log_level=SUMMARY):
253261
if log_level < self.log_level:
254262
return -1
255263

256264
flat_dictionary = self._flatten_dict(dictionary)
257-
if self._check_table_exists(group).fetchone():
258-
columns = self._list_columns(group)
259-
for key in flat_dictionary:
260-
if key not in columns:
261-
self.log_message(f'Key "{key}" is unknown. New keys are not allowed', log_level=self.ERROR)
262-
for column_name in columns:
263-
if column_name not in flat_dictionary:
264-
self.log_message(f'Key "{column_name}" not in the dictionary to be logged', log_level=self.ERROR)
265-
else:
266-
self._create_table(group)
267-
for key, value in flat_dictionary.items():
268-
self._add_column(group, key, value)
265+
with closing(self.get_conn().cursor()) as cursor:
266+
if self._check_table_exists(group, cursor=cursor).fetchone():
267+
columns = self._list_columns(group)
268+
for key in flat_dictionary:
269+
if key not in columns:
270+
self.log_message(f'Key "{key}" is unknown. New keys are not allowed', log_level=self.ERROR)
271+
for column_name in columns:
272+
if column_name not in flat_dictionary:
273+
self.log_message(f'Key "{column_name}" not in the dictionary to be logged', log_level=self.ERROR)
274+
else:
275+
self._create_table(group)
276+
for key, value in flat_dictionary.items():
277+
self._add_column(group, key, value)
269278

270279
self._insert_row(group, flat_dictionary)
271280

272281
if should_print:
273282
self.log_dict_message(group, dictionary, description, stack_displacement + 1, log_level)
274283

275-
def select(self, group, columns=None):
276-
table_name = self._get_internal_table_name(group)
277-
table_columns = self._list_columns(group)
284+
@staticmethod
285+
def select(group, columns=None):
286+
logger = Logger._get_instance(dir_logs=None, name=None)
287+
table_name = logger._get_internal_table_name(group)
288+
table_columns = logger._list_columns(group)
278289
if columns is None:
279290
column_string = '*'
280291
else:
281292
for c in columns:
282293
if c not in table_columns:
283-
self.log_message(f'Unknown column "{c}"', log_level=self.ERROR)
294+
logger.log_message(f'Unknown column "{c}"', log_level=Logger.ERROR)
284295
column_string = ', '.join([f'"{c}"' for c in columns])
285296
statement = f'SELECT {column_string} FROM {table_name}'
286-
return self._execute(statement)
287-
288-
def init_sqlite(self):
289-
pre_existing = os.path.isfile(self.sqlite_file)
290-
self.sqlite_conn = sqlite3.connect(self.sqlite_file)
291-
self.sqlite_cur = self.sqlite_conn.cursor()
292-
if not pre_existing:
293-
self._create_table('bootstrap')
297+
with closing(logger.get_conn().cursor()) as cursor:
298+
return logger._execute(statement, cursor=cursor, commit=False).fetchall()
299+
300+
def get_conn(self):
301+
if self.connection is None:
302+
pre_existing = os.path.isfile(self.sqlite_file)
303+
connection = sqlite3.connect(self.sqlite_file, check_same_thread=False, isolation_level='IMMEDIATE')
304+
self.connection = connection
305+
if not pre_existing:
306+
self._create_table('bootstrap')
307+
return self.connection
294308

295309
def flush(self):
296310
if self.dir_logs:
297-
self.sqlite_conn.commit()
311+
self.get_conn().commit()

bootstrap/views/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66

77
def factory(engine=None):
8-
Logger()('Creating views...')
9-
108
# if views does not exist, pick view
119
# to support backward compatibility
1210
if 'views' in Options():
@@ -24,6 +22,8 @@ def factory(engine=None):
2422

2523
exp_dir = Options()['exp.dir']
2624

25+
Logger(exp_dir)('Creating views...')
26+
2727
if 'names' in opt:
2828
view = make_multi_views(opt, exp_dir)
2929
return view

bootstrap/views/plotly.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import json
32
import math
43
import plotly.graph_objects as go
54
from plotly.subplots import make_subplots
@@ -49,7 +48,6 @@ def generate(self):
4948
for log_name in log_names:
5049
path_sqlite = os.path.join(self.exp_dir, '{}.sqlite'.format(log_name))
5150
if os.path.isfile(path_sqlite):
52-
Logger._instance = None
5351
data_sqlite = Logger(dir_logs=self.exp_dir, name=log_name)
5452
data_dict[log_name] = data_sqlite
5553
else:
@@ -93,16 +91,19 @@ def generate(self):
9391
group = view['view_name'].split('.')[0]
9492
columns = [view['view_name'].split('.')[1]]
9593

96-
y = [x[0] for x in data_dict[view['log_name']].select(group, columns)]
97-
x = list(range(len(y)))
98-
99-
scatter = go.Scatter(
100-
x=x,
101-
y=y,
102-
name=view['view_interim'],
103-
line={'color': color}
104-
)
105-
figure.append_trace(scatter, figure_pos_y, figure_pos_x)
94+
try:
95+
y = [x[0] for x in data_dict[view['log_name']].select(group, columns)]
96+
x = list(range(len(y)))
97+
98+
scatter = go.Scatter(
99+
x=x,
100+
y=y,
101+
name=view['view_interim'],
102+
line={'color': color}
103+
)
104+
figure.append_trace(scatter, figure_pos_y, figure_pos_x)
105+
except Exception:
106+
pass
106107

107108
figure['layout'].update(
108109
autosize=True,

0 commit comments

Comments
 (0)