Skip to content

Commit bc458b4

Browse files
authored
Add CPU offloading to UnCLIP (huggingface#1761)
* Add CPU offloading to UnCLIP * use fp32 for testing the offload
1 parent 45eac8b commit bc458b4

File tree

1 file changed

+57
-13
lines changed

1 file changed

+57
-13
lines changed

pipelines/unclip/pipeline_unclip.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from diffusers.schedulers import UnCLIPScheduler
2424
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
2525

26-
from ...utils import logging
26+
from ...utils import is_accelerate_available, logging
2727
from .text_proj import UnCLIPTextProjModel
2828

2929

@@ -115,7 +115,7 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
115115
latents = latents * scheduler.init_noise_sigma
116116
return latents
117117

118-
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
118+
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
119119
batch_size = len(prompt) if isinstance(prompt, list) else 1
120120

121121
# get prompt text embeddings
@@ -126,7 +126,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
126126
return_tensors="pt",
127127
)
128128
text_input_ids = text_inputs.input_ids
129-
text_mask = text_inputs.attention_mask.bool().to(self.device)
129+
text_mask = text_inputs.attention_mask.bool().to(device)
130130

131131
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
132132
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
@@ -136,7 +136,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
136136
)
137137
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
138138

139-
text_encoder_output = self.text_encoder(text_input_ids.to(self.device))
139+
text_encoder_output = self.text_encoder(text_input_ids.to(device))
140140

141141
text_embeddings = text_encoder_output.text_embeds
142142
text_encoder_hidden_states = text_encoder_output.last_hidden_state
@@ -156,8 +156,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
156156
truncation=True,
157157
return_tensors="pt",
158158
)
159-
uncond_text_mask = uncond_input.attention_mask.bool().to(self.device)
160-
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device))
159+
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
160+
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
161161

162162
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
163163
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
@@ -187,6 +187,49 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
187187

188188
return text_embeddings, text_encoder_hidden_states, text_mask
189189

190+
def enable_sequential_cpu_offload(self, gpu_id=0):
191+
r"""
192+
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
193+
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
194+
when their specific submodule has its `forward` method called.
195+
"""
196+
if is_accelerate_available():
197+
from accelerate import cpu_offload
198+
else:
199+
raise ImportError("Please install accelerate via `pip install accelerate`")
200+
201+
device = torch.device(f"cuda:{gpu_id}")
202+
203+
# TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
204+
models = [
205+
self.decoder,
206+
self.text_proj,
207+
self.text_encoder,
208+
self.super_res_first,
209+
self.super_res_last,
210+
]
211+
for cpu_offloaded_model in models:
212+
if cpu_offloaded_model is not None:
213+
cpu_offload(cpu_offloaded_model, device)
214+
215+
@property
216+
def _execution_device(self):
217+
r"""
218+
Returns the device on which the pipeline's models will be executed. After calling
219+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
220+
hooks.
221+
"""
222+
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
223+
return self.device
224+
for module in self.decoder.modules():
225+
if (
226+
hasattr(module, "_hf_hook")
227+
and hasattr(module._hf_hook, "execution_device")
228+
and module._hf_hook.execution_device is not None
229+
):
230+
return torch.device(module._hf_hook.execution_device)
231+
return self.device
232+
190233
@torch.no_grad()
191234
def __call__(
192235
self,
@@ -254,25 +297,26 @@ def __call__(
254297
batch_size = len(prompt)
255298
else:
256299
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
300+
device = self._execution_device
257301

258302
batch_size = batch_size * num_images_per_prompt
259303

260304
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
261305

262306
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
263-
prompt, num_images_per_prompt, do_classifier_free_guidance
307+
prompt, device, num_images_per_prompt, do_classifier_free_guidance
264308
)
265309

266310
# prior
267311

268-
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device)
312+
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
269313
prior_timesteps_tensor = self.prior_scheduler.timesteps
270314

271315
embedding_dim = self.prior.config.embedding_dim
272316
prior_latents = self.prepare_latents(
273317
(batch_size, embedding_dim),
274318
text_embeddings.dtype,
275-
self.device,
319+
device,
276320
generator,
277321
prior_latents,
278322
self.prior_scheduler,
@@ -326,7 +370,7 @@ def __call__(
326370

327371
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
328372

329-
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device)
373+
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
330374
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
331375

332376
num_channels_latents = self.decoder.in_channels
@@ -335,7 +379,7 @@ def __call__(
335379
decoder_latents = self.prepare_latents(
336380
(batch_size, num_channels_latents, height, width),
337381
text_encoder_hidden_states.dtype,
338-
self.device,
382+
device,
339383
generator,
340384
decoder_latents,
341385
self.decoder_scheduler,
@@ -378,7 +422,7 @@ def __call__(
378422

379423
# super res
380424

381-
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device)
425+
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
382426
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
383427

384428
channels = self.super_res_first.in_channels // 2
@@ -387,7 +431,7 @@ def __call__(
387431
super_res_latents = self.prepare_latents(
388432
(batch_size, channels, height, width),
389433
image_small.dtype,
390-
self.device,
434+
device,
391435
generator,
392436
super_res_latents,
393437
self.super_res_scheduler,

0 commit comments

Comments
 (0)