@@ -173,6 +173,16 @@ def parse_args():
173173 parser .add_argument (
174174 "--hub_private_repo" , action = "store_true" , help = "Whether or not to create a private repository."
175175 )
176+ parser .add_argument (
177+ "--logger" ,
178+ type = str ,
179+ default = "tensorboard" ,
180+ choices = ["tensorboard" , "wandb" ],
181+ help = (
182+ "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
183+ " for experiment tracking and logging of model metrics and model checkpoints"
184+ ),
185+ )
176186 parser .add_argument (
177187 "--logging_dir" ,
178188 type = str ,
@@ -248,7 +258,7 @@ def main(args):
248258 accelerator = Accelerator (
249259 gradient_accumulation_steps = args .gradient_accumulation_steps ,
250260 mixed_precision = args .mixed_precision ,
251- log_with = "tensorboard" ,
261+ log_with = args . logger ,
252262 logging_dir = logging_dir ,
253263 )
254264
@@ -477,9 +487,11 @@ def transforms(examples):
477487
478488 # denormalize the images and save to tensorboard
479489 images_processed = (images * 255 ).round ().astype ("uint8" )
480- accelerator .trackers [0 ].writer .add_images (
481- "test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
482- )
490+
491+ if args .logger == "tensorboard" :
492+ accelerator .get_tracker ("tensorboard" ).add_images (
493+ "test_samples" , images_processed .transpose (0 , 3 , 1 , 2 ), epoch
494+ )
483495
484496 if epoch % args .save_model_epochs == 0 or epoch == args .num_epochs - 1 :
485497 # save the model
0 commit comments