Skip to content

Commit aa98e9f

Browse files
ash0tssliard
authored andcommitted
[Examples] Update train_unconditional.py to include logging argument for Wandb (huggingface#1719)
Update train_unconditional.py Add logger flag to choose between tensorboard and wandb
1 parent 63af40c commit aa98e9f

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,16 @@ def parse_args():
173173
parser.add_argument(
174174
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
175175
)
176+
parser.add_argument(
177+
"--logger",
178+
type=str,
179+
default="tensorboard",
180+
choices=["tensorboard", "wandb"],
181+
help=(
182+
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
183+
" for experiment tracking and logging of model metrics and model checkpoints"
184+
),
185+
)
176186
parser.add_argument(
177187
"--logging_dir",
178188
type=str,
@@ -248,7 +258,7 @@ def main(args):
248258
accelerator = Accelerator(
249259
gradient_accumulation_steps=args.gradient_accumulation_steps,
250260
mixed_precision=args.mixed_precision,
251-
log_with="tensorboard",
261+
log_with=args.logger,
252262
logging_dir=logging_dir,
253263
)
254264

@@ -477,9 +487,11 @@ def transforms(examples):
477487

478488
# denormalize the images and save to tensorboard
479489
images_processed = (images * 255).round().astype("uint8")
480-
accelerator.trackers[0].writer.add_images(
481-
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
482-
)
490+
491+
if args.logger == "tensorboard":
492+
accelerator.get_tracker("tensorboard").add_images(
493+
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
494+
)
483495

484496
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
485497
# save the model

0 commit comments

Comments
 (0)