Skip to content

Commit 64d1957

Browse files
prathikrPrathik Rao
authored andcommitted
manually update train_unconditional_ort (huggingface#1694)
* manually update train_unconditional_ort * formatting Co-authored-by: Prathik Rao <[email protected]>
1 parent d72d732 commit 64d1957

File tree

1 file changed

+231
-61
lines changed

1 file changed

+231
-61
lines changed

examples/unconditional_image_generation/train_unconditional_ort.py

Lines changed: 231 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import inspect
23
import math
34
import os
45
from pathlib import Path
@@ -31,9 +32,192 @@
3132
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
3233
check_min_version("0.10.0.dev0")
3334

35+
3436
logger = get_logger(__name__)
3537

3638

39+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
40+
"""
41+
Extract values from a 1-D numpy array for a batch of indices.
42+
43+
:param arr: the 1-D numpy array.
44+
:param timesteps: a tensor of indices into the array to extract.
45+
:param broadcast_shape: a larger shape of K dimensions with the batch
46+
dimension equal to the length of timesteps.
47+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
48+
"""
49+
if not isinstance(arr, torch.Tensor):
50+
arr = torch.from_numpy(arr)
51+
res = arr[timesteps].float().to(timesteps.device)
52+
while len(res.shape) < len(broadcast_shape):
53+
res = res[..., None]
54+
return res.expand(broadcast_shape)
55+
56+
57+
def parse_args():
58+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
59+
parser.add_argument(
60+
"--dataset_name",
61+
type=str,
62+
default=None,
63+
help=(
64+
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
65+
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
66+
" or to a folder containing files that HF Datasets can understand."
67+
),
68+
)
69+
parser.add_argument(
70+
"--dataset_config_name",
71+
type=str,
72+
default=None,
73+
help="The config of the Dataset, leave as None if there's only one config.",
74+
)
75+
parser.add_argument(
76+
"--train_data_dir",
77+
type=str,
78+
default=None,
79+
help=(
80+
"A folder containing the training data. Folder contents must follow the structure described in"
81+
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
82+
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
83+
),
84+
)
85+
parser.add_argument(
86+
"--output_dir",
87+
type=str,
88+
default="ddpm-model-64",
89+
help="The output directory where the model predictions and checkpoints will be written.",
90+
)
91+
parser.add_argument("--overwrite_output_dir", action="store_true")
92+
parser.add_argument(
93+
"--cache_dir",
94+
type=str,
95+
default=None,
96+
help="The directory where the downloaded models and datasets will be stored.",
97+
)
98+
parser.add_argument(
99+
"--resolution",
100+
type=int,
101+
default=64,
102+
help=(
103+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
104+
" resolution"
105+
),
106+
)
107+
parser.add_argument(
108+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
109+
)
110+
parser.add_argument(
111+
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
112+
)
113+
parser.add_argument(
114+
"--dataloader_num_workers",
115+
type=int,
116+
default=0,
117+
help=(
118+
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
119+
" process."
120+
),
121+
)
122+
parser.add_argument("--num_epochs", type=int, default=100)
123+
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
124+
parser.add_argument(
125+
"--save_model_epochs", type=int, default=10, help="How often to save the model during training."
126+
)
127+
parser.add_argument(
128+
"--gradient_accumulation_steps",
129+
type=int,
130+
default=1,
131+
help="Number of updates steps to accumulate before performing a backward/update pass.",
132+
)
133+
parser.add_argument(
134+
"--learning_rate",
135+
type=float,
136+
default=1e-4,
137+
help="Initial learning rate (after the potential warmup period) to use.",
138+
)
139+
parser.add_argument(
140+
"--lr_scheduler",
141+
type=str,
142+
default="cosine",
143+
help=(
144+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
145+
' "constant", "constant_with_warmup"]'
146+
),
147+
)
148+
parser.add_argument(
149+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
150+
)
151+
parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
152+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
153+
parser.add_argument(
154+
"--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
155+
)
156+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
157+
parser.add_argument(
158+
"--use_ema",
159+
action="store_true",
160+
default=True,
161+
help="Whether to use Exponential Moving Average for the final model weights.",
162+
)
163+
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
164+
parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
165+
parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
166+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
167+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
168+
parser.add_argument(
169+
"--hub_model_id",
170+
type=str,
171+
default=None,
172+
help="The name of the repository to keep in sync with the local `output_dir`.",
173+
)
174+
parser.add_argument(
175+
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
176+
)
177+
parser.add_argument(
178+
"--logging_dir",
179+
type=str,
180+
default="logs",
181+
help=(
182+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
183+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
184+
),
185+
)
186+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
187+
parser.add_argument(
188+
"--mixed_precision",
189+
type=str,
190+
default="no",
191+
choices=["no", "fp16", "bf16"],
192+
help=(
193+
"Whether to use mixed precision. Choose"
194+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
195+
"and an Nvidia Ampere GPU."
196+
),
197+
)
198+
199+
parser.add_argument(
200+
"--prediction_type",
201+
type=str,
202+
default="epsilon",
203+
choices=["epsilon", "sample"],
204+
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
205+
)
206+
207+
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
208+
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
209+
210+
args = parser.parse_args()
211+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
212+
if env_local_rank != -1 and env_local_rank != args.local_rank:
213+
args.local_rank = env_local_rank
214+
215+
if args.dataset_name is None and args.train_data_dir is None:
216+
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
217+
218+
return args
219+
220+
37221
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
38222
if token is None:
39223
token = HfFolder.get_token()
@@ -77,7 +261,17 @@ def main(args):
77261
),
78262
)
79263
model = ORTModule(model)
80-
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
264+
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
265+
266+
if accepts_prediction_type:
267+
noise_scheduler = DDPMScheduler(
268+
num_train_timesteps=args.ddpm_num_steps,
269+
beta_schedule=args.ddpm_beta_schedule,
270+
prediction_type=args.prediction_type,
271+
)
272+
else:
273+
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
274+
81275
optimizer = torch.optim.AdamW(
82276
model.parameters(),
83277
lr=args.learning_rate,
@@ -101,7 +295,6 @@ def main(args):
101295
args.dataset_name,
102296
args.dataset_config_name,
103297
cache_dir=args.cache_dir,
104-
use_auth_token=True if args.use_auth_token else None,
105298
split="train",
106299
)
107300
else:
@@ -111,8 +304,12 @@ def transforms(examples):
111304
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
112305
return {"input": images}
113306

307+
logger.info(f"Dataset size: {len(dataset)}")
308+
114309
dataset.set_transform(transforms)
115-
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
310+
train_dataloader = torch.utils.data.DataLoader(
311+
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
312+
)
116313

117314
lr_scheduler = get_scheduler(
118315
args.lr_scheduler,
@@ -127,7 +324,12 @@ def transforms(examples):
127324

128325
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
129326

130-
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
327+
ema_model = EMAModel(
328+
accelerator.unwrap_model(model),
329+
inv_gamma=args.ema_inv_gamma,
330+
power=args.ema_power,
331+
max_value=args.ema_max_decay,
332+
)
131333

132334
# Handle the repository creation
133335
if accelerator.is_main_process:
@@ -171,11 +373,26 @@ def transforms(examples):
171373

172374
with accelerator.accumulate(model):
173375
# Predict the noise residual
174-
noise_pred = model(noisy_images, timesteps, return_dict=True)[0]
175-
loss = F.mse_loss(noise_pred, noise)
376+
model_output = model(noisy_images, timesteps, return_dict=True)[0]
377+
378+
if args.prediction_type == "epsilon":
379+
loss = F.mse_loss(model_output, noise) # this could have different weights!
380+
elif args.prediction_type == "sample":
381+
alpha_t = _extract_into_tensor(
382+
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
383+
)
384+
snr_weights = alpha_t / (1 - alpha_t)
385+
loss = snr_weights * F.mse_loss(
386+
model_output, clean_images, reduction="none"
387+
) # use SNR weighting from distillation paper
388+
loss = loss.mean()
389+
else:
390+
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
391+
176392
accelerator.backward(loss)
177393

178-
accelerator.clip_grad_norm_(model.parameters(), 1.0)
394+
if accelerator.sync_gradients:
395+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
179396
optimizer.step()
180397
lr_scheduler.step()
181398
if args.use_ema:
@@ -204,9 +421,13 @@ def transforms(examples):
204421
scheduler=noise_scheduler,
205422
)
206423

207-
generator = torch.manual_seed(0)
424+
generator = torch.Generator(device=pipeline.device).manual_seed(0)
208425
# run pipeline in inference (sample random noise and denoise)
209-
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
426+
images = pipeline(
427+
generator=generator,
428+
batch_size=args.eval_batch_size,
429+
output_type="numpy",
430+
).images
210431

211432
# denormalize the images and save to tensorboard
212433
images_processed = (images * 255).round().astype("uint8")
@@ -225,56 +446,5 @@ def transforms(examples):
225446

226447

227448
if __name__ == "__main__":
228-
parser = argparse.ArgumentParser(description="Simple example of a training script.")
229-
parser.add_argument("--local_rank", type=int, default=-1)
230-
parser.add_argument("--dataset_name", type=str, default=None)
231-
parser.add_argument("--dataset_config_name", type=str, default=None)
232-
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
233-
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
234-
parser.add_argument("--overwrite_output_dir", action="store_true")
235-
parser.add_argument("--cache_dir", type=str, default=None)
236-
parser.add_argument("--resolution", type=int, default=64)
237-
parser.add_argument("--train_batch_size", type=int, default=16)
238-
parser.add_argument("--eval_batch_size", type=int, default=16)
239-
parser.add_argument("--num_epochs", type=int, default=100)
240-
parser.add_argument("--save_images_epochs", type=int, default=10)
241-
parser.add_argument("--save_model_epochs", type=int, default=10)
242-
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
243-
parser.add_argument("--learning_rate", type=float, default=1e-4)
244-
parser.add_argument("--lr_scheduler", type=str, default="cosine")
245-
parser.add_argument("--lr_warmup_steps", type=int, default=500)
246-
parser.add_argument("--adam_beta1", type=float, default=0.95)
247-
parser.add_argument("--adam_beta2", type=float, default=0.999)
248-
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
249-
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
250-
parser.add_argument("--use_ema", action="store_true", default=True)
251-
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
252-
parser.add_argument("--ema_power", type=float, default=3 / 4)
253-
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
254-
parser.add_argument("--push_to_hub", action="store_true")
255-
parser.add_argument("--use_auth_token", action="store_true")
256-
parser.add_argument("--hub_token", type=str, default=None)
257-
parser.add_argument("--hub_model_id", type=str, default=None)
258-
parser.add_argument("--hub_private_repo", action="store_true")
259-
parser.add_argument("--logging_dir", type=str, default="logs")
260-
parser.add_argument(
261-
"--mixed_precision",
262-
type=str,
263-
default="no",
264-
choices=["no", "fp16", "bf16"],
265-
help=(
266-
"Whether to use mixed precision. Choose"
267-
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
268-
"and an Nvidia Ampere GPU."
269-
),
270-
)
271-
272-
args = parser.parse_args()
273-
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
274-
if env_local_rank != -1 and env_local_rank != args.local_rank:
275-
args.local_rank = env_local_rank
276-
277-
if args.dataset_name is None and args.train_data_dir is None:
278-
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
279-
449+
args = parse_args()
280450
main(args)

0 commit comments

Comments
 (0)