-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Examples] Improve the model card pushed from the train_text_to_image.py script
#3810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
511abcb
512aa05
4ed17ed
5d31fb9
13b9a8b
e8e6ac7
da9e00b
507d2cc
dca3f87
2890d5b
5eecea5
0bd364c
431daf0
3b6ced0
a4f9858
a743121
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| from datasets import load_dataset | ||
| from huggingface_hub import create_repo, upload_folder | ||
| from packaging import version | ||
| from PIL import Image | ||
| from torchvision import transforms | ||
| from tqdm.auto import tqdm | ||
| from transformers import CLIPTextModel, CLIPTokenizer | ||
|
|
@@ -62,6 +63,92 @@ | |
| } | ||
|
|
||
|
|
||
| def make_image_grid(imgs, rows, cols): | ||
| assert len(imgs) == rows * cols | ||
|
|
||
| w, h = imgs[0].size | ||
| grid = Image.new("RGB", size=(cols * w, rows * h)) | ||
|
|
||
| for i, img in enumerate(imgs): | ||
| grid.paste(img, box=(i % cols * w, i // cols * h)) | ||
| return grid | ||
|
|
||
|
|
||
| def save_model_card( | ||
| args, | ||
| repo_id: str, | ||
| images=None, | ||
| repo_folder=None, | ||
| ): | ||
| img_str = "" | ||
| if images is not None: | ||
| image_grid = make_image_grid(images, 1, len(args.validation_prompts)) | ||
| image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) | ||
| img_str += "\n" | ||
|
|
||
| yaml = f""" | ||
| --- | ||
| license: creativeml-openrail-m | ||
| base_model: {args.pretrained_model_name_or_path} | ||
| datasets: | ||
| - {args.dataset_name} | ||
| tags: | ||
| - stable-diffusion | ||
| - stable-diffusion-diffusers | ||
| - text-to-image | ||
| - diffusers | ||
| inference: true | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| --- | ||
| """ | ||
| model_card = f""" | ||
| # Text-to-image finetuning - {repo_id} | ||
|
|
||
| This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n | ||
| {img_str} | ||
|
|
||
| ## Pipeline usage | ||
|
|
||
| You can use the pipeline like so: | ||
|
|
||
| ```python | ||
| from diffusers import DiffusionPipeline | ||
| import torch | ||
|
|
||
| pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) | ||
| prompt = "{args.validation_prompts[0]}" | ||
| image = pipeline(prompt).images[0] | ||
| image.save("my_image.png") | ||
| ``` | ||
|
|
||
| ## Training info | ||
|
|
||
| These are the key hyperparameters used during training: | ||
|
|
||
| * Epochs: {args.num_train_epochs} | ||
| * Learning rate: {args.learning_rate} | ||
| * Batch size: {args.train_batch_size} | ||
| * Gradient accumulation steps: {args.gradient_accumulation_steps} | ||
| * Image resolution: {args.resolution} | ||
| * Mixed-precision: {args.mixed_precision} | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| """ | ||
| wandb_info = """ | ||
| More information on all the CLI arguments and the environment should be available on the `wandb` run page if you used it via `report_to="wandb"`. | ||
| """ | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if is_wandb_available(): | ||
| wandb_run_url = None | ||
| if wandb.run is not None: | ||
| wandb_run_url = wandb.run.url | ||
|
|
||
| if wandb_run_url is not None: | ||
| wandb_info += f"Check it out here: {wandb_run_url}." | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| model_card += wandb_info | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with open(os.path.join(repo_folder, "README.md"), "w") as f: | ||
| f.write(yaml + model_card) | ||
|
|
||
|
|
||
| def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): | ||
| logger.info("Running validation... ") | ||
|
|
||
|
|
@@ -112,6 +199,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight | |
| del pipeline | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return images | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | ||
|
|
@@ -747,8 +836,10 @@ def collate_fn(examples): | |
| weight_dtype = torch.float32 | ||
| if accelerator.mixed_precision == "fp16": | ||
| weight_dtype = torch.float16 | ||
| args.mixed_precision = accelerator.mixed_precision | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif accelerator.mixed_precision == "bf16": | ||
| weight_dtype = torch.bfloat16 | ||
| args.mixed_precision = accelerator.mixed_precision | ||
|
|
||
| # Move text_encode and vae to gpu and cast to weight_dtype | ||
| text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
|
|
@@ -935,12 +1026,13 @@ def collate_fn(examples): | |
| break | ||
|
|
||
| if accelerator.is_main_process: | ||
| images = None | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if args.validation_prompts is not None and epoch % args.validation_epochs == 0: | ||
| if args.use_ema: | ||
| # Store the UNet parameters temporarily and load the EMA parameters to perform inference. | ||
| ema_unet.store(unet.parameters()) | ||
| ema_unet.copy_to(unet.parameters()) | ||
| log_validation( | ||
| images = log_validation( | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| vae, | ||
| text_encoder, | ||
| tokenizer, | ||
|
|
@@ -970,7 +1062,28 @@ def collate_fn(examples): | |
| ) | ||
| pipeline.save_pretrained(args.output_dir) | ||
|
|
||
| # Run a final round of inference. | ||
| images = None | ||
| if args.validation_prompts is not None: | ||
| logger.info("Running inference for collecting generated images...") | ||
| pipeline = pipeline.to(accelerator.device) | ||
| pipeline.torch_dtype = weight_dtype | ||
| pipeline.set_progress_bar_config(disable=True) | ||
|
|
||
| if args.enable_xformers_memory_efficient_attention: | ||
| pipeline.enable_xformers_memory_efficient_attention() | ||
|
|
||
| if args.seed is None: | ||
| generator = None | ||
| else: | ||
| generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) | ||
|
|
||
| for i in range(len(args.validation_prompts)): | ||
| image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] | ||
| images.append(image) | ||
|
|
||
| if args.push_to_hub: | ||
| save_model_card(args, repo_id, images, repo_folder=args.output_dir) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you separately create the images that are saved to the model card here? Feel free to just pull the subset of code out of the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either what @williamberman said, or skip the image block in the model card if validation is disabled. I'd prefer to have the images, but we may not be able to generate them unless we make up a prompt that may not be ideal for the particular fine-tune the user performed.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you see the logic in the new If there are validation images to save to the model card,
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that works for me. |
||
| upload_folder( | ||
| repo_id=repo_id, | ||
| folder_path=args.output_dir, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI
huggingface_hubhas nice utilities for model card creation (both content and metadata) and templates you could leverage here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://huggingface.co/docs/huggingface_hub/guides/model-cards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggesting!
But I feel like the current way of creating the content in the model card is a bit more explicit and also flexible.
I could leverage something like https://huggingface.co/docs/huggingface_hub/guides/model-cards#from-a-jinja-template, but I would not here because:
huggingface_hub, which I'd like to avoid because that's not the objective here.Cc: @pcuenca
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Wauplin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping @osanseviero. I would also advocate to use
huggingface_hub's modelcards for a few reasons:\nvs\r\non Windows). It's quite specific and annoying but we did a few iterations inhfhto handle that correctly. The goal being to avoid big diffs if someone else updates the model card afterwards.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That being said, I don't think the
Path('custom_template.md').write_text(template_text)line from the docs should be reused. I think it would be best to provide a jinja template alongside the training script and only have:in the training script
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(thinking out loud here but) actually reusing the default modelcard template is also an option.
It simplifies the example training script on your side at the cost of a more verbose model card -with a lot of empty fields at first-. We can see this as a way to encourage users to document better their models. The advantage of the default template is that it has been iterated multiple times to be compliant with what should be the standard in term of model cards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am fine having a more verbose model card but having to provide a Jinja template separately is something I am not comfortable doing for our examples.
I think they would need to open up the PR editing the README file anyway whose content is straightforward IMO.
Probably, the best is to reuse default modelcard template as you mentioned. Happy to accept a PR to see the changes.