diff --git a/openwakeword/train.py b/openwakeword/train.py index 411fc43..3901999 100755 --- a/openwakeword/train.py +++ b/openwakeword/train.py @@ -360,26 +360,30 @@ def train_model(self, X, max_steps, warmup_steps, hold_steps, X_val=None, w[pos_ndcs] = 1 w = w[..., None] - # Do backpropagation, with gradient accumulation if the batch-size after selecting high loss examples is too small - loss = self.loss(predictions, y_ if self.n_classes == 1 else y, w.to(self.device)) - loss = loss/accumulation_steps - accumulated_samples += predictions.shape[0] - if accumulated_samples < 128: - accumulation_steps += 1 + if predictions.shape[0] != 0: # edge case where a batch is empty after selecting high loss examples + # Do backpropagation, with gradient accumulation if the batch-size after selecting high loss examples is too small + loss = self.loss(predictions, y_ if self.n_classes == 1 else y, w.to(self.device)) + loss = loss/accumulation_steps + accumulated_samples += predictions.shape[0] + if accumulated_samples < 128: + accumulation_steps += 1 + else: + loss.backward() + self.optimizer.step() + accumulation_steps = 1 + accumulated_samples = 0 + + self.history["loss"].append(loss.detach().cpu().numpy()) + + # Compute training metrics and log them + fp = self.fp(predictions, y_ if self.n_classes == 1 else y) + self.n_fp += fp + self.history["recall"].append(self.recall(predictions, y_).detach().cpu().numpy()) + + if self.n_classes != 1: + self.history["accuracy"].append(self.acc(predictions, y).detach().cpu().numpy()) else: - loss.backward() - self.optimizer.step() - accumulation_steps = 1 - accumulated_samples = 0 - - # Compute training metrics and log them - fp = self.fp(predictions, y_ if self.n_classes == 1 else y) - self.n_fp += fp - - self.history["loss"].append(loss.detach().cpu().numpy()) - self.history["recall"].append(self.recall(predictions, y_).detach().cpu().numpy()) - if self.n_classes != 1: - self.history["accuracy"].append(self.acc(predictions, y).detach().cpu().numpy()) + logging.warning("Empty batch after selecting high loss examples! Your model may be overfit to the training data.") # Run validation and log validation metrics if step_ndx in val_steps and step_ndx > 1 and false_positive_val_data is not None: