11import argparse
2+ import inspect
23import math
34import os
45from pathlib import Path
3132# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
3233check_min_version ("0.10.0.dev0" )
3334
35+
3436logger = get_logger (__name__ )
3537
3638
39+ def _extract_into_tensor (arr , timesteps , broadcast_shape ):
40+ """
41+ Extract values from a 1-D numpy array for a batch of indices.
42+
43+ :param arr: the 1-D numpy array.
44+ :param timesteps: a tensor of indices into the array to extract.
45+ :param broadcast_shape: a larger shape of K dimensions with the batch
46+ dimension equal to the length of timesteps.
47+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
48+ """
49+ if not isinstance (arr , torch .Tensor ):
50+ arr = torch .from_numpy (arr )
51+ res = arr [timesteps ].float ().to (timesteps .device )
52+ while len (res .shape ) < len (broadcast_shape ):
53+ res = res [..., None ]
54+ return res .expand (broadcast_shape )
55+
56+
57+ def parse_args ():
58+ parser = argparse .ArgumentParser (description = "Simple example of a training script." )
59+ parser .add_argument (
60+ "--dataset_name" ,
61+ type = str ,
62+ default = None ,
63+ help = (
64+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
65+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
66+ " or to a folder containing files that HF Datasets can understand."
67+ ),
68+ )
69+ parser .add_argument (
70+ "--dataset_config_name" ,
71+ type = str ,
72+ default = None ,
73+ help = "The config of the Dataset, leave as None if there's only one config." ,
74+ )
75+ parser .add_argument (
76+ "--train_data_dir" ,
77+ type = str ,
78+ default = None ,
79+ help = (
80+ "A folder containing the training data. Folder contents must follow the structure described in"
81+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
82+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
83+ ),
84+ )
85+ parser .add_argument (
86+ "--output_dir" ,
87+ type = str ,
88+ default = "ddpm-model-64" ,
89+ help = "The output directory where the model predictions and checkpoints will be written." ,
90+ )
91+ parser .add_argument ("--overwrite_output_dir" , action = "store_true" )
92+ parser .add_argument (
93+ "--cache_dir" ,
94+ type = str ,
95+ default = None ,
96+ help = "The directory where the downloaded models and datasets will be stored." ,
97+ )
98+ parser .add_argument (
99+ "--resolution" ,
100+ type = int ,
101+ default = 64 ,
102+ help = (
103+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
104+ " resolution"
105+ ),
106+ )
107+ parser .add_argument (
108+ "--train_batch_size" , type = int , default = 16 , help = "Batch size (per device) for the training dataloader."
109+ )
110+ parser .add_argument (
111+ "--eval_batch_size" , type = int , default = 16 , help = "The number of images to generate for evaluation."
112+ )
113+ parser .add_argument (
114+ "--dataloader_num_workers" ,
115+ type = int ,
116+ default = 0 ,
117+ help = (
118+ "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
119+ " process."
120+ ),
121+ )
122+ parser .add_argument ("--num_epochs" , type = int , default = 100 )
123+ parser .add_argument ("--save_images_epochs" , type = int , default = 10 , help = "How often to save images during training." )
124+ parser .add_argument (
125+ "--save_model_epochs" , type = int , default = 10 , help = "How often to save the model during training."
126+ )
127+ parser .add_argument (
128+ "--gradient_accumulation_steps" ,
129+ type = int ,
130+ default = 1 ,
131+ help = "Number of updates steps to accumulate before performing a backward/update pass." ,
132+ )
133+ parser .add_argument (
134+ "--learning_rate" ,
135+ type = float ,
136+ default = 1e-4 ,
137+ help = "Initial learning rate (after the potential warmup period) to use." ,
138+ )
139+ parser .add_argument (
140+ "--lr_scheduler" ,
141+ type = str ,
142+ default = "cosine" ,
143+ help = (
144+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
145+ ' "constant", "constant_with_warmup"]'
146+ ),
147+ )
148+ parser .add_argument (
149+ "--lr_warmup_steps" , type = int , default = 500 , help = "Number of steps for the warmup in the lr scheduler."
150+ )
151+ parser .add_argument ("--adam_beta1" , type = float , default = 0.95 , help = "The beta1 parameter for the Adam optimizer." )
152+ parser .add_argument ("--adam_beta2" , type = float , default = 0.999 , help = "The beta2 parameter for the Adam optimizer." )
153+ parser .add_argument (
154+ "--adam_weight_decay" , type = float , default = 1e-6 , help = "Weight decay magnitude for the Adam optimizer."
155+ )
156+ parser .add_argument ("--adam_epsilon" , type = float , default = 1e-08 , help = "Epsilon value for the Adam optimizer." )
157+ parser .add_argument (
158+ "--use_ema" ,
159+ action = "store_true" ,
160+ default = True ,
161+ help = "Whether to use Exponential Moving Average for the final model weights." ,
162+ )
163+ parser .add_argument ("--ema_inv_gamma" , type = float , default = 1.0 , help = "The inverse gamma value for the EMA decay." )
164+ parser .add_argument ("--ema_power" , type = float , default = 3 / 4 , help = "The power value for the EMA decay." )
165+ parser .add_argument ("--ema_max_decay" , type = float , default = 0.9999 , help = "The maximum decay magnitude for EMA." )
166+ parser .add_argument ("--push_to_hub" , action = "store_true" , help = "Whether or not to push the model to the Hub." )
167+ parser .add_argument ("--hub_token" , type = str , default = None , help = "The token to use to push to the Model Hub." )
168+ parser .add_argument (
169+ "--hub_model_id" ,
170+ type = str ,
171+ default = None ,
172+ help = "The name of the repository to keep in sync with the local `output_dir`." ,
173+ )
174+ parser .add_argument (
175+ "--hub_private_repo" , action = "store_true" , help = "Whether or not to create a private repository."
176+ )
177+ parser .add_argument (
178+ "--logging_dir" ,
179+ type = str ,
180+ default = "logs" ,
181+ help = (
182+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
183+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
184+ ),
185+ )
186+ parser .add_argument ("--local_rank" , type = int , default = - 1 , help = "For distributed training: local_rank" )
187+ parser .add_argument (
188+ "--mixed_precision" ,
189+ type = str ,
190+ default = "no" ,
191+ choices = ["no" , "fp16" , "bf16" ],
192+ help = (
193+ "Whether to use mixed precision. Choose"
194+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
195+ "and an Nvidia Ampere GPU."
196+ ),
197+ )
198+
199+ parser .add_argument (
200+ "--prediction_type" ,
201+ type = str ,
202+ default = "epsilon" ,
203+ choices = ["epsilon" , "sample" ],
204+ help = "Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'." ,
205+ )
206+
207+ parser .add_argument ("--ddpm_num_steps" , type = int , default = 1000 )
208+ parser .add_argument ("--ddpm_beta_schedule" , type = str , default = "linear" )
209+
210+ args = parser .parse_args ()
211+ env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
212+ if env_local_rank != - 1 and env_local_rank != args .local_rank :
213+ args .local_rank = env_local_rank
214+
215+ if args .dataset_name is None and args .train_data_dir is None :
216+ raise ValueError ("You must specify either a dataset name from the hub or a train data directory." )
217+
218+ return args
219+
220+
37221def get_full_repo_name (model_id : str , organization : Optional [str ] = None , token : Optional [str ] = None ):
38222 if token is None :
39223 token = HfFolder .get_token ()
@@ -77,7 +261,17 @@ def main(args):
77261 ),
78262 )
79263 model = ORTModule (model )
80- noise_scheduler = DDPMScheduler (num_train_timesteps = 1000 )
264+ accepts_prediction_type = "prediction_type" in set (inspect .signature (DDPMScheduler .__init__ ).parameters .keys ())
265+
266+ if accepts_prediction_type :
267+ noise_scheduler = DDPMScheduler (
268+ num_train_timesteps = args .ddpm_num_steps ,
269+ beta_schedule = args .ddpm_beta_schedule ,
270+ prediction_type = args .prediction_type ,
271+ )
272+ else :
273+ noise_scheduler = DDPMScheduler (num_train_timesteps = args .ddpm_num_steps , beta_schedule = args .ddpm_beta_schedule )
274+
81275 optimizer = torch .optim .AdamW (
82276 model .parameters (),
83277 lr = args .learning_rate ,
@@ -101,7 +295,6 @@ def main(args):
101295 args .dataset_name ,
102296 args .dataset_config_name ,
103297 cache_dir = args .cache_dir ,
104- use_auth_token = True if args .use_auth_token else None ,
105298 split = "train" ,
106299 )
107300 else :
@@ -111,8 +304,12 @@ def transforms(examples):
111304 images = [augmentations (image .convert ("RGB" )) for image in examples ["image" ]]
112305 return {"input" : images }
113306
307+ logger .info (f"Dataset size: { len (dataset )} " )
308+
114309 dataset .set_transform (transforms )
115- train_dataloader = torch .utils .data .DataLoader (dataset , batch_size = args .train_batch_size , shuffle = True )
310+ train_dataloader = torch .utils .data .DataLoader (
311+ dataset , batch_size = args .train_batch_size , shuffle = True , num_workers = args .dataloader_num_workers
312+ )
116313
117314 lr_scheduler = get_scheduler (
118315 args .lr_scheduler ,
@@ -127,7 +324,12 @@ def transforms(examples):
127324
128325 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
129326
130- ema_model = EMAModel (model , inv_gamma = args .ema_inv_gamma , power = args .ema_power , max_value = args .ema_max_decay )
327+ ema_model = EMAModel (
328+ accelerator .unwrap_model (model ),
329+ inv_gamma = args .ema_inv_gamma ,
330+ power = args .ema_power ,
331+ max_value = args .ema_max_decay ,
332+ )
131333
132334 # Handle the repository creation
133335 if accelerator .is_main_process :
@@ -171,11 +373,26 @@ def transforms(examples):
171373
172374 with accelerator .accumulate (model ):
173375 # Predict the noise residual
174- noise_pred = model (noisy_images , timesteps , return_dict = True )[0 ]
175- loss = F .mse_loss (noise_pred , noise )
376+ model_output = model (noisy_images , timesteps , return_dict = True )[0 ]
377+
378+ if args .prediction_type == "epsilon" :
379+ loss = F .mse_loss (model_output , noise ) # this could have different weights!
380+ elif args .prediction_type == "sample" :
381+ alpha_t = _extract_into_tensor (
382+ noise_scheduler .alphas_cumprod , timesteps , (clean_images .shape [0 ], 1 , 1 , 1 )
383+ )
384+ snr_weights = alpha_t / (1 - alpha_t )
385+ loss = snr_weights * F .mse_loss (
386+ model_output , clean_images , reduction = "none"
387+ ) # use SNR weighting from distillation paper
388+ loss = loss .mean ()
389+ else :
390+ raise ValueError (f"Unsupported prediction type: { args .prediction_type } " )
391+
176392 accelerator .backward (loss )
177393
178- accelerator .clip_grad_norm_ (model .parameters (), 1.0 )
394+ if accelerator .sync_gradients :
395+ accelerator .clip_grad_norm_ (model .parameters (), 1.0 )
179396 optimizer .step ()
180397 lr_scheduler .step ()
181398 if args .use_ema :
@@ -204,9 +421,13 @@ def transforms(examples):
204421 scheduler = noise_scheduler ,
205422 )
206423
207- generator = torch .manual_seed (0 )
424+ generator = torch .Generator ( device = pipeline . device ). manual_seed (0 )
208425 # run pipeline in inference (sample random noise and denoise)
209- images = pipeline (generator = generator , batch_size = args .eval_batch_size , output_type = "numpy" ).images
426+ images = pipeline (
427+ generator = generator ,
428+ batch_size = args .eval_batch_size ,
429+ output_type = "numpy" ,
430+ ).images
210431
211432 # denormalize the images and save to tensorboard
212433 images_processed = (images * 255 ).round ().astype ("uint8" )
@@ -225,56 +446,5 @@ def transforms(examples):
225446
226447
227448if __name__ == "__main__" :
228- parser = argparse .ArgumentParser (description = "Simple example of a training script." )
229- parser .add_argument ("--local_rank" , type = int , default = - 1 )
230- parser .add_argument ("--dataset_name" , type = str , default = None )
231- parser .add_argument ("--dataset_config_name" , type = str , default = None )
232- parser .add_argument ("--train_data_dir" , type = str , default = None , help = "A folder containing the training data." )
233- parser .add_argument ("--output_dir" , type = str , default = "ddpm-model-64" )
234- parser .add_argument ("--overwrite_output_dir" , action = "store_true" )
235- parser .add_argument ("--cache_dir" , type = str , default = None )
236- parser .add_argument ("--resolution" , type = int , default = 64 )
237- parser .add_argument ("--train_batch_size" , type = int , default = 16 )
238- parser .add_argument ("--eval_batch_size" , type = int , default = 16 )
239- parser .add_argument ("--num_epochs" , type = int , default = 100 )
240- parser .add_argument ("--save_images_epochs" , type = int , default = 10 )
241- parser .add_argument ("--save_model_epochs" , type = int , default = 10 )
242- parser .add_argument ("--gradient_accumulation_steps" , type = int , default = 1 )
243- parser .add_argument ("--learning_rate" , type = float , default = 1e-4 )
244- parser .add_argument ("--lr_scheduler" , type = str , default = "cosine" )
245- parser .add_argument ("--lr_warmup_steps" , type = int , default = 500 )
246- parser .add_argument ("--adam_beta1" , type = float , default = 0.95 )
247- parser .add_argument ("--adam_beta2" , type = float , default = 0.999 )
248- parser .add_argument ("--adam_weight_decay" , type = float , default = 1e-6 )
249- parser .add_argument ("--adam_epsilon" , type = float , default = 1e-08 )
250- parser .add_argument ("--use_ema" , action = "store_true" , default = True )
251- parser .add_argument ("--ema_inv_gamma" , type = float , default = 1.0 )
252- parser .add_argument ("--ema_power" , type = float , default = 3 / 4 )
253- parser .add_argument ("--ema_max_decay" , type = float , default = 0.9999 )
254- parser .add_argument ("--push_to_hub" , action = "store_true" )
255- parser .add_argument ("--use_auth_token" , action = "store_true" )
256- parser .add_argument ("--hub_token" , type = str , default = None )
257- parser .add_argument ("--hub_model_id" , type = str , default = None )
258- parser .add_argument ("--hub_private_repo" , action = "store_true" )
259- parser .add_argument ("--logging_dir" , type = str , default = "logs" )
260- parser .add_argument (
261- "--mixed_precision" ,
262- type = str ,
263- default = "no" ,
264- choices = ["no" , "fp16" , "bf16" ],
265- help = (
266- "Whether to use mixed precision. Choose"
267- "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
268- "and an Nvidia Ampere GPU."
269- ),
270- )
271-
272- args = parser .parse_args ()
273- env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
274- if env_local_rank != - 1 and env_local_rank != args .local_rank :
275- args .local_rank = env_local_rank
276-
277- if args .dataset_name is None and args .train_data_dir is None :
278- raise ValueError ("You must specify either a dataset name from the hub or a train data directory." )
279-
449+ args = parse_args ()
280450 main (args )
0 commit comments