|
18 | 18 | import math |
19 | 19 | import os |
20 | 20 | import random |
| 21 | +import warnings |
21 | 22 | from pathlib import Path |
22 | 23 | from typing import Optional |
23 | 24 |
|
|
54 | 55 | from diffusers.utils.import_utils import is_xformers_available |
55 | 56 |
|
56 | 57 |
|
| 58 | +if is_wandb_available(): |
| 59 | + import wandb |
| 60 | + |
57 | 61 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
58 | 62 | PIL_INTERPOLATION = { |
59 | 63 | "linear": PIL.Image.Resampling.BILINEAR, |
|
79 | 83 | logger = get_logger(__name__) |
80 | 84 |
|
81 | 85 |
|
| 86 | +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): |
| 87 | + logger.info( |
| 88 | + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
| 89 | + f" {args.validation_prompt}." |
| 90 | + ) |
| 91 | + # create pipeline (note: unet and vae are loaded again in float32) |
| 92 | + pipeline = DiffusionPipeline.from_pretrained( |
| 93 | + args.pretrained_model_name_or_path, |
| 94 | + text_encoder=accelerator.unwrap_model(text_encoder), |
| 95 | + tokenizer=tokenizer, |
| 96 | + unet=unet, |
| 97 | + vae=vae, |
| 98 | + revision=args.revision, |
| 99 | + torch_dtype=weight_dtype, |
| 100 | + ) |
| 101 | + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
| 102 | + pipeline = pipeline.to(accelerator.device) |
| 103 | + pipeline.set_progress_bar_config(disable=True) |
| 104 | + |
| 105 | + # run inference |
| 106 | + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
| 107 | + images = [] |
| 108 | + for _ in range(args.num_validation_images): |
| 109 | + with torch.autocast("cuda"): |
| 110 | + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] |
| 111 | + images.append(image) |
| 112 | + |
| 113 | + for tracker in accelerator.trackers: |
| 114 | + if tracker.name == "tensorboard": |
| 115 | + np_images = np.stack([np.asarray(img) for img in images]) |
| 116 | + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
| 117 | + if tracker.name == "wandb": |
| 118 | + tracker.log( |
| 119 | + { |
| 120 | + "validation": [ |
| 121 | + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) |
| 122 | + ] |
| 123 | + } |
| 124 | + ) |
| 125 | + |
| 126 | + del pipeline |
| 127 | + torch.cuda.empty_cache() |
| 128 | + |
| 129 | + |
82 | 130 | def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): |
83 | 131 | logger.info("Saving embeddings") |
84 | 132 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] |
@@ -268,12 +316,22 @@ def parse_args(): |
268 | 316 | default=4, |
269 | 317 | help="Number of images that should be generated during validation with `validation_prompt`.", |
270 | 318 | ) |
| 319 | + parser.add_argument( |
| 320 | + "--validation_steps", |
| 321 | + type=int, |
| 322 | + default=100, |
| 323 | + help=( |
| 324 | + "Run validation every X steps. Validation consists of running the prompt" |
| 325 | + " `args.validation_prompt` multiple times: `args.num_validation_images`" |
| 326 | + " and logging the images." |
| 327 | + ), |
| 328 | + ) |
271 | 329 | parser.add_argument( |
272 | 330 | "--validation_epochs", |
273 | 331 | type=int, |
274 | | - default=50, |
| 332 | + default=None, |
275 | 333 | help=( |
276 | | - "Run validation every X epochs. Validation consists of running the prompt" |
| 334 | + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" |
277 | 335 | " `args.validation_prompt` multiple times: `args.num_validation_images`" |
278 | 336 | " and logging the images." |
279 | 337 | ), |
@@ -488,7 +546,6 @@ def main(): |
488 | 546 | if args.report_to == "wandb": |
489 | 547 | if not is_wandb_available(): |
490 | 548 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") |
491 | | - import wandb |
492 | 549 |
|
493 | 550 | # Make one log on every process with the configuration for debugging. |
494 | 551 | logging.basicConfig( |
@@ -627,6 +684,15 @@ def main(): |
627 | 684 | train_dataloader = torch.utils.data.DataLoader( |
628 | 685 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers |
629 | 686 | ) |
| 687 | + if args.validation_epochs is not None: |
| 688 | + warnings.warn( |
| 689 | + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." |
| 690 | + " Deprecated validation_epochs in favor of `validation_steps`" |
| 691 | + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", |
| 692 | + FutureWarning, |
| 693 | + stacklevel=2, |
| 694 | + ) |
| 695 | + args.validation_steps = args.validation_epochs * len(train_dataset) |
630 | 696 |
|
631 | 697 | # Scheduler and math around the number of training steps. |
632 | 698 | overrode_max_train_steps = False |
@@ -683,7 +749,6 @@ def main(): |
683 | 749 | logger.info(f" Total optimization steps = {args.max_train_steps}") |
684 | 750 | global_step = 0 |
685 | 751 | first_epoch = 0 |
686 | | - |
687 | 752 | # Potentially load in the weights and states from a previous save |
688 | 753 | if args.resume_from_checkpoint: |
689 | 754 | if args.resume_from_checkpoint != "latest": |
@@ -783,60 +848,15 @@ def main(): |
783 | 848 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") |
784 | 849 | accelerator.save_state(save_path) |
785 | 850 | logger.info(f"Saved state to {save_path}") |
| 851 | + if args.validation_prompt is not None and global_step % args.validation_steps == 0: |
| 852 | + log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) |
786 | 853 |
|
787 | 854 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} |
788 | 855 | progress_bar.set_postfix(**logs) |
789 | 856 | accelerator.log(logs, step=global_step) |
790 | 857 |
|
791 | 858 | if global_step >= args.max_train_steps: |
792 | 859 | break |
793 | | - |
794 | | - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
795 | | - logger.info( |
796 | | - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
797 | | - f" {args.validation_prompt}." |
798 | | - ) |
799 | | - # create pipeline (note: unet and vae are loaded again in float32) |
800 | | - pipeline = DiffusionPipeline.from_pretrained( |
801 | | - args.pretrained_model_name_or_path, |
802 | | - text_encoder=accelerator.unwrap_model(text_encoder), |
803 | | - tokenizer=tokenizer, |
804 | | - unet=unet, |
805 | | - vae=vae, |
806 | | - revision=args.revision, |
807 | | - torch_dtype=weight_dtype, |
808 | | - ) |
809 | | - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
810 | | - pipeline = pipeline.to(accelerator.device) |
811 | | - pipeline.set_progress_bar_config(disable=True) |
812 | | - |
813 | | - # run inference |
814 | | - generator = ( |
815 | | - None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) |
816 | | - ) |
817 | | - images = [] |
818 | | - for _ in range(args.num_validation_images): |
819 | | - with torch.autocast("cuda"): |
820 | | - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] |
821 | | - images.append(image) |
822 | | - |
823 | | - for tracker in accelerator.trackers: |
824 | | - if tracker.name == "tensorboard": |
825 | | - np_images = np.stack([np.asarray(img) for img in images]) |
826 | | - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") |
827 | | - if tracker.name == "wandb": |
828 | | - tracker.log( |
829 | | - { |
830 | | - "validation": [ |
831 | | - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") |
832 | | - for i, image in enumerate(images) |
833 | | - ] |
834 | | - } |
835 | | - ) |
836 | | - |
837 | | - del pipeline |
838 | | - torch.cuda.empty_cache() |
839 | | - |
840 | 860 | # Create the pipeline using using the trained modules and save it. |
841 | 861 | accelerator.wait_for_everyone() |
842 | 862 | if accelerator.is_main_process: |
|
0 commit comments