-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add Shap-E #3742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Shap-E #3742
Conversation
adding conversion script add pipeline add step_index from pipeline, + remove permute add zero pad token remove copy from statement for betas_for_alpha_bar function
|
@patrickvonplaten When I compared the model forward pass (see equivalency test for model forward pass), the results matched nicely with the max element difference less than Not sure what to do here and appreciate any feedback/advices:) equivalency test for pipeline outputsthis script returns import torch
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 4
guidance_scale = 15.
sigma_min = 1e-3
prompt = "a shark"
# diffusers
from diffusers import ShapEPipeline
repo = "YiYiXu/shap-e"
pipe = ShapEPipeline.from_pretrained(repo)
pipe = pipe.to(device)
generator = torch.Generator(device="cuda").manual_seed(0)
latents_d = pipe(prompt, num_images_per_prompt=batch_size, generator=generator, guidance_scale=guidance_scale,num_inference_steps= 64, sigma_min=sigma_min).latents
# original
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
latents, _ = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=False,
use_karras=True,
karras_steps=64,
sigma_min=sigma_min,
sigma_max=160,
s_churn=0,
)
# compare
print("max diff latents:")
print(np.abs(latents.reshape(4, 1024,1024).detach().cpu().numpy() - latents_d.detach().cpu().numpy()).max())equivalency test for model forward pass
import torch
import numpy as np
import clip
from diffusers.models.prior_transformer import PriorTransformer
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from shap_e.models.download import load_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# create original model
model = load_model('text300M', device=device)
transformer = model.wrapped
# create diffusers model
path_shape = "YiYiXu/shap-e"
transformer_d = PriorTransformer.from_pretrained(path_shape, subfolder="prior").to(device)
# inputs
batch_size = 1
torch.manual_seed(0)
x = torch.randn([batch_size, 1024, 1024], device=device)
t = torch.tensor([0] * batch_size, device=device)
prompt = ["a shark"] * batch_size
# create embeddings using original clip model
clip_name = "ViT-L/14"
download_root= "/home/yiyi_huggingface_co/shap-e/shap_e_model_cache"
clip_model, _ = clip.load(clip_name, device=device, download_root=download_root)
tokenize = clip.tokenize
embeddings = clip_model.encode_text(
tokenize(list(prompt), truncate=True).to(device)
).float()
embeddings = embeddings / torch.linalg.norm(embeddings, dim=-1, keepdim=True)
# create embeddings using transformer clip
repo = "openai/clip-vit-large-patch14"
d_text_encoder = CLIPTextModelWithProjection.from_pretrained(repo).to(device)
d_tokenizer = CLIPTokenizer.from_pretrained(repo)
tokens = d_tokenizer(prompt, padding="max_length", max_length=d_tokenizer.model_max_length, truncation=True, return_tensors="pt",).input_ids
embeddings_d= d_text_encoder(tokens.to(device)).text_embeds.float()
embeddings_d = embeddings_d / torch.linalg.norm(embeddings_d, dim=-1, keepdim=True)
# compare the embeddings : 0.00019
print(f" compare embeddings: {np.abs(embeddings.detach().cpu().numpy() - embeddings_d.detach().cpu().numpy()).max()}")
# TEST1: compare the output using respective embeddings: 0.0012
# original output
out = transformer(x,t, embeddings = embeddings)
# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings_d, return_dict=False)[0]
print(" ")
print(" test1 result") # 0.0012
print((out_d.permute(0, 2, 1) - out).abs().max())
# TEST2: compare the outputs using same embedding: 4.6790e-06
# original output
out = transformer(x,t, embeddings = embeddings)
# diffusers output
out_d = transformer_d(x.permute(0,2,1), 0, embeddings, return_dict=False)[0]
print(" ")
print(" test2 result") # 4.6790e-06
print((out_d.permute(0, 2, 1) - out).abs().max())testing script compare the pipeline output with
|
| return t | ||
|
|
||
| # YiYi Notes: Taking from the origional repo, will refactor and not introduce dependency on spicy | ||
| def _sigma_to_t_yiyi(self, sigma): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
|
In the general the design looks good to me! I just noticed that we don't have any prior transformer tests so I added them here: #3796. This PR also allows to disable the PT 2 attention processor which should help with precision issues. Could you maybe merge #3796 into your PR and once it's merged and then use Thing we're on a good way here to have a powerful new model class in |
|
The documentation is not available anymore as the PR was closed or merged. |
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
| 255.0, | ||
| 255.0, | ||
| 255.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍

original repo: https://github.com/openai/shap-e
text-to-3D
generated from original code

image-to-3D
image:

3d

as a reference, this is the 3d render generated with original repo with same inputs and seed

To-do: