Skip to content

Commit a7c93f5

Browse files
pcuencapatrickvonplaten
authored andcommitted
Add state checkpointing to other training scripts (huggingface#1687)
* Add state checkpointing to other training scripts * Fix first_epoch * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Update Dreambooth checkpoint help message. * Dreambooth docs: checkpoints, inference from a checkpoint. * make style Co-authored-by: Patrick von Platen <[email protected]>
1 parent ec31ea1 commit a7c93f5

File tree

6 files changed

+267
-25
lines changed

6 files changed

+267
-25
lines changed

docs/source/training/dreambooth.mdx

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/m
2121

2222
<Tip warning={true}>
2323

24-
<!-- TODO: replace with our blog when it's done -->
25-
2624
Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there.
2725

2826
</Tip>
@@ -44,17 +42,9 @@ Then initialize and configure a [🤗 Accelerate](https://github.com/huggingface
4442
accelerate config
4543
```
4644

47-
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
48-
49-
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
50-
51-
Run the following command to authenticate your token
52-
53-
```bash
54-
huggingface-cli login
55-
```
45+
In this example we'll use model version `v1-4`, so please visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4) and carefully read the license before proceeding.
5646

57-
If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there.
47+
The command below will download and cache the model weights from the Hub because we use the model's Hub id `CompVis/stable-diffusion-v1-4`. You may also clone the repo locally and use the local path in your system where the checkout was saved.
5848

5949
### Dog toy example
6050

@@ -111,6 +101,59 @@ accelerate launch train_dreambooth.py \
111101
--max_train_steps=800
112102
```
113103

104+
### Saving checkpoints while training
105+
106+
It's easy to overfit while training with Dreambooth, so sometimes it's useful to save regular checkpoints during the process. One of the intermediate checkpoints might work better than the final model! To use this feature you need to pass the following argument to the training script:
107+
108+
```bash
109+
--checkpointing_steps=500
110+
```
111+
112+
This will save the full training state in subfolders of your `output_dir`. Subfolder names begin with the prefix `checkpoint-`, and then the number of steps performed so far; for example: `checkpoint-1500` would be a checkpoint saved after 1500 training steps.
113+
114+
#### Resuming training from a saved checkpoint
115+
116+
If you want to resume training from any of the saved checkpoints, you can pass the argument `--resume_from_checkpoint` and then indicate the name of the checkpoint you want to use. You can also use the special string `"latest"` to resume from the last checkpoint saved (i.e., the one with the largest number of steps). For example, the following would resume training from the checkpoint saved after 1500 steps:
117+
118+
```bash
119+
--resume_from_checkpoint="checkpoint-1500"
120+
```
121+
122+
This would be a good opportunity to tweak some of your hyperparameters if you wish.
123+
124+
#### Performing inference using a saved checkpoint
125+
126+
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
127+
128+
You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it:
129+
130+
```python
131+
from accelerate import Accelerator
132+
from diffusers import DiffusionPipeline
133+
134+
# Load the pipeline with the same arguments (model, revision) that were used for training
135+
model_id = "CompVis/stable-diffusion-v1-4"
136+
pipeline = DiffusionPipeline.from_pretrained(model_id)
137+
138+
accelerator = Accelerator()
139+
140+
# Use text_encoder if `--train_text_encoder` was used for the initial training
141+
unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder)
142+
143+
# Restore state from a checkpoint path. You have to use the absolute path here.
144+
accelerator.load_state("/sddata/dreambooth/daruma-v2-1/checkpoint-100")
145+
146+
# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too)
147+
pipeline = DiffusionPipeline.from_pretrained(
148+
model_id,
149+
unet=accelerator.unwrap_model(unet),
150+
text_encoder=accelerator.unwrap_model(text_encoder),
151+
)
152+
153+
# Perform inference, or save, or push to the hub
154+
pipeline.save_pretrained("dreambooth-pipeline")
155+
```
156+
114157
### Training on a 16GB GPU
115158

116159
With the help of gradient checkpointing and the 8-bit optimizer from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), it's possible to train dreambooth on a 16GB GPU.

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def parse_args(input_args=None):
155155
type=int,
156156
default=500,
157157
help=(
158-
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
158+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
159+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
159160
" training using `--resume_from_checkpoint`."
160161
),
161162
)

examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,25 @@ def parse_args():
242242
),
243243
)
244244
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
245+
parser.add_argument(
246+
"--checkpointing_steps",
247+
type=int,
248+
default=500,
249+
help=(
250+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
251+
" checkpoints in case they are better than the last checkpoint and are suitable for resuming training"
252+
" using `--resume_from_checkpoint`."
253+
),
254+
)
255+
parser.add_argument(
256+
"--resume_from_checkpoint",
257+
type=str,
258+
default=None,
259+
help=(
260+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
261+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
262+
),
263+
)
245264

246265
args = parser.parse_args()
247266
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -591,6 +610,7 @@ def collate_fn(examples):
591610
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
592611
unet, optimizer, train_dataloader, lr_scheduler
593612
)
613+
accelerator.register_for_checkpointing(lr_scheduler)
594614

595615
weight_dtype = torch.float32
596616
if args.mixed_precision == "fp16":
@@ -628,14 +648,39 @@ def collate_fn(examples):
628648
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
629649
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
630650
logger.info(f" Total optimization steps = {args.max_train_steps}")
651+
global_step = 0
652+
first_epoch = 0
653+
654+
if args.resume_from_checkpoint:
655+
if args.resume_from_checkpoint != "latest":
656+
path = os.path.basename(args.resume_from_checkpoint)
657+
else:
658+
# Get the most recent checkpoint
659+
dirs = os.listdir(args.output_dir)
660+
dirs = [d for d in dirs if d.startswith("checkpoint")]
661+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
662+
path = dirs[-1]
663+
accelerator.print(f"Resuming from checkpoint {path}")
664+
accelerator.load_state(os.path.join(args.output_dir, path))
665+
global_step = int(path.split("-")[1])
666+
667+
resume_global_step = global_step * args.gradient_accumulation_steps
668+
first_epoch = resume_global_step // num_update_steps_per_epoch
669+
resume_step = resume_global_step % num_update_steps_per_epoch
670+
631671
# Only show the progress bar once on each machine.
632-
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
672+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
633673
progress_bar.set_description("Steps")
634-
global_step = 0
635674

636-
for epoch in range(args.num_train_epochs):
675+
for epoch in range(first_epoch, args.num_epochs):
637676
unet.train()
638677
for step, batch in enumerate(train_dataloader):
678+
# Skip steps until we reach the resumed step
679+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
680+
if step % args.gradient_accumulation_steps == 0:
681+
progress_bar.update(1)
682+
continue
683+
639684
with accelerator.accumulate(unet):
640685
# Convert images to latent space
641686

@@ -719,6 +764,12 @@ def collate_fn(examples):
719764
progress_bar.update(1)
720765
global_step += 1
721766

767+
if global_step % args.checkpointing_steps == 0:
768+
if accelerator.is_main_process:
769+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
770+
accelerator.save_state(save_path)
771+
logger.info(f"Saved state to {save_path}")
772+
722773
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
723774
progress_bar.set_postfix(**logs)
724775
accelerator.log(logs, step=global_step)

examples/text_to_image/train_text_to_image.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,24 @@ def parse_args():
216216
),
217217
)
218218
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
219+
parser.add_argument(
220+
"--checkpointing_steps",
221+
type=int,
222+
default=500,
223+
help=(
224+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
225+
" training using `--resume_from_checkpoint`."
226+
),
227+
)
228+
parser.add_argument(
229+
"--resume_from_checkpoint",
230+
type=str,
231+
default=None,
232+
help=(
233+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
234+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
235+
),
236+
)
219237

220238
args = parser.parse_args()
221239
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -528,6 +546,7 @@ def collate_fn(examples):
528546
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
529547
unet, optimizer, train_dataloader, lr_scheduler
530548
)
549+
accelerator.register_for_checkpointing(lr_scheduler)
531550

532551
weight_dtype = torch.float32
533552
if accelerator.mixed_precision == "fp16":
@@ -567,16 +586,40 @@ def collate_fn(examples):
567586
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
568587
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
569588
logger.info(f" Total optimization steps = {args.max_train_steps}")
589+
global_step = 0
590+
first_epoch = 0
591+
592+
if args.resume_from_checkpoint:
593+
if args.resume_from_checkpoint != "latest":
594+
path = os.path.basename(args.resume_from_checkpoint)
595+
else:
596+
# Get the most recent checkpoint
597+
dirs = os.listdir(args.output_dir)
598+
dirs = [d for d in dirs if d.startswith("checkpoint")]
599+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
600+
path = dirs[-1]
601+
accelerator.print(f"Resuming from checkpoint {path}")
602+
accelerator.load_state(os.path.join(args.output_dir, path))
603+
global_step = int(path.split("-")[1])
604+
605+
resume_global_step = global_step * args.gradient_accumulation_steps
606+
first_epoch = resume_global_step // num_update_steps_per_epoch
607+
resume_step = resume_global_step % num_update_steps_per_epoch
570608

571609
# Only show the progress bar once on each machine.
572-
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
610+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
573611
progress_bar.set_description("Steps")
574-
global_step = 0
575612

576-
for epoch in range(args.num_train_epochs):
613+
for epoch in range(first_epoch, args.num_train_epochs):
577614
unet.train()
578615
train_loss = 0.0
579616
for step, batch in enumerate(train_dataloader):
617+
# Skip steps until we reach the resumed step
618+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
619+
if step % args.gradient_accumulation_steps == 0:
620+
progress_bar.update(1)
621+
continue
622+
580623
with accelerator.accumulate(unet):
581624
# Convert images to latent space
582625
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
@@ -629,6 +672,12 @@ def collate_fn(examples):
629672
accelerator.log({"train_loss": train_loss}, step=global_step)
630673
train_loss = 0.0
631674

675+
if global_step % args.checkpointing_steps == 0:
676+
if accelerator.is_main_process:
677+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
678+
accelerator.save_state(save_path)
679+
logger.info(f"Saved state to {save_path}")
680+
632681
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
633682
progress_bar.set_postfix(**logs)
634683

examples/textual_inversion/textual_inversion.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,24 @@ def parse_args():
205205
),
206206
)
207207
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
208+
parser.add_argument(
209+
"--checkpointing_steps",
210+
type=int,
211+
default=500,
212+
help=(
213+
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
214+
" training using `--resume_from_checkpoint`."
215+
),
216+
)
217+
parser.add_argument(
218+
"--resume_from_checkpoint",
219+
type=str,
220+
default=None,
221+
help=(
222+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
223+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
224+
),
225+
)
208226

209227
args = parser.parse_args()
210228
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -512,6 +530,7 @@ def main():
512530
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
513531
text_encoder, optimizer, train_dataloader, lr_scheduler
514532
)
533+
accelerator.register_for_checkpointing(lr_scheduler)
515534

516535
# Move vae and unet to device
517536
vae.to(accelerator.device)
@@ -543,17 +562,42 @@ def main():
543562
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
544563
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
545564
logger.info(f" Total optimization steps = {args.max_train_steps}")
565+
global_step = 0
566+
first_epoch = 0
567+
568+
if args.resume_from_checkpoint:
569+
if args.resume_from_checkpoint != "latest":
570+
path = os.path.basename(args.resume_from_checkpoint)
571+
else:
572+
# Get the most recent checkpoint
573+
dirs = os.listdir(args.output_dir)
574+
dirs = [d for d in dirs if d.startswith("checkpoint")]
575+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
576+
path = dirs[-1]
577+
accelerator.print(f"Resuming from checkpoint {path}")
578+
accelerator.load_state(os.path.join(args.output_dir, path))
579+
global_step = int(path.split("-")[1])
580+
581+
resume_global_step = global_step * args.gradient_accumulation_steps
582+
first_epoch = resume_global_step // num_update_steps_per_epoch
583+
resume_step = resume_global_step % num_update_steps_per_epoch
584+
546585
# Only show the progress bar once on each machine.
547-
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
586+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
548587
progress_bar.set_description("Steps")
549-
global_step = 0
550588

551589
# keep original embeddings as reference
552590
orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
553591

554-
for epoch in range(args.num_train_epochs):
592+
for epoch in range(first_epoch, args.num_train_epochs):
555593
text_encoder.train()
556594
for step, batch in enumerate(train_dataloader):
595+
# Skip steps until we reach the resumed step
596+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
597+
if step % args.gradient_accumulation_steps == 0:
598+
progress_bar.update(1)
599+
continue
600+
557601
with accelerator.accumulate(text_encoder):
558602
# Convert images to latent space
559603
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
@@ -605,6 +649,12 @@ def main():
605649
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
606650
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
607651

652+
if global_step % args.checkpointing_steps == 0:
653+
if accelerator.is_main_process:
654+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
655+
accelerator.save_state(save_path)
656+
logger.info(f"Saved state to {save_path}")
657+
608658
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
609659
progress_bar.set_postfix(**logs)
610660
accelerator.log(logs, step=global_step)

0 commit comments

Comments
 (0)