Skip to content

Commit dc01404

Browse files
authored
Add arg compile_unet (huggingface#17)
1 parent f429954 commit dc01404

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

src/diffusers/models/unet_blocks_oneflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ def set_attention_slice(self, slice_size):
371371

372372
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
373373
hidden_states = self.resnets[0](hidden_states, temb)
374-
for attn, resnet in zip(self.attentions, self.resnets[1:]):
374+
resnets_list = [m for m in self.resnets]
375+
for attn, resnet in zip(self.attentions, resnets_list[1:]):
375376
hidden_states = attn(hidden_states, encoder_hidden_states)
376377
hidden_states = resnet(hidden_states, temb)
377378

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __call__(
142142
latents: Optional[torch.FloatTensor] = None,
143143
output_type: Optional[str] = "pil",
144144
return_dict: bool = True,
145+
compile_unet: bool = True,
145146
**kwargs,
146147
):
147148
r"""
@@ -179,6 +180,8 @@ def __call__(
179180
return_dict (`bool`, *optional*, defaults to `True`):
180181
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
181182
plain tuple.
183+
compile_unet (`bool`, *optional*, defaults to `True`):
184+
Whether or not to compile unet as nn.graph
182185
183186
Returns:
184187
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
@@ -278,15 +281,16 @@ def __call__(
278281

279282
compilation_start = timer()
280283
compilation_time = 0
281-
if self.unet_compiled == False:
282-
print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
283-
i, t = list(enumerate(self.scheduler.timesteps))[0]
284-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
285-
self.unet_graph._compile(latent_model_input, t, text_embeddings)
286-
self.unet_compiled = True
287-
self.unet_graph(latent_model_input, t, text_embeddings) # warmup
288-
compilation_time = timer() - compilation_start
289-
print("[oneflow]", "[elapsed(s)]", "[unet compilation]", compilation_time)
284+
if compile_unet:
285+
if self.unet_compiled == False:
286+
print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
287+
i, t = list(enumerate(self.scheduler.timesteps))[0]
288+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
289+
self.unet_graph._compile(latent_model_input, t, text_embeddings)
290+
self.unet_compiled = True
291+
self.unet_graph(latent_model_input, t, text_embeddings) # warmup
292+
compilation_time = timer() - compilation_start
293+
print("[oneflow]", "[elapsed(s)]", "[unet compilation]", compilation_time)
290294

291295
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
292296
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}")
@@ -298,9 +302,12 @@ def __call__(
298302
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
299303

300304
# predict the noise residual
301-
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph")
302-
noise_pred = self.unet_graph(latent_model_input, t, text_embeddings)
303-
torch._oneflow_internal.profiler.RangePop()
305+
if compile_unet:
306+
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph")
307+
noise_pred = self.unet_graph(latent_model_input, t, text_embeddings)
308+
torch._oneflow_internal.profiler.RangePop()
309+
else:
310+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
304311

305312
# perform guidance
306313
if do_classifier_free_guidance:

0 commit comments

Comments
 (0)