diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index db9e72a4ea20..470d8c5c189d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -226,6 +226,8 @@ title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion title: Semantic Guidance + - local: api/pipelines/shap_e + title: Shap-E - local: api/pipelines/spectrogram_diffusion title: Spectrogram Diffusion - sections: diff --git a/docs/source/en/api/pipelines/shap_e.mdx b/docs/source/en/api/pipelines/shap_e.mdx new file mode 100644 index 000000000000..fcb32da31bca --- /dev/null +++ b/docs/source/en/api/pipelines/shap_e.mdx @@ -0,0 +1,139 @@ + + +# Shap-E + +## Overview + + +The Shap-E model was proposed in [Shap-E: Generating Conditional 3D Implicit Functions](https://arxiv.org/abs/2305.02463) by Alex Nichol and Heewon Jun from [OpenAI](https://github.com/openai). + +The abstract of the paper is the following: + +*We present Shap-E, a conditional generative model for 3D assets. Unlike recent work on 3D generative models which produce a single output representation, Shap-E directly generates the parameters of implicit functions that can be rendered as both textured meshes and neural radiance fields. We train Shap-E in two stages: first, we train an encoder that deterministically maps 3D assets into the parameters of an implicit function; second, we train a conditional diffusion model on outputs of the encoder. When trained on a large dataset of paired 3D and text data, our resulting models are capable of generating complex and diverse 3D assets in a matter of seconds. When compared to Point-E, an explicit generative model over point clouds, Shap-E converges faster and reaches comparable or better sample quality despite modeling a higher-dimensional, multi-representation output space.* + +The original codebase can be found [here](https://github.com/openai/shap-e). + +## Available Pipelines: + +| Pipeline | Tasks | +|---|---| +| [pipeline_shap_e.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e.py) | *Text-to-Image Generation* | +| [pipeline_shap_e_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py) | *Image-to-Image Generation* | + +## Available checkpoints + +* [`openai/shap-e`](https://huggingface.co/openai/shap-e) +* [`openai/shap-e-img2img`](https://huggingface.co/openai/shap-e-img2img) + +## Usage Examples + +In the following, we will walk you through some examples of how to use Shap-E pipelines to create 3D objects in gif format. + +### Text-to-3D image generation + +We can use [`ShapEPipeline`] to create 3D object based on a text prompt. In this example, we will make a birthday cupcake for :firecracker: diffusers library's 1 year birthday. The workflow to use the Shap-E text-to-image pipeline is same as how you would use other text-to-image pipelines in diffusers. + +```python +import torch + +from diffusers import DiffusionPipeline + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +repo = "openai/shap-e" +pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) +pipe = pipe.to(device) + +guidance_scale = 15.0 +prompt = ["A firecracker", "A birthday cupcake"] + +images = pipe( + prompt, + guidance_scale=guidance_scale, + num_inference_steps=64, + frame_size=256, +).images +``` + +The output of [`ShapEPipeline`] is a list of lists of images frames. Each list of frames can be used to create a 3D object. Let's use the `export_to_gif` utility function in diffusers to make a 3D cupcake! + +```python +from diffusers.utils import export_to_gif + +export_to_gif(images[0], "firecracker_3d.gif") +export_to_gif(images[1], "cake_3d.gif") +``` +![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/firecracker_out.gif) +![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/cake_out.gif) + + +### Image-to-Image generation + +You can use [`ShapEImg2ImgPipeline`] along with other text-to-image pipelines in diffusers and turn your 2D generation into 3D. + +In this example, We will first genrate a cheeseburger with a simple prompt "A cheeseburger, white background" + +```python +from diffusers import DiffusionPipeline +import torch + +pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16) +pipe_prior.to("cuda") + +t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +t2i_pipe.to("cuda") + +prompt = "A cheeseburger, white background" + +image_embeds, negative_image_embeds = pipe_prior(prompt, guidance_scale=1.0).to_tuple() +image = t2i_pipe( + prompt, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, +).images[0] + +image.save("burger.png") +``` + +![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_in.png) + +we will then use the Shap-E image-to-image pipeline to turn it into a 3D cheeseburger :) + +```python +from PIL import Image +from diffusers.utils import export_to_gif + +repo = "openai/shap-e-img2img" +pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) +pipe = pipe.to("cuda") + +guidance_scale = 3.0 +image = Image.open("burger.png").resize((256, 256)) + +images = pipe( + image, + guidance_scale=guidance_scale, + num_inference_steps=64, + frame_size=256, +).images + +gif_path = export_to_gif(images[0], "burger_3d.gif") +``` +![img](https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/burger_out.gif) + +## ShapEPipeline +[[autodoc]] ShapEPipeline + - all + - __call__ + +## ShapEImg2ImgPipeline +[[autodoc]] ShapEImg2ImgPipeline + - all + - __call__ \ No newline at end of file diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py new file mode 100644 index 000000000000..d92db176f422 --- /dev/null +++ b/scripts/convert_shap_e_to_diffusers.py @@ -0,0 +1,594 @@ +import argparse +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch + +from diffusers.models.prior_transformer import PriorTransformer +from diffusers.pipelines.shap_e import ShapERenderer + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt" +``` + +Convert the model: +```sh +$ python scripts/convert_shap_e_to_diffusers.py \ + --prior_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/text_cond.pt \ + --prior_image_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/image_cond.pt \ + --transmitter_checkpoint_path /home/yiyi_huggingface_co/shap-e/shap_e_model_cache/transmitter.pt\ + --dump_path /home/yiyi_huggingface_co/model_repo/shap-e/renderer\ + --debug renderer +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "wrapped" + +PRIOR_CONFIG = { + "num_attention_heads": 16, + "attention_head_dim": 1024 // 16, + "num_layers": 24, + "embedding_dim": 1024, + "num_embeddings": 1024, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + "time_embed_dim": 1024 * 4, + "embedding_proj_dim": 768, + "clip_embed_dim": 1024 * 2, +} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # .time_embed.c_fc -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_fc.bias"], + } + ) + + # .time_embed.c_proj -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.c_proj.bias"], + } + ) + + # .input_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.input_proj.bias"], + } + ) + + # .clip_emb -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_embed.bias"], + } + ) + + # .pos_emb -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.pos_emb"][None, :]}) + + # .ln_pre -> .norm_in + diffusers_checkpoint.update( + { + "norm_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.weight"], + "norm_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_pre.bias"], + } + ) + + # .backbone.resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.backbone.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .ln_post -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.ln_post.bias"], + } + ) + + # .output_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.output_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + + +# prior_image (only slightly different from prior) + + +PRIOR_IMAGE_ORIGINAL_PREFIX = "wrapped" + +# Uses default arguments +PRIOR_IMAGE_CONFIG = { + "num_attention_heads": 8, + "attention_head_dim": 1024 // 8, + "num_layers": 24, + "embedding_dim": 1024, + "num_embeddings": 1024, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "embedding_proj_norm_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + "time_embed_dim": 1024 * 4, + "embedding_proj_dim": 1024, + "clip_embed_dim": 1024 * 2, +} + + +def prior_image_model_from_original_config(): + model = PriorTransformer(**PRIOR_IMAGE_CONFIG) + + return model + + +def prior_image_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # .time_embed.c_fc -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_fc.bias"], + } + ) + + # .time_embed.c_proj -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.time_embed.c_proj.bias"], + } + ) + + # .input_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.input_proj.bias"], + } + ) + + # .clip_embed.0 -> .embedding_proj_norm + diffusers_checkpoint.update( + { + "embedding_proj_norm.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.weight"], + "embedding_proj_norm.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.0.bias"], + } + ) + + # ..clip_embed.1 -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.clip_embed.1.bias"], + } + ) + + # .pos_emb -> .positional_embedding + diffusers_checkpoint.update( + {"positional_embedding": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.pos_emb"][None, :]} + ) + + # .ln_pre -> .norm_in + diffusers_checkpoint.update( + { + "norm_in.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.weight"], + "norm_in.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_pre.bias"], + } + ) + + # .backbone.resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.backbone.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .ln_post -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.ln_post.bias"], + } + ) + + # .output_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_IMAGE_ORIGINAL_PREFIX}.output_proj.bias"], + } + ) + + return diffusers_checkpoint + + +# done prior_image + + +# renderer + +RENDERER_CONFIG = {} + + +def renderer_model_from_original_config(): + model = ShapERenderer(**RENDERER_CONFIG) + + return model + + +RENDERER_MLP_ORIGINAL_PREFIX = "renderer.nerstf" + +RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX = "encoder.params_proj" + + +def renderer_model_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update( + {f"mlp.{k}": checkpoint[f"{RENDERER_MLP_ORIGINAL_PREFIX}.{k}"] for k in model.mlp.state_dict().keys()} + ) + + diffusers_checkpoint.update( + { + f"params_proj.{k}": checkpoint[f"{RENDERER_PARAMS_PROJ_ORIGINAL_PREFIX}.{k}"] + for k in model.params_proj.state_dict().keys() + } + ) + + diffusers_checkpoint.update({"void.background": torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)}) + + return diffusers_checkpoint + + +# done renderer + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +# Driver functions + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint) + + del prior_checkpoint + + load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model) + + print("done loading prior") + + return prior_model + + +def prior_image(*, args, checkpoint_map_location): + print("loading prior_image") + + print(f"load checkpoint from {args.prior_image_checkpoint_path}") + prior_checkpoint = torch.load(args.prior_image_checkpoint_path, map_location=checkpoint_map_location) + + prior_model = prior_image_model_from_original_config() + + prior_diffusers_checkpoint = prior_image_original_checkpoint_to_diffusers_checkpoint(prior_model, prior_checkpoint) + + del prior_checkpoint + + load_prior_checkpoint_to_model(prior_diffusers_checkpoint, prior_model) + + print("done loading prior_image") + + return prior_model + + +def renderer(*, args, checkpoint_map_location): + print(" loading renderer") + + renderer_checkpoint = torch.load(args.transmitter_checkpoint_path, map_location=checkpoint_map_location) + + renderer_model = renderer_model_from_original_config() + + renderer_diffusers_checkpoint = renderer_model_original_checkpoint_to_diffusers_checkpoint( + renderer_model, renderer_checkpoint + ) + + del renderer_checkpoint + + load_checkpoint_to_model(renderer_diffusers_checkpoint, renderer_model, strict=True) + + print("done loading renderer") + + return renderer_model + + +# prior model will expect clip_mean and clip_std, whic are missing from the state_dict +PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"] + + +def load_prior_checkpoint_to_model(checkpoint, model): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + missing_keys, unexpected_keys = model.load_state_dict(torch.load(file.name), strict=False) + missing_keys = list(set(missing_keys) - set(PRIOR_EXPECTED_MISSING_KEYS)) + + if len(unexpected_keys) > 0: + raise ValueError(f"Unexpected keys when loading prior model: {unexpected_keys}") + if len(missing_keys) > 0: + raise ValueError(f"Missing keys when loading prior model: {missing_keys}") + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile() as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior checkpoint to convert.", + ) + + parser.add_argument( + "--prior_image_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior_image checkpoint to convert.", + ) + + parser.add_argument( + "--transmitter_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the transmitter checkpoint to convert.", + ) + + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + print("YiYi TO-DO") + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "prior_image": + prior_model = prior_image(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "renderer": + renderer_model = renderer(args=args, checkpoint_map_location=checkpoint_map_location) + renderer_model.save_pretrained(args.dump_path) + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f425dc13ec2c..a7fc9a36f271 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -149,6 +149,8 @@ LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, + ShapEImg2ImgPipeline, + ShapEPipeline, StableDiffusionAttendAndExcitePipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 47785a93e939..9f3c61dd7561 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -34,14 +34,33 @@ class PriorTransformer(ModelMixin, ConfigMixin): num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): - The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension. - num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the - length of the prompt after it has been tokenized). + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. """ @register_to_config @@ -54,6 +73,14 @@ def __init__( num_embeddings=77, additional_embeddings=4, dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -61,17 +88,41 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim self.additional_embeddings = additional_embeddings + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + self.time_proj = Timesteps(inner_dim, True, 0) - self.time_embedding = TimestepEmbedding(inner_dim, inner_dim) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) self.proj_in = nn.Linear(embedding_dim, inner_dim) - self.embedding_proj = nn.Linear(embedding_dim, inner_dim) - self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim)) - self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + if added_emb_type == "prd": + self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim)) + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) self.transformer_blocks = nn.ModuleList( [ @@ -87,8 +138,16 @@ def __init__( ] ) + if norm_in_type == "layer": + self.norm_in = nn.LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + self.norm_out = nn.LayerNorm(inner_dim) - self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim) + + self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim) causal_attention_mask = torch.full( [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 @@ -97,8 +156,8 @@ def __init__( causal_attention_mask = causal_attention_mask[None, ...] self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False) - self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim)) - self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim)) + self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim)) + self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim)) @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors @@ -172,7 +231,7 @@ def forward( hidden_states, timestep: Union[torch.Tensor, float, int], proj_embedding: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.BoolTensor] = None, return_dict: bool = True, ): @@ -217,23 +276,61 @@ def forward( timesteps_projected = timesteps_projected.to(dtype=self.dtype) time_embeddings = self.time_embedding(timesteps_projected) + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + proj_embeddings = self.embedding_proj(proj_embedding) - encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + hidden_states = self.proj_in(hidden_states) - prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1) + additional_embeds.append(prd_embedding) + hidden_states = torch.cat( - [ - encoder_hidden_states, - proj_embeddings[:, None, :], - time_embeddings[:, None, :], - hidden_states[:, None, :], - prd_embedding, - ], + additional_embeds, dim=1, ) + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = F.pad( + positional_embeddings, + ( + 0, + 0, + additional_embeddings_len, + self.prd_embedding.shape[1] if self.prd_embedding is not None else 0, + ), + value=0.0, + ) + hidden_states = hidden_states + positional_embeddings if attention_mask is not None: @@ -242,11 +339,19 @@ def forward( attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0) + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + for block in self.transformer_blocks: hidden_states = block(hidden_states, attention_mask=attention_mask) hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states[:, -1] + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) if not return_dict: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b8bee3299aff..c3968406ed90 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -77,6 +77,7 @@ from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, diff --git a/src/diffusers/pipelines/shap_e/__init__.py b/src/diffusers/pipelines/shap_e/__init__.py new file mode 100644 index 000000000000..04aa1f2f6d78 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/__init__.py @@ -0,0 +1,27 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .camera import create_pan_cameras + from .pipeline_shap_e import ShapEPipeline + from .pipeline_shap_e_img2img import ShapEImg2ImgPipeline + from .renderer import ( + BoundingBoxVolume, + ImportanceRaySampler, + MLPNeRFModelOutput, + MLPNeRSTFModel, + ShapEParamsProjModel, + ShapERenderer, + StratifiedRaySampler, + VoidNeRFModel, + ) diff --git a/src/diffusers/pipelines/shap_e/camera.py b/src/diffusers/pipelines/shap_e/camera.py new file mode 100644 index 000000000000..7ef0d6607022 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/camera.py @@ -0,0 +1,147 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import torch + + +@dataclass +class DifferentiableProjectiveCamera: + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + shape: Tuple[int] + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert len(self.x.shape) == len(self.y.shape) == len(self.z.shape) == len(self.origin.shape) == 2 + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def get_image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + @property + def camera_rays(self): + batch_size, *inner_shape = self.shape + inner_batch_size = int(np.prod(inner_shape)) + + coords = self.get_image_coords() + coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) + rays = self.get_camera_rays(coords) + + rays = rays.view(batch_size, inner_batch_size * self.height * self.width, 2, 3) + + return rays + + def get_camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + + flat = coords.view(batch_size, -1, 2) + + res = self.resolution() + fov = self.fov() + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to(self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3]), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +def create_pan_cameras(size: int) -> DifferentiableProjectiveCamera: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * 4 + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float(), + x=torch.from_numpy(np.stack(xs, axis=0)).float(), + y=torch.from_numpy(np.stack(ys, axis=0)).float(), + z=torch.from_numpy(np.stack(zs, axis=0)).float(), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + shape=(1, len(xs)), + ) diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py new file mode 100644 index 000000000000..5d96fc7bb9f4 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -0,0 +1,390 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 15.0 + >>> prompt = "a shark" + + >>> images = pipe( + ... prompt, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "shark_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for ShapEPipeline. + + Args: + images (`torch.FloatTensor`) + a list of images for 3D rendering + """ + + images: Union[List[List[PIL.Image.Image]], List[List[np.ndarray]]] + + +class ShapEPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects + with the NeRF rendering method + """ + + def __init__( + self, + prior: PriorTransformer, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: HeunDiscreteScheduler, + renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + renderer=renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [self.text_encoder, self.prior] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.prior, self.renderer]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): + return self.device + for module in self.text_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + len(prompt) if isinstance(prompt, list) else 1 + + # YiYi Notes: set pad_token_id to be 0, not sure why I can't set in the config file + self.tokenizer.pad_token_id = 0 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + prompt_embeds = text_encoder_output.text_embeds + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # in Shap-E it normalize the prompt_embeds and then later rescale it + prompt_embeds = prompt_embeds / torch.linalg.norm(prompt_embeds, dim=-1, keepdim=True) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # Rescale the features to have unit variance + prompt_embeds = math.sqrt(prompt_embeds.shape[1]) * prompt_embeds + + return prompt_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + frame_size (`int`, *optional*, default to 64): + the width and height of each image frame of the generated 3d output + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`ShapEPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=prompt_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance is not None: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + for i, latent in enumerate(latents): + image = self.renderer.decode( + latent[None, :], + device, + size=frame_size, + ray_batch_size=4096, + n_coarse_samples=64, + n_fine_samples=128, + ) + images.append(image) + + images = torch.stack(images) + + if output_type not in ["np", "pil"]: + raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py new file mode 100644 index 000000000000..b99b808e5953 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -0,0 +1,349 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPVisionModel + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import HeunDiscreteScheduler +from ...utils import ( + BaseOutput, + is_accelerate_available, + logging, + randn_tensor, + replace_example_docstring, +) +from .renderer import ShapERenderer + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from PIL import Image + >>> import torch + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import export_to_gif, load_image + + >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + >>> repo = "openai/shap-e-img2img" + >>> pipe = DiffusionPipeline.from_pretrained(repo, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> guidance_scale = 3.0 + >>> image_url = "https://hf.co/datasets/diffusers/docs-images/resolve/main/shap-e/corgi.png" + >>> image = load_image(image_url).convert("RGB") + + >>> images = pipe( + ... image, + ... guidance_scale=guidance_scale, + ... num_inference_steps=64, + ... frame_size=256, + ... ).images + + >>> gif_path = export_to_gif(images[0], "corgi_3d.gif") + ``` +""" + + +@dataclass +class ShapEPipelineOutput(BaseOutput): + """ + Output class for ShapEPipeline. + + Args: + images (`torch.FloatTensor`) + a list of images for 3D rendering + """ + + images: Union[PIL.Image.Image, np.ndarray] + + +class ShapEImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for generating latent representation of a 3D asset and rendering with NeRF method with Shap-E + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`HeunDiscreteScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + renderer ([`ShapERenderer`]): + Shap-E renderer projects the generated latents into parameters of a MLP that's used to create 3D objects + with the NeRF rendering method + """ + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + scheduler: HeunDiscreteScheduler, + renderer: ShapERenderer, + ): + super().__init__() + + self.register_modules( + prior=prior, + image_encoder=image_encoder, + image_processor=image_processor, + scheduler=scheduler, + renderer=renderer, + ) + + # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [self.image_encoder, self.prior] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_encoder, "_hf_hook"): + return self.device + for module in self.image_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image( + self, + image, + device, + num_images_per_prompt, + do_classifier_free_guidance, + ): + if isinstance(image, List) and isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if not isinstance(image, torch.Tensor): + image = self.image_processor(image, return_tensors="pt").pixel_values[0].unsqueeze(0) + + image = image.to(dtype=self.image_encoder.dtype, device=device) + + image_embeds = self.image_encoder(image)["last_hidden_state"] + image_embeds = image_embeds[:, 1:, :].contiguous() # batch_size, dim, 256 + + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.zeros_like(image_embeds) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + return image_embeds + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image]], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + frame_size: int = 64, + output_type: Optional[str] = "pil", # pil, np, latent + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + frame_size (`int`, *optional*, default to 64): + the width and height of each image frame of the generated 3d output + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`ShapEPipelineOutput`] or `tuple` + """ + + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + elif isinstance(image, list) and isinstance(image[0], (torch.Tensor, PIL.Image.Image)): + batch_size = len(image) + else: + raise ValueError( + f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `List[PIL.Image.Image]` or `List[torch.Tensor]` but is {type(image)}" + ) + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + image_embeds = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # prior + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_embeddings = self.prior.config.num_embeddings + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, num_embeddings * embedding_dim), + image_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim + latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim) + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + scaled_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + noise_pred = self.prior( + scaled_model_input, + timestep=t, + proj_embedding=image_embeds, + ).predicted_image_embedding + + # remove the variance + noise_pred, _ = noise_pred.split( + scaled_model_input.shape[2], dim=2 + ) # batch_size, num_embeddings, embedding_dim + + if do_classifier_free_guidance is not None: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + timestep=t, + sample=latents, + ).prev_sample + + if output_type == "latent": + return ShapEPipelineOutput(images=latents) + + images = [] + for i, latent in enumerate(latents): + print() + image = self.renderer.decode( + latent[None, :], + device, + size=frame_size, + ray_batch_size=4096, + n_coarse_samples=64, + n_fine_samples=128, + ) + + images.append(image) + + images = torch.stack(images) + + if output_type not in ["np", "pil"]: + raise ValueError(f"Only the output types `pil` and `np` are supported not output_type={output_type}") + + images = images.cpu().numpy() + + if output_type == "pil": + images = [self.numpy_to_pil(image) for image in images] + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (images,) + + return ShapEPipelineOutput(images=images) diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py new file mode 100644 index 000000000000..8b075e671f63 --- /dev/null +++ b/src/diffusers/pipelines/shap_e/renderer.py @@ -0,0 +1,709 @@ +# Copyright 2023 Open AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...utils import BaseOutput +from .camera import create_pan_cameras + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + r""" + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + Args: + pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + n_samples: number of samples + + Return: + indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://arxiv.org/pdf/2210.04628.pdf + """ + if min_deg == max_deg: + return x + + scales = 2.0 ** torch.arange(min_deg, max_deg, dtype=x.dtype, device=x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +def encode_position(position): + return posenc_nerf(position, min_deg=0, max_deg=15) + + +def encode_direction(position, direction=None): + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + else: + return posenc_nerf(direction, min_deg=0, max_deg=8) + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") + + +def integrate_samples(volume_range, ts, density, channels): + r""" + Function integrating the model output. + + Args: + volume_range: Specifies the integral range [t0, t1] + ts: timesteps + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + returns: + channels: integrated rgb output weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density + *transmittance)[i] weight for each rgb output at [..., i, :]. transmittance: transmittance of this volume + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(ts) + ddensity = density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate channels + channels = torch.sum(channels * weights, dim=-2) + + return channels, weights, transmittance + + +class VoidNeRFModel(nn.Module): + """ + Implements the default empty space model where all queries are rendered as background. + """ + + def __init__(self, background, channel_scale=255.0): + super().__init__() + background = nn.Parameter(torch.from_numpy(np.array(background)).to(dtype=torch.float32) / channel_scale) + + self.register_buffer("background", background) + + def forward(self, position): + background = self.background[None].to(position.device) + + shape = position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to(background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]) + + return background + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def partition(self, ts): + """ + Partitions t0 and t1 into n_samples intervals. + + Args: + ts: [batch_size, *shape, n_samples, 1] + + Return: + + lower: [batch_size, *shape, n_samples, 1] upper: [batch_size, *shape, n_samples, 1] delta: [batch_size, + *shape, n_samples, 1] + + where + ts \\in [lower, upper] deltas = upper - lower + """ + + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class BoundingBoxVolume(nn.Module): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min, + bbox_max, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + ): + """ + Args: + bbox_min: the left/bottommost corner of the bounding box + bbox_max: the other corner of the bounding box + min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.min_dist = min_dist + self.min_t_range = min_t_range + + self.bbox_min = torch.tensor(bbox_min) + self.bbox_max = torch.tensor(bbox_max) + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert min_dist >= 0.0 + assert min_t_range > 0.0 + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + epsilon=1e-6, + ): + """ + Args: + origin: [batch_size, *shape, 3] + direction: [batch_size, *shape, 3] + t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + params: Optional meta parameters in case Volume is parametric + epsilon: to stabilize calculations + + Return: + A tuple of (t0, t1, intersected) where each has a shape [batch_size, *shape, 1]. If a ray intersects with + the volume, `o + td` is in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed to + be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3).to(origin.device) + + def _safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) + + ts = _safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class StratifiedRaySampler(nn.Module): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + # yiyi notes: add a random seed here for testing, don't forget to remove + torch.manual_seed(0) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(nn.Module): + """ + Given the initial estimate of densities, this samples more from regions/bins expected to have objects. + """ + + def __init__( + self, + volume_range: VolumeRange, + ts: torch.Tensor, + weights: torch.Tensor, + blur_pool: bool = False, + alpha: float = 1e-5, + ): + """ + Args: + volume_range: the range in which a ray intersects the given volume. + ts: earlier samples from the coarse rendering step + weights: discretized version of density * transmittance + blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = ts.clone().detach() + self.weights = weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Args: + t0: start time has shape [batch_size, *shape, 1] + t1: finish time has shape [batch_size, *shape, 1] + n_samples: number of ts to sample + Return: + sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts + + +@dataclass +class MLPNeRFModelOutput(BaseOutput): + density: torch.Tensor + signed_distance: torch.Tensor + channels: torch.Tensor + ts: torch.Tensor + + +class MLPNeRSTFModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + ): + super().__init__() + + # Instantiate the MLP + + # Find out the dimension of encoded position and direction + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(position=dummy).shape[-1] + d_posenc_dir = encode_direction(position=dummy).shape[-1] + + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos] + mlp_widths + output_widths = mlp_widths + [n_output] + + if insert_direction_at is not None: + input_widths[insert_direction_at] += d_posenc_dir + + self.mlp = nn.ModuleList([nn.Linear(d_in, d_out) for d_in, d_out in zip(input_widths, output_widths)]) + + if act_fn == "swish": + # self.activation = swish + # yiyi testing: + self.activation = lambda x: F.silu(x) + else: + raise ValueError(f"Unsupported activation function {act_fn}") + + self.sdf_activation = torch.tanh + self.density_activation = torch.nn.functional.relu + self.channel_activation = torch.sigmoid + + def map_indices_to_keys(self, output): + h_map = { + "sdf": (0, 1), + "density_coarse": (1, 2), + "density_fine": (2, 3), + "stf": (3, 6), + "nerf_coarse": (6, 9), + "nerf_fine": (9, 12), + } + + mapped_output = {k: output[..., start:end] for k, (start, end) in h_map.items()} + + return mapped_output + + def forward(self, *, position, direction, ts, nerf_level="coarse"): + h = encode_position(position) + + h_preact = h + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.config.insert_direction_at: # 4 in the config + h_directionless = h_preact + h_direction = encode_direction(position, direction=direction) + h = torch.cat([h, h_direction], dim=-1) + + h = layer(h) + + h_preact = h + + if i < len(self.mlp) - 1: + h = self.activation(h) + + h_final = h + if h_directionless is None: + h_directionless = h_preact + + activation = self.map_indices_to_keys(h_final) + + if nerf_level == "coarse": + h_density = activation["density_coarse"] + h_channels = activation["nerf_coarse"] + else: + h_density = activation["density_fine"] + h_channels = activation["nerf_fine"] + + density = self.density_activation(h_density) + signed_distance = self.sdf_activation(activation["sdf"]) + channels = self.channel_activation(h_channels) + + # yiyi notes: I think signed_distance is not used + return MLPNeRFModelOutput(density=density, signed_distance=signed_distance, channels=channels, ts=ts) + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + vectors: int, + channels: int, + d_latent: int, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels) + self.norm = nn.LayerNorm(channels) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + h = self.norm(h) + + h = h + b_vc + return h + + +class ShapEParamsProjModel(ModelMixin, ConfigMixin): + """ + project the latent representation of a 3D asset to obtain weights of a multi-layer perceptron (MLP). + + For more details, see the original paper: + """ + + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + ): + super().__init__() + + # check inputs + if len(param_names) != len(param_shapes): + raise ValueError("Must provide same number of `param_names` as `param_shapes`") + self.projections = nn.ModuleDict({}) + for k, (vectors, channels) in zip(param_names, param_shapes): + self.projections[_sanitize_name(k)] = ChannelsProj( + vectors=vectors, + channels=channels, + d_latent=d_latent, + ) + + def forward(self, x: torch.Tensor): + out = {} + start = 0 + for k, shape in zip(self.config.param_names, self.config.param_shapes): + vectors, _ = shape + end = start + vectors + x_bvd = x[:, start:end] + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +class ShapERenderer(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + *, + param_names: Tuple[str] = ( + "nerstf.mlp.0.weight", + "nerstf.mlp.1.weight", + "nerstf.mlp.2.weight", + "nerstf.mlp.3.weight", + ), + param_shapes: Tuple[Tuple[int]] = ( + (256, 93), + (256, 256), + (256, 256), + (256, 256), + ), + d_latent: int = 1024, + d_hidden: int = 256, + n_output: int = 12, + n_hidden_layers: int = 6, + act_fn: str = "swish", + insert_direction_at: int = 4, + background: Tuple[float] = ( + 255.0, + 255.0, + 255.0, + ), + ): + super().__init__() + + self.params_proj = ShapEParamsProjModel( + param_names=param_names, + param_shapes=param_shapes, + d_latent=d_latent, + ) + self.mlp = MLPNeRSTFModel(d_hidden, n_output, n_hidden_layers, act_fn, insert_direction_at) + self.void = VoidNeRFModel(background=background, channel_scale=255.0) + self.volume = BoundingBoxVolume(bbox_max=[1.0, 1.0, 1.0], bbox_min=[-1.0, -1.0, -1.0]) + + @torch.no_grad() + def render_rays(self, rays, sampler, n_samples, prev_model_out=None, render_with_direction=False): + """ + Perform volumetric rendering over a partition of possible t's in the union of rendering volumes (written below + with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), [t[i], t[i + 1]], + ) for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the probability of light passing through + the volume specified by [t[0], s]. (transmittance of 1 means light can pass freely) 2) density and channels are + obtained by evaluating the appropriate part.model at time t. 3) [t[i], t[i + 1]] is defined as the range of t + where the ray intersects (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface of the + shell (if bounded). If the ray does not intersect, the integral over this segment is evaluated as 0 and + transmittance(t[i + 1]) := transmittance(t[i]). 4) The last term is integration to infinity (e.g. [t[-1], + math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + + args: + rays: [batch_size x ... x 2 x 3] origin and direction. sampler: disjoint volume integrals. n_samples: + number of ts to sample. prev_model_outputs: model outputs from the previous rendering step, including + + :return: A tuple of + - `channels` + - A importance samplers for additional fine-grained rendering + - raw model output + """ + origin, direction = rays[..., 0, :], rays[..., 1, :] + + # Integrate over [t[i], t[i + 1]] + + # 1 Intersect the rays with the current volume and sample ts to integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=None) + ts = sampler.sample(vrange.t0, vrange.t1, n_samples) + ts = ts.to(rays.dtype) + + if prev_model_out is not None: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_model_out.ts], dim=-2), dim=-2).values + + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + directions = directions.to(self.mlp.dtype) + positions = positions.to(self.mlp.dtype) + + optional_directions = directions if render_with_direction else None + + model_out = self.mlp( + position=positions, + direction=optional_directions, + ts=ts, + nerf_level="coarse" if prev_model_out is None else "fine", + ) + + # 3. Integrate the model results + channels, weights, transmittance = integrate_samples( + vrange, model_out.ts, model_out.density, model_out.channels + ) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where(vrange.intersected, transmittance, torch.ones_like(transmittance)) + channels = torch.where(vrange.intersected, channels, torch.zeros_like(channels)) + # 5. integration to infinity (e.g. [t[-1], math.inf]) that is evaluated by the void_model (i.e. we consider this space to be empty). + channels = channels + transmittance * self.void(origin) + + weighted_sampler = ImportanceRaySampler(vrange, ts=model_out.ts, weights=weights) + + return channels, weighted_sampler, model_out + + @torch.no_grad() + def decode( + self, + latents, + device, + size: int = 64, + ray_batch_size: int = 4096, + n_coarse_samples=64, + n_fine_samples=128, + ): + # project the the paramters from the generated latents + projected_params = self.params_proj(latents) + + # update the mlp layers of the renderer + for name, param in self.mlp.state_dict().items(): + if f"nerstf.{name}" in projected_params.keys(): + param.copy_(projected_params[f"nerstf.{name}"].squeeze(0)) + + # create cameras object + camera = create_pan_cameras(size) + rays = camera.camera_rays + rays = rays.to(device) + n_batches = rays.shape[1] // ray_batch_size + + coarse_sampler = StratifiedRaySampler() + + images = [] + + for idx in range(n_batches): + rays_batch = rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + + # render rays with coarse, stratified samples. + _, fine_sampler, coarse_model_out = self.render_rays(rays_batch, coarse_sampler, n_coarse_samples) + # Then, render with additional importance-weighted ray samples. + channels, _, _ = self.render_rays( + rays_batch, fine_sampler, n_fine_samples, prev_model_out=coarse_model_out + ) + + images.append(channels) + + images = torch.cat(images, dim=1) + images = images.view(*camera.shape, camera.height, camera.width, -1).squeeze(0) + + return images diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 99602d14038b..a93255ca600e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -47,7 +47,11 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 2c9fc036a027..c04aabe035b5 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -46,7 +46,11 @@ class DDIMSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py index 8875aa73208b..db3ea0e1cca5 100644 --- a/src/diffusers/schedulers/scheduling_ddim_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py @@ -47,7 +47,11 @@ class DDIMParallelSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ddf27d409d88..a1b7d7aaa9c2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -44,7 +44,11 @@ class DDPMSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.FloatTensor] = None -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index e4d858efde8f..a92e175877d2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -46,7 +46,11 @@ class DDPMParallelSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -59,19 +63,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index c504fb19231a..36947294922b 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -26,7 +26,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 528b7b838b1c..d7516fa601e1 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -26,7 +26,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index b424ebbff262..a6736b354419 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -26,7 +26,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -39,19 +43,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index da8b71788b75..a31e97b69651 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -76,7 +77,11 @@ def __call__(self, sigma, sigma_next): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -89,19 +94,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -190,10 +206,16 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 else: - pos = 0 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + return indices[pos].item() @property @@ -292,6 +314,10 @@ def set_timesteps( self.sample = None self.mid_point_sigma = None + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + def _second_order_timesteps(self, sigmas, log_sigmas): def sigma_fn(_t): return np.exp(-_t) @@ -373,6 +399,10 @@ def step( """ step_index = self.index_for_timestep(timestep) + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 721dd5e5bb85..93975a27fc6e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -29,7 +29,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -42,19 +46,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 6b8c2f1a8a28..065f657032e6 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -47,7 +47,11 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index fc52c50ebc7f..cb126d4b953c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -47,7 +47,11 @@ class EulerDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -60,19 +64,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 28f29067a544..5f694fd60fc9 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -23,7 +24,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -74,6 +90,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf). + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. use_karras_sigmas (`bool`, *optional*, defaults to `False`): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence @@ -100,6 +120,8 @@ def __init__( trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", use_karras_sigmas: Optional[bool] = False, + clip_sample: Optional[bool] = False, + clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): @@ -114,7 +136,9 @@ def __init__( ) elif beta_schedule == "squaredcos_cap_v2": # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") + elif beta_schedule == "exp": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -131,10 +155,16 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 else: - pos = 0 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + return indices[pos].item() @property @@ -207,7 +237,7 @@ def set_timesteps( log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.use_karras_sigmas: + if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) @@ -228,6 +258,10 @@ def set_timesteps( self.prev_derivative = None self.dt = None + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma @@ -292,6 +326,10 @@ def step( """ step_index = self.index_for_timestep(timestep) + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_next = self.sigmas[step_index + 1] @@ -316,12 +354,17 @@ def step( sample / (sigma_input**2 + 1) ) elif self.config.prediction_type == "sample": - raise NotImplementedError("prediction_type not implemented yet: sample") + pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) + if self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + if self.state_in_first_order: # 2. Convert to an ODE derivative for 1st order derivative = (sample - pred_original_sample) / sigma_hat diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index d4a35ab82502..bdf9379b9b90 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -24,7 +25,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -37,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -130,10 +146,16 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 else: - pos = 0 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + return indices[pos].item() @property @@ -245,6 +267,10 @@ def set_timesteps( self.sample = None + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() @@ -295,6 +321,10 @@ def step( """ step_index = self.index_for_timestep(timestep) + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index] diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 39079fde10d2..a6a1b4e6640d 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -23,7 +24,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -36,19 +41,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) @@ -129,10 +145,16 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): indices = (schedule_timesteps == timestep).nonzero() - if self.state_in_first_order: - pos = -1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(self._index_counter) == 0: + pos = 1 if len(indices) > 1 else 0 else: - pos = 0 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + pos = self._index_counter[timestep_int] + return indices[pos].item() @property @@ -234,6 +256,10 @@ def set_timesteps( self.sample = None + # for exp beta schedules, such as the one for `pipeline_shap_e.py` + # we need an index counter + self._index_counter = defaultdict(int) + def sigma_to_t(self, sigma): # get log sigma log_sigma = sigma.log() @@ -283,6 +309,10 @@ def step( """ step_index = self.index_for_timestep(timestep) + # advance index counter by 1 + timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep + self._index_counter[timestep_int] += 1 + if self.state_in_first_order: sigma = self.sigmas[step_index] sigma_interpol = self.sigmas_interpol[step_index + 1] diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1256660b843c..d58d4ce45bd1 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -45,7 +45,11 @@ class LMSDiscreteSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -58,19 +62,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 70ee1301129c..794eb3674c1b 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -25,7 +25,11 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -38,19 +42,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index f2f97b38f3d3..41e7450d2df6 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -43,7 +43,11 @@ class RePaintSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -56,19 +60,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index d7f927658c8a..fd23e48bad00 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -44,7 +44,11 @@ class UnCLIPSchedulerOutput(BaseOutput): # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. @@ -57,19 +61,30 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; use values lower than 1 to prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ + if alpha_transform_type == "cosine": - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index bdb7f020a0aa..7449df99ba80 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -104,7 +104,7 @@ ) from .torch_utils import maybe_allow_in_graph -from .testing_utils import export_to_video +from .testing_utils import export_to_gif, export_to_video logger = get_logger(__name__) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6d39c0c67d9d..164206d776fa 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -377,6 +377,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ShapEImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ShapEPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionAttendAndExcitePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index dcb80169de74..64eb3ac925e9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -300,6 +300,21 @@ def preprocess_image(image: PIL.Image, batch_size: int): return 2.0 * image - 1.0 +def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str: + if output_gif_path is None: + output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name + + image[0].save( + output_gif_path, + save_all=True, + append_images=image[1:], + optimize=False, + duration=100, + loop=0, + ) + return output_gif_path + + def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: if is_opencv_available(): import cv2 diff --git a/tests/pipelines/shap_e/__init__.py b/tests/pipelines/shap_e/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py new file mode 100644 index 000000000000..d095dd9d49b9 --- /dev/null +++ b/tests/pipelines/shap_e/test_shap_e.py @@ -0,0 +1,265 @@ +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline +from diffusers.pipelines.shap_e import ShapERenderer +from diffusers.utils import load_numpy, slow +from diffusers.utils.testing_utils import require_torch_gpu, torch_device + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +class ShapEPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ShapEPipeline + params = ["prompt"] + batch_params = ["prompt"] + required_optional_params = [ + "num_images_per_prompt", + "num_inference_steps", + "generator", + "latents", + "guidance_scale", + "frame_size", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def renderer_dim(self): + return 8 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config) + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 16, + "embedding_dim": self.time_input_dim, + "num_embeddings": 32, + "embedding_proj_dim": self.text_embedder_hidden_size, + "time_embed_dim": self.time_embed_dim, + "num_layers": 1, + "clip_embed_dim": self.time_input_dim * 2, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + } + + model = PriorTransformer(**model_kwargs) + return model + + @property + def dummy_renderer(self): + torch.manual_seed(0) + + model_kwargs = { + "param_shapes": ( + (self.renderer_dim, 93), + (self.renderer_dim, 8), + (self.renderer_dim, 8), + (self.renderer_dim, 8), + ), + "d_latent": self.time_input_dim, + "d_hidden": self.renderer_dim, + "n_output": 12, + "background": ( + 0.1, + 0.1, + 0.1, + ), + } + model = ShapERenderer(**model_kwargs) + return model + + def get_dummy_components(self): + prior = self.dummy_prior + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + renderer = self.dummy_renderer + + scheduler = HeunDiscreteScheduler( + beta_schedule="exp", + num_train_timesteps=1024, + prediction_type="sample", + use_karras_sigmas=True, + clip_sample=True, + clip_sample_range=1.0, + ) + components = { + "prior": prior, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "renderer": renderer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "num_inference_steps": 1, + "frame_size": 32, + "output_type": "np", + } + return inputs + + def test_shap_e(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images[0] + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (20, 32, 32, 3) + + expected_slice = np.array( + [ + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_inference_batch_consistent(self): + # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches + self._test_inference_batch_consistent(batch_sizes=[1, 2]) + + def test_inference_batch_single_identical(self): + test_max_difference = torch_device == "cpu" + relax_max_difference = True + + self._test_inference_batch_single_identical( + batch_size=2, + test_max_difference=test_max_difference, + relax_max_difference=relax_max_difference, + ) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_size = 1 + num_images_per_prompt = 2 + + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + +@slow +@require_torch_gpu +class ShapEPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_shap_e(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/shap_e/test_shap_e_np_out.npy" + ) + pipe = ShapEPipeline.from_pretrained("openai/shap-e") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + + images = pipe( + "a shark", + generator=generator, + guidance_scale=15.0, + num_inference_steps=64, + frame_size=64, + output_type="np", + ).images[0] + + assert images.shape == (20, 64, 64, 3) + + assert_mean_pixel_difference(images, expected_image) diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py new file mode 100644 index 000000000000..f6638a994fdd --- /dev/null +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -0,0 +1,281 @@ +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel + +from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline +from diffusers.pipelines.shap_e import ShapERenderer +from diffusers.utils import floats_tensor, load_image, load_numpy, slow +from diffusers.utils.testing_utils import require_torch_gpu, torch_device + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = ShapEImg2ImgPipeline + params = ["image"] + batch_params = ["image"] + required_optional_params = [ + "num_images_per_prompt", + "num_inference_steps", + "generator", + "latents", + "guidance_scale", + "frame_size", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def renderer_dim(self): + return 8 + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=self.text_embedder_hidden_size, + image_size=64, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=1, + ) + + model = CLIPVisionModel(config) + return model + + @property + def dummy_image_processor(self): + image_processor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + + return image_processor + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 16, + "embedding_dim": self.time_input_dim, + "num_embeddings": 32, + "embedding_proj_dim": self.text_embedder_hidden_size, + "time_embed_dim": self.time_embed_dim, + "num_layers": 1, + "clip_embed_dim": self.time_input_dim * 2, + "additional_embeddings": 0, + "time_embed_act_fn": "gelu", + "norm_in_type": "layer", + "embedding_proj_norm_type": "layer", + "encoder_hid_proj_type": None, + "added_emb_type": None, + } + + model = PriorTransformer(**model_kwargs) + return model + + @property + def dummy_renderer(self): + torch.manual_seed(0) + + model_kwargs = { + "param_shapes": ( + (self.renderer_dim, 93), + (self.renderer_dim, 8), + (self.renderer_dim, 8), + (self.renderer_dim, 8), + ), + "d_latent": self.time_input_dim, + "d_hidden": self.renderer_dim, + "n_output": 12, + "background": ( + 0.1, + 0.1, + 0.1, + ), + } + model = ShapERenderer(**model_kwargs) + return model + + def get_dummy_components(self): + prior = self.dummy_prior + image_encoder = self.dummy_image_encoder + image_processor = self.dummy_image_processor + renderer = self.dummy_renderer + + scheduler = HeunDiscreteScheduler( + beta_schedule="exp", + num_train_timesteps=1024, + prediction_type="sample", + use_karras_sigmas=True, + clip_sample=True, + clip_sample_range=1.0, + ) + components = { + "prior": prior, + "image_encoder": image_encoder, + "image_processor": image_processor, + "renderer": renderer, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + input_image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image": input_image, + "generator": generator, + "num_inference_steps": 1, + "frame_size": 32, + "output_type": "np", + } + return inputs + + def test_shap_e(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images[0] + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (20, 32, 32, 3) + + expected_slice = np.array( + [ + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + 0.00039216, + ] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_inference_batch_consistent(self): + # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches + self._test_inference_batch_consistent(batch_sizes=[1, 2]) + + def test_inference_batch_single_identical(self): + test_max_difference = torch_device == "cpu" + relax_max_difference = True + self._test_inference_batch_single_identical( + batch_size=2, + test_max_difference=test_max_difference, + relax_max_difference=relax_max_difference, + ) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_size = 1 + num_images_per_prompt = 2 + + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + +@slow +@require_torch_gpu +class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_shap_e_img2img(self): + input_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/shap_e/test_shap_e_img2img_out.npy" + ) + pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img") + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + + images = pipe( + input_image, + generator=generator, + guidance_scale=3.0, + num_inference_steps=64, + frame_size=64, + output_type="np", + ).images[0] + + assert images.shape == (20, 64, 64, 3) + + assert_mean_pixel_difference(images, expected_image) diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py index 2fd50425938f..ae0fe26b11ba 100644 --- a/tests/schedulers/test_scheduler_heun.py +++ b/tests/schedulers/test_scheduler_heun.py @@ -30,11 +30,15 @@ def test_betas(self): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self): - for schedule in ["linear", "scaled_linear"]: + for schedule in ["linear", "scaled_linear", "exp"]: self.check_over_configs(beta_schedule=schedule) + def test_clip_sample(self): + for clip_sample_range in [1.0, 2.0, 3.0]: + self.check_over_configs(clip_sample_range=clip_sample_range, clip_sample=True) + def test_prediction_type(self): - for prediction_type in ["epsilon", "v_prediction"]: + for prediction_type in ["epsilon", "v_prediction", "sample"]: self.check_over_configs(prediction_type=prediction_type) def test_full_loop_no_noise(self):