diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index cb955a688643..dfe60d9794e1 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl( neg_prompt: str = "", neg_prompt_2: str = None, num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, ): """ This function can process long prompt with weights, no length limitation @@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl( neg_prompt (str) neg_prompt_2 (str) num_images_per_prompt (int) + device (torch.device) Returns: prompt_embeds (torch.Tensor) neg_prompt_embeds (torch.Tensor) """ + device = device or pipe._execution_device + if prompt_2: prompt = f"{prompt} {prompt_2}" @@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl( # get prompt embeddings one by one is not working. for i in range(len(prompt_token_groups)): # get positive prompt embeddings with weights - token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device) - weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device) + weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device) - token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) + token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device) # use first text encoder - prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True) + prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True) prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] # use second text encoder - prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True) + prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True) prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] pooled_prompt_embeds = prompt_embeds_2[0] @@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl( embeds.append(token_embedding) # get negative prompt embeddings with weights - neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device) - neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) - neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device) + neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device) + neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device) # use first text encoder - neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True) neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] # use second text encoder - neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True) neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]