Skip to content

Commit 71f2349

Browse files
committed
fix style/quality for code
1 parent 8ee23f2 commit 71f2349

File tree

9 files changed

+40
-19
lines changed

9 files changed

+40
-19
lines changed

examples/community/composable_stable_diffusion.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -481,13 +481,13 @@ def __call__(
481481
# corresponds to doing no classifier free guidance.
482482
do_classifier_free_guidance = guidance_scale > 1.0
483483

484-
if '|' in prompt:
485-
prompt = [x.strip() for x in prompt.split('|')]
484+
if "|" in prompt:
485+
prompt = [x.strip() for x in prompt.split("|")]
486486
print(f"composing {prompt}...")
487487

488488
if not weights:
489489
# specify weights for prompts (excluding the unconditional score)
490-
print('using equal positive weights (conjunction) for all prompts...')
490+
print("using equal positive weights (conjunction) for all prompts...")
491491
weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
492492
else:
493493
# set prompt weight for each
@@ -546,7 +546,9 @@ def __call__(
546546
# perform guidance
547547
if do_classifier_free_guidance:
548548
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
549-
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(dim=0, keepdims=True)
549+
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
550+
dim=0, keepdims=True
551+
)
550552

551553
# compute the previous noisy sample x_t -> x_t-1
552554
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
@@ -570,4 +572,4 @@ def __call__(
570572
if not return_dict:
571573
return (image, has_nsfw_concept)
572574

573-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
575+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

examples/textual_inversion/textual_inversion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,10 @@ def __getitem__(self, i):
336336

337337
if self.center_crop:
338338
crop = min(img.shape[0], img.shape[1])
339-
h, w, = (
339+
(
340+
h,
341+
w,
342+
) = (
340343
img.shape[0],
341344
img.shape[1],
342345
)

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,10 @@ def __getitem__(self, i):
306306

307307
if self.center_crop:
308308
crop = min(img.shape[0], img.shape[1])
309-
h, w, = (
309+
(
310+
h,
311+
w,
312+
) = (
310313
img.shape[0],
311314
img.shape[1],
312315
)

src/diffusers/modeling_flax_pytorch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def rename_key(key):
3737
# PyTorch => Flax #
3838
#####################
3939

40+
4041
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
4142
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
4243
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):

src/diffusers/models/attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,10 @@ def __init__(
288288
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
289289
if not is_xformers_available():
290290
raise ModuleNotFoundError(
291-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
292-
" xformers",
291+
(
292+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
293+
" xformers"
294+
),
293295
name="xformers",
294296
)
295297
elif not torch.cuda.is_available():
@@ -450,8 +452,10 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
450452
if not is_xformers_available():
451453
print("Here is how to install it")
452454
raise ModuleNotFoundError(
453-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
454-
" xformers",
455+
(
456+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
457+
" xformers"
458+
),
455459
name="xformers",
456460
)
457461
elif not torch.cuda.is_available():

src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,11 @@ def step(
189189
or isinstance(timestep, torch.LongTensor)
190190
):
191191
raise ValueError(
192-
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
193-
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
194-
" one of the `scheduler.timesteps` as a timestep.",
192+
(
193+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
194+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
195+
" one of the `scheduler.timesteps` as a timestep."
196+
),
195197
)
196198

197199
if not self.is_scale_input_called:

src/diffusers/schedulers/scheduling_euler_discrete.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,11 @@ def step(
198198
or isinstance(timestep, torch.LongTensor)
199199
):
200200
raise ValueError(
201-
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
202-
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
203-
" one of the `scheduler.timesteps` as a timestep.",
201+
(
202+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
203+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
204+
" one of the `scheduler.timesteps` as a timestep."
205+
),
204206
)
205207

206208
if not self.is_scale_input_called:

tests/test_scheduler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,8 +535,10 @@ def test_scheduler_public_api(self):
535535
)
536536
self.assertTrue(
537537
hasattr(scheduler, "scale_model_input"),
538-
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
539-
" timestep)`",
538+
(
539+
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
540+
" timestep)`"
541+
),
540542
)
541543
self.assertTrue(
542544
hasattr(scheduler, "step"),

utils/custom_init_isort.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def _inner(x):
9696

9797
def sort_objects(objects, key=None):
9898
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
99+
99100
# If no key is provided, we use a noop.
100101
def noop(x):
101102
return x
@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
117118
"""
118119
Return the same `import_statement` but with objects properly sorted.
119120
"""
121+
120122
# This inner function sort imports between [ ].
121123
def _replace(match):
122124
imports = match.groups()[0]

0 commit comments

Comments
 (0)