Skip to content

Commit 4103ce9

Browse files
authored
Simplify callbacks (#4289)
1 parent 771ac6c commit 4103ce9

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

utils/callbacks.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,119 +58,118 @@ def get_registered_actions(self, hook=None):
5858
else:
5959
return self._callbacks
6060

61-
@staticmethod
62-
def run_callbacks(register, *args, **kwargs):
61+
def run_callbacks(self, hook, *args, **kwargs):
6362
"""
6463
Loop through the registered actions and fire all callbacks
6564
"""
66-
for logger in register:
65+
for logger in self._callbacks[hook]:
6766
# print(f"Running callbacks.{logger['callback'].__name__}()")
6867
logger['callback'](*args, **kwargs)
6968

7069
def on_pretrain_routine_start(self, *args, **kwargs):
7170
"""
7271
Fires all registered callbacks at the start of each pretraining routine
7372
"""
74-
self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
73+
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
7574

7675
def on_pretrain_routine_end(self, *args, **kwargs):
7776
"""
7877
Fires all registered callbacks at the end of each pretraining routine
7978
"""
80-
self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
79+
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
8180

8281
def on_train_start(self, *args, **kwargs):
8382
"""
8483
Fires all registered callbacks at the start of each training
8584
"""
86-
self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
85+
self.run_callbacks('on_train_start', *args, **kwargs)
8786

8887
def on_train_epoch_start(self, *args, **kwargs):
8988
"""
9089
Fires all registered callbacks at the start of each training epoch
9190
"""
92-
self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
91+
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
9392

9493
def on_train_batch_start(self, *args, **kwargs):
9594
"""
9695
Fires all registered callbacks at the start of each training batch
9796
"""
98-
self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
97+
self.run_callbacks('on_train_batch_start', *args, **kwargs)
9998

10099
def optimizer_step(self, *args, **kwargs):
101100
"""
102101
Fires all registered callbacks on each optimizer step
103102
"""
104-
self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
103+
self.run_callbacks('optimizer_step', *args, **kwargs)
105104

106105
def on_before_zero_grad(self, *args, **kwargs):
107106
"""
108107
Fires all registered callbacks before zero grad
109108
"""
110-
self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
109+
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
111110

112111
def on_train_batch_end(self, *args, **kwargs):
113112
"""
114113
Fires all registered callbacks at the end of each training batch
115114
"""
116-
self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
115+
self.run_callbacks('on_train_batch_end', *args, **kwargs)
117116

118117
def on_train_epoch_end(self, *args, **kwargs):
119118
"""
120119
Fires all registered callbacks at the end of each training epoch
121120
"""
122-
self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
121+
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
123122

124123
def on_val_start(self, *args, **kwargs):
125124
"""
126125
Fires all registered callbacks at the start of the validation
127126
"""
128-
self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
127+
self.run_callbacks('on_val_start', *args, **kwargs)
129128

130129
def on_val_batch_start(self, *args, **kwargs):
131130
"""
132131
Fires all registered callbacks at the start of each validation batch
133132
"""
134-
self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
133+
self.run_callbacks('on_val_batch_start', *args, **kwargs)
135134

136135
def on_val_image_end(self, *args, **kwargs):
137136
"""
138137
Fires all registered callbacks at the end of each val image
139138
"""
140-
self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
139+
self.run_callbacks('on_val_image_end', *args, **kwargs)
141140

142141
def on_val_batch_end(self, *args, **kwargs):
143142
"""
144143
Fires all registered callbacks at the end of each validation batch
145144
"""
146-
self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
145+
self.run_callbacks('on_val_batch_end', *args, **kwargs)
147146

148147
def on_val_end(self, *args, **kwargs):
149148
"""
150149
Fires all registered callbacks at the end of the validation
151150
"""
152-
self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
151+
self.run_callbacks('on_val_end', *args, **kwargs)
153152

154153
def on_fit_epoch_end(self, *args, **kwargs):
155154
"""
156155
Fires all registered callbacks at the end of each fit (train+val) epoch
157156
"""
158-
self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
157+
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
159158

160159
def on_model_save(self, *args, **kwargs):
161160
"""
162161
Fires all registered callbacks after each model save
163162
"""
164-
self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
163+
self.run_callbacks('on_model_save', *args, **kwargs)
165164

166165
def on_train_end(self, *args, **kwargs):
167166
"""
168167
Fires all registered callbacks at the end of training
169168
"""
170-
self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
169+
self.run_callbacks('on_train_end', *args, **kwargs)
171170

172171
def teardown(self, *args, **kwargs):
173172
"""
174173
Fires all registered callbacks before teardown
175174
"""
176-
self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
175+
self.run_callbacks('teardown', *args, **kwargs)

0 commit comments

Comments
 (0)