Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")

loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
Expand All @@ -216,7 +218,6 @@ def train(args):
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand Down Expand Up @@ -291,16 +292,21 @@ def train(args):
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)

if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"epoch_loss": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)

accelerator.wait_for_everyone()
Expand Down
12 changes: 9 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("network_train")

loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
Expand All @@ -386,7 +388,6 @@ def train(args):

network.on_epoch_start(text_encoder, unet)

loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -446,8 +447,13 @@ def train(args):
global_step += 1

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

Expand All @@ -459,7 +465,7 @@ def train(args):
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)

accelerator.wait_for_everyone()
Expand Down