Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin
<!-- accelerate_snippet_start -->
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
Expand Down Expand Up @@ -133,11 +133,11 @@ for running distributed training with `accelerate`. Here is an example command:

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
Expand Down Expand Up @@ -274,11 +274,11 @@ pip install -U -r requirements_flax.txt

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"

python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--dataset_name=$DATASET_NAME \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--mixed_precision="fp16" \
Expand Down
115 changes: 114 additions & 1 deletion examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
Expand Down Expand Up @@ -62,6 +63,92 @@
}


def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols

w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))

for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid


def save_model_card(
args,
repo_id: str,
images=None,
repo_folder=None,
):
img_str = ""
if images is not None:
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"

yaml = f"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI huggingface_hub has nice utilities for model card creation (both content and metadata) and templates you could leverage here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggesting!

But I feel like the current way of creating the content in the model card is a bit more explicit and also flexible.

I could leverage something like https://huggingface.co/docs/huggingface_hub/guides/model-cards#from-a-jinja-template, but I would not here because:

  • It doesn't reduce the code complexity significantly.
  • Now, users will have to learn about the model card related details of huggingface_hub, which I'd like to avoid because that's not the objective here.

Cc: @pcuenca

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping @osanseviero. I would also advocate to use huggingface_hub's modelcards for a few reasons:

  1. You need to be careful about encoding issues in the README (especially \n vs \r\n on Windows). It's quite specific and annoying but we did a few iterations in hfh to handle that correctly. The goal being to avoid big diffs if someone else updates the model card afterwards.
  2. Using modelcards + a separate jinja template makes it really easy for non-developers to review the model card template without looking into the code. This can prove useful if someone from the ethics team (for example) wants to open a PR to complete the model card template without digging to the exact code (i.e. separate code and templates to separate usage).
  3. While it doesn't reduce much complexity, it doesn't add much either. I don't think users have to understand ModelCards internal details to use it correctly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I don't think the Path('custom_template.md').write_text(template_text) line from the docs should be reused. I think it would be best to provide a jinja template alongside the training script and only have:

card_data = ModelCardData(language='en', license='mit', library_name='keras')
card = ModelCard.from_template(card_data, template_path='custom_template.md', author='nateraw')
card.save('README.md')

in the training script

Copy link
Collaborator

@Wauplin Wauplin Jun 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(thinking out loud here but) actually reusing the default modelcard template is also an option.

It simplifies the example training script on your side at the cost of a more verbose model card -with a lot of empty fields at first-. We can see this as a way to encourage users to document better their models. The advantage of the default template is that it has been iterated multiple times to be compliant with what should be the standard in term of model cards.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine having a more verbose model card but having to provide a Jinja template separately is something I am not comfortable doing for our examples.

Using modelcards + a separate jinja template makes it really easy for non-developers to review the model card template without looking into the code. This can prove useful if someone from the ethics team (for example) wants to open a PR to complete the model card template without digging to the exact code (i.e. separate code and templates to separate usage).

I think they would need to open up the PR editing the README file anyway whose content is straightforward IMO.

Probably, the best is to reuse default modelcard template as you mentioned. Happy to accept a PR to see the changes.

---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# Text-to-image finetuning - {repo_id}

This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
{img_str}

## Pipeline usage

You can use the pipeline like so:

```python
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
prompt = "{args.validation_prompts[0]}"
image = pipeline(prompt).images[0]
image.save("my_image.png")
```

## Training info

These are the key hyperparameters used during training:

* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}

"""
wandb_info = """
More information on all the CLI arguments and the environment should be available on the `wandb` run page if you used it via `report_to="wandb"`.
"""
if is_wandb_available():
wandb_run_url = None
if wandb.run is not None:
wandb_run_url = wandb.run.url

if wandb_run_url is not None:
wandb_info += f"Check it out here: {wandb_run_url}."

model_card += wandb_info

with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)


def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")

Expand Down Expand Up @@ -112,6 +199,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
del pipeline
torch.cuda.empty_cache()

return images


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
Expand Down Expand Up @@ -747,8 +836,10 @@ def collate_fn(examples):
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision

# Move text_encode and vae to gpu and cast to weight_dtype
text_encoder.to(accelerator.device, dtype=weight_dtype)
Expand Down Expand Up @@ -935,12 +1026,13 @@ def collate_fn(examples):
break

if accelerator.is_main_process:
images = None
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
log_validation(
images = log_validation(
vae,
text_encoder,
tokenizer,
Expand Down Expand Up @@ -970,7 +1062,28 @@ def collate_fn(examples):
)
pipeline.save_pretrained(args.output_dir)

# Run a final round of inference.
images = None
if args.validation_prompts is not None:
logger.info("Running inference for collecting generated images...")
pipeline = pipeline.to(accelerator.device)
pipeline.torch_dtype = weight_dtype
pipeline.set_progress_bar_config(disable=True)

if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()

if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)

for i in range(len(args.validation_prompts)):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image)

if args.push_to_hub:
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you separately create the images that are saved to the model card here? Feel free to just pull the subset of code out of the log_validation_images that creates the images into a helper method.

Copy link
Member

@pcuenca pcuenca Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either what @williamberman said, or skip the image block in the model card if validation is disabled. I'd prefer to have the images, but we may not be able to generate them unless we make up a prompt that may not be ideal for the particular fine-tune the user performed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you see the logic in the new save_model_card() function, that is how it's currently done.

If there are validation images to save to the model card, img_str will be crafted accordingly, otherwise, it will be none.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that works for me.

upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
Expand Down