2323from diffusers .schedulers import UnCLIPScheduler
2424from transformers import CLIPTextModelWithProjection , CLIPTokenizer
2525
26- from ...utils import logging
26+ from ...utils import is_accelerate_available , logging
2727from .text_proj import UnCLIPTextProjModel
2828
2929
@@ -115,7 +115,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
115115 latents = latents * scheduler .init_noise_sigma
116116 return latents
117117
118- def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance ):
118+ def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance ):
119119 batch_size = len (prompt ) if isinstance (prompt , list ) else 1
120120
121121 # get prompt text embeddings
@@ -126,7 +126,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
126126 return_tensors = "pt" ,
127127 )
128128 text_input_ids = text_inputs .input_ids
129- text_mask = text_inputs .attention_mask .bool ().to (self . device )
129+ text_mask = text_inputs .attention_mask .bool ().to (device )
130130
131131 if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
132132 removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
@@ -136,7 +136,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
136136 )
137137 text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
138138
139- text_encoder_output = self .text_encoder (text_input_ids .to (self . device ))
139+ text_encoder_output = self .text_encoder (text_input_ids .to (device ))
140140
141141 text_embeddings = text_encoder_output .text_embeds
142142 text_encoder_hidden_states = text_encoder_output .last_hidden_state
@@ -156,8 +156,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
156156 truncation = True ,
157157 return_tensors = "pt" ,
158158 )
159- uncond_text_mask = uncond_input .attention_mask .bool ().to (self . device )
160- uncond_embeddings_text_encoder_output = self .text_encoder (uncond_input .input_ids .to (self . device ))
159+ uncond_text_mask = uncond_input .attention_mask .bool ().to (device )
160+ uncond_embeddings_text_encoder_output = self .text_encoder (uncond_input .input_ids .to (device ))
161161
162162 uncond_embeddings = uncond_embeddings_text_encoder_output .text_embeds
163163 uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output .last_hidden_state
@@ -187,6 +187,49 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
187187
188188 return text_embeddings , text_encoder_hidden_states , text_mask
189189
190+ def enable_sequential_cpu_offload (self , gpu_id = 0 ):
191+ r"""
192+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
193+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
194+ when their specific submodule has its `forward` method called.
195+ """
196+ if is_accelerate_available ():
197+ from accelerate import cpu_offload
198+ else :
199+ raise ImportError ("Please install accelerate via `pip install accelerate`" )
200+
201+ device = torch .device (f"cuda:{ gpu_id } " )
202+
203+ # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
204+ models = [
205+ self .decoder ,
206+ self .text_proj ,
207+ self .text_encoder ,
208+ self .super_res_first ,
209+ self .super_res_last ,
210+ ]
211+ for cpu_offloaded_model in models :
212+ if cpu_offloaded_model is not None :
213+ cpu_offload (cpu_offloaded_model , device )
214+
215+ @property
216+ def _execution_device (self ):
217+ r"""
218+ Returns the device on which the pipeline's models will be executed. After calling
219+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
220+ hooks.
221+ """
222+ if self .device != torch .device ("meta" ) or not hasattr (self .decoder , "_hf_hook" ):
223+ return self .device
224+ for module in self .decoder .modules ():
225+ if (
226+ hasattr (module , "_hf_hook" )
227+ and hasattr (module ._hf_hook , "execution_device" )
228+ and module ._hf_hook .execution_device is not None
229+ ):
230+ return torch .device (module ._hf_hook .execution_device )
231+ return self .device
232+
190233 @torch .no_grad ()
191234 def __call__ (
192235 self ,
@@ -254,25 +297,26 @@ def __call__(
254297 batch_size = len (prompt )
255298 else :
256299 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
300+ device = self ._execution_device
257301
258302 batch_size = batch_size * num_images_per_prompt
259303
260304 do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
261305
262306 text_embeddings , text_encoder_hidden_states , text_mask = self ._encode_prompt (
263- prompt , num_images_per_prompt , do_classifier_free_guidance
307+ prompt , device , num_images_per_prompt , do_classifier_free_guidance
264308 )
265309
266310 # prior
267311
268- self .prior_scheduler .set_timesteps (prior_num_inference_steps , device = self . device )
312+ self .prior_scheduler .set_timesteps (prior_num_inference_steps , device = device )
269313 prior_timesteps_tensor = self .prior_scheduler .timesteps
270314
271315 embedding_dim = self .prior .config .embedding_dim
272316 prior_latents = self .prepare_latents (
273317 (batch_size , embedding_dim ),
274318 text_embeddings .dtype ,
275- self . device ,
319+ device ,
276320 generator ,
277321 prior_latents ,
278322 self .prior_scheduler ,
@@ -326,7 +370,7 @@ def __call__(
326370
327371 decoder_text_mask = F .pad (text_mask , (self .text_proj .clip_extra_context_tokens , 0 ), value = 1 )
328372
329- self .decoder_scheduler .set_timesteps (decoder_num_inference_steps , device = self . device )
373+ self .decoder_scheduler .set_timesteps (decoder_num_inference_steps , device = device )
330374 decoder_timesteps_tensor = self .decoder_scheduler .timesteps
331375
332376 num_channels_latents = self .decoder .in_channels
@@ -335,7 +379,7 @@ def __call__(
335379 decoder_latents = self .prepare_latents (
336380 (batch_size , num_channels_latents , height , width ),
337381 text_encoder_hidden_states .dtype ,
338- self . device ,
382+ device ,
339383 generator ,
340384 decoder_latents ,
341385 self .decoder_scheduler ,
@@ -378,7 +422,7 @@ def __call__(
378422
379423 # super res
380424
381- self .super_res_scheduler .set_timesteps (super_res_num_inference_steps , device = self . device )
425+ self .super_res_scheduler .set_timesteps (super_res_num_inference_steps , device = device )
382426 super_res_timesteps_tensor = self .super_res_scheduler .timesteps
383427
384428 channels = self .super_res_first .in_channels // 2
@@ -387,7 +431,7 @@ def __call__(
387431 super_res_latents = self .prepare_latents (
388432 (batch_size , channels , height , width ),
389433 image_small .dtype ,
390- self . device ,
434+ device ,
391435 generator ,
392436 super_res_latents ,
393437 self .super_res_scheduler ,
0 commit comments