Skip to content

Commit c0f50d4

Browse files
authored
[Dreambooth] flax fixes (huggingface#1765)
* Fail if there are less images than the effective batch size. * Remove lr-scheduler arg as it's currently ignored. * Make guidance_scale work for batch_size > 1.
1 parent 52f27aa commit c0f50d4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def __call__(
337337
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
338338
if len(prompt_ids.shape) > 2:
339339
# Assume sharded
340-
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
340+
guidance_scale = guidance_scale[:, None]
341341

342342
if jit:
343343
images = _p_generate(

0 commit comments

Comments
 (0)