5555from diffusers .utils .torch_utils import is_compiled_module
5656
5757
58+ if is_wandb_available ():
59+ import wandb
60+
5861# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962check_min_version ("0.26.0.dev0" )
6063
6770TORCH_DTYPE_MAPPING = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
6871
6972
73+ def log_validation (
74+ pipeline ,
75+ args ,
76+ accelerator ,
77+ generator ,
78+ global_step ,
79+ is_final_validation = False ,
80+ ):
81+ logger .info (
82+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
83+ f" { args .validation_prompt } ."
84+ )
85+
86+ pipeline = pipeline .to (accelerator .device )
87+ pipeline .set_progress_bar_config (disable = True )
88+
89+ val_save_dir = os .path .join (args .output_dir , "validation_images" )
90+ if not os .path .exists (val_save_dir ):
91+ os .makedirs (val_save_dir )
92+
93+ original_image = (
94+ lambda image_url_or_path : load_image (image_url_or_path )
95+ if urlparse (image_url_or_path ).scheme
96+ else Image .open (image_url_or_path ).convert ("RGB" )
97+ )(args .val_image_url_or_path )
98+
99+ with torch .autocast (str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16" ):
100+ edited_images = []
101+ # Run inference
102+ for val_img_idx in range (args .num_validation_images ):
103+ a_val_img = pipeline (
104+ args .validation_prompt ,
105+ image = original_image ,
106+ num_inference_steps = 20 ,
107+ image_guidance_scale = 1.5 ,
108+ guidance_scale = 7 ,
109+ generator = generator ,
110+ ).images [0 ]
111+ edited_images .append (a_val_img )
112+ # Save validation images
113+ a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
114+
115+ for tracker in accelerator .trackers :
116+ if tracker .name == "wandb" :
117+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
118+ for edited_image in edited_images :
119+ wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
120+ logger_name = "test" if is_final_validation else "validation"
121+ tracker .log ({logger_name : wandb_table })
122+
123+
70124def import_model_class_from_model_name_or_path (
71125 pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
72126):
@@ -447,11 +501,6 @@ def main():
447501
448502 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
449503
450- if args .report_to == "wandb" :
451- if not is_wandb_available ():
452- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
453- import wandb
454-
455504 # Make one log on every process with the configuration for debugging.
456505 logging .basicConfig (
457506 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1111,11 +1160,6 @@ def collate_fn(examples):
11111160 ### BEGIN: Perform validation every `validation_epochs` steps
11121161 if global_step % args .validation_steps == 0 :
11131162 if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1114- logger .info (
1115- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1116- f" { args .validation_prompt } ."
1117- )
1118-
11191163 # create pipeline
11201164 if args .use_ema :
11211165 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1179,16 @@ def collate_fn(examples):
11351179 variant = args .variant ,
11361180 torch_dtype = weight_dtype ,
11371181 )
1138- pipeline = pipeline .to (accelerator .device )
1139- pipeline .set_progress_bar_config (disable = True )
1140-
1141- # run inference
1142- # Save validation images
1143- val_save_dir = os .path .join (args .output_dir , "validation_images" )
1144- if not os .path .exists (val_save_dir ):
1145- os .makedirs (val_save_dir )
1146-
1147- original_image = (
1148- lambda image_url_or_path : load_image (image_url_or_path )
1149- if urlparse (image_url_or_path ).scheme
1150- else Image .open (image_url_or_path ).convert ("RGB" )
1151- )(args .val_image_url_or_path )
1152- with torch .autocast (
1153- str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
1154- ):
1155- edited_images = []
1156- for val_img_idx in range (args .num_validation_images ):
1157- a_val_img = pipeline (
1158- args .validation_prompt ,
1159- image = original_image ,
1160- num_inference_steps = 20 ,
1161- image_guidance_scale = 1.5 ,
1162- guidance_scale = 7 ,
1163- generator = generator ,
1164- ).images [0 ]
1165- edited_images .append (a_val_img )
1166- a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
1167-
1168- for tracker in accelerator .trackers :
1169- if tracker .name == "wandb" :
1170- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1171- for edited_image in edited_images :
1172- wandb_table .add_data (
1173- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1174- )
1175- tracker .log ({"validation" : wandb_table })
1182+
1183+ log_validation (
1184+ pipeline ,
1185+ args ,
1186+ accelerator ,
1187+ generator ,
1188+ global_step ,
1189+ is_final_validation = False ,
1190+ )
1191+
11761192 if args .use_ema :
11771193 # Switch back to the original UNet parameters.
11781194 ema_unet .restore (unet .parameters ())
@@ -1187,7 +1203,6 @@ def collate_fn(examples):
11871203 # Create the pipeline using the trained modules and save it.
11881204 accelerator .wait_for_everyone ()
11891205 if accelerator .is_main_process :
1190- unet = unwrap_model (unet )
11911206 if args .use_ema :
11921207 ema_unet .copy_to (unet .parameters ())
11931208
@@ -1198,10 +1213,11 @@ def collate_fn(examples):
11981213 tokenizer = tokenizer_1 ,
11991214 tokenizer_2 = tokenizer_2 ,
12001215 vae = vae ,
1201- unet = unet ,
1216+ unet = unwrap_model ( unet ) ,
12021217 revision = args .revision ,
12031218 variant = args .variant ,
12041219 )
1220+
12051221 pipeline .save_pretrained (args .output_dir )
12061222
12071223 if args .push_to_hub :
@@ -1212,30 +1228,15 @@ def collate_fn(examples):
12121228 ignore_patterns = ["step_*" , "epoch_*" ],
12131229 )
12141230
1215- if args .validation_prompt is not None :
1216- edited_images = []
1217- pipeline = pipeline .to (accelerator .device )
1218- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1219- for _ in range (args .num_validation_images ):
1220- edited_images .append (
1221- pipeline (
1222- args .validation_prompt ,
1223- image = original_image ,
1224- num_inference_steps = 20 ,
1225- image_guidance_scale = 1.5 ,
1226- guidance_scale = 7 ,
1227- generator = generator ,
1228- ).images [0 ]
1229- )
1230-
1231- for tracker in accelerator .trackers :
1232- if tracker .name == "wandb" :
1233- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1234- for edited_image in edited_images :
1235- wandb_table .add_data (
1236- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1237- )
1238- tracker .log ({"test" : wandb_table })
1231+ if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1232+ log_validation (
1233+ pipeline ,
1234+ args ,
1235+ accelerator ,
1236+ generator ,
1237+ global_step ,
1238+ is_final_validation = True ,
1239+ )
12391240
12401241 accelerator .end_training ()
12411242
0 commit comments