Skip to content

Commit f429954

Browse files
Debug sd conv gn geglu (#7)
Co-authored-by: liujuncheng <[email protected]>
1 parent f6ac84a commit f429954

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

src/diffusers/models/attention_oneflow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,5 +349,13 @@ def __init__(self, dim_in: int, dim_out: int):
349349
self.proj = nn.Linear(dim_in, dim_out * 2)
350350

351351
def forward(self, hidden_states):
352+
x_shape = hidden_states.shape
353+
if len(x_shape) != 2:
354+
hidden_states = hidden_states.reshape(-1, x_shape[-1])
355+
out = torch._C.fused_geglu(hidden_states, self.proj.weight, self.proj.bias)
356+
if len(x_shape) != 2:
357+
out_shape = x_shape[0:len(x_shape) -1 ] + (-1, )
358+
out = out.reshape(out_shape)
359+
return out
352360
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
353361
return hidden_states * F.gelu(gate)

src/diffusers/pipeline_oneflow_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
353353
class_name = "OneFlow" + class_name
354354
print(f"[oneflow]", f"[{name}]", f"{library_name}.{class_name}")
355355
else:
356-
print(f"[python]", f"[{name}]", f"{library_name}.{class_name}")
356+
print(f"[diffusers]", f"[{name}]", f"{library_name}.{class_name}")
357357
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
358358
if class_name.startswith("Flax"):
359359
class_name = class_name[4:]

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_oneflow.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, unet):
3232
self.config.enable_cudnn_conv_heuristic_search_algo(False)
3333

3434
def build(self, latent_model_input, t, text_embeddings):
35+
text_embeddings = torch._C.amp_white_identity(text_embeddings)
3536
return self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
3637

3738
class OneFlowStableDiffusionPipeline(DiffusionPipeline):
@@ -98,6 +99,8 @@ def __init__(
9899
safety_checker=safety_checker,
99100
feature_extractor=feature_extractor,
100101
)
102+
self.unet_graph = UNetGraph(self.unet)
103+
self.unet_compiled = False
101104

102105
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
103106
r"""
@@ -185,6 +188,8 @@ def __call__(
185188
(nsfw) content, according to the `safety_checker`.
186189
"""
187190

191+
from timeit import default_timer as timer
192+
start = timer()
188193
if "torch_device" in kwargs:
189194
device = kwargs.pop("torch_device")
190195
warnings.warn(
@@ -271,12 +276,17 @@ def __call__(
271276
if accepts_eta:
272277
extra_step_kwargs["eta"] = eta
273278

274-
unet_graph = UNetGraph(self.unet)
275-
276-
print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
277-
i, t = list(enumerate(self.scheduler.timesteps))[0]
278-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
279-
unet_graph._compile(latent_model_input, t, text_embeddings)
279+
compilation_start = timer()
280+
compilation_time = 0
281+
if self.unet_compiled == False:
282+
print("[oneflow]", "compiling unet beforehand to make sure the progress bar is more accurate")
283+
i, t = list(enumerate(self.scheduler.timesteps))[0]
284+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
285+
self.unet_graph._compile(latent_model_input, t, text_embeddings)
286+
self.unet_compiled = True
287+
self.unet_graph(latent_model_input, t, text_embeddings) # warmup
288+
compilation_time = timer() - compilation_start
289+
print("[oneflow]", "[elapsed(s)]", "[unet compilation]", compilation_time)
280290

281291
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
282292
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}")
@@ -289,7 +299,7 @@ def __call__(
289299

290300
# predict the noise residual
291301
torch._oneflow_internal.profiler.RangePush(f"denoise-{i}-unet-graph")
292-
noise_pred = unet_graph(latent_model_input, t, text_embeddings)
302+
noise_pred = self.unet_graph(latent_model_input, t, text_embeddings)
293303
torch._oneflow_internal.profiler.RangePop()
294304

295305
# perform guidance
@@ -310,6 +320,8 @@ def __call__(
310320
if isinstance(latents, np.ndarray):
311321
latents = torch.from_numpy(latents)
312322
image = self.vae.decode(latents).sample
323+
print("[oneflow]", "[elapsed(s)]", "[image]", timer() - start - compilation_time)
324+
post_process_start = timer()
313325

314326
image = (image / 2 + 0.5).clamp(0, 1)
315327
image = image.cpu().permute(0, 2, 3, 1).numpy()
@@ -328,4 +340,6 @@ def __call__(
328340
return (image, has_nsfw_concept)
329341
import torch as og_torch
330342
assert og_torch.cuda.is_initialized() is False
343+
344+
print("[oneflow]", "[elapsed(s)]", "[post-process]", timer() - post_process_start)
331345
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

0 commit comments

Comments
 (0)