Skip to content

Commit 7507641

Browse files
patil-surajsliard
authored andcommitted
[UnCLIPPipeline] fix num_images_per_prompt (huggingface#1762)
duplicate maks for num_images_per_prompt
1 parent 0fb53e4 commit 7507641

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/pipelines/unclip/pipeline_unclip.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
143143

144144
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
145145
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
146+
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
146147

147148
if do_classifier_free_guidance:
148149
uncond_tokens = [""] * batch_size
@@ -172,6 +173,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
172173
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
173174
batch_size * num_images_per_prompt, seq_len, -1
174175
)
176+
uncond_text_mask = uncond_text_mask.repeat(1, num_images_per_prompt)
175177

176178
# done duplicates
177179

0 commit comments

Comments
 (0)