Skip to content

error in pipeline_flax_stable_diffusion #1746

@yiyixuxu

Description

@yiyixuxu

Describe the bug

I run into this error when trying to run the flax example of dreambooth exactly as it is instructed https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-with-flaxjax

the error message is:

Traceback (most recent call last):
  File "train_dreambooth_flax.py", line 656, in <module>
    main()
  File "train_dreambooth_flax.py", line 363, in main
    images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
  File "/home/yixu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 343, in __call__
    guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 793, in _reshape
    return lax.reshape(a, newshape, None)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 882, in reshape
    return reshape_p.bind(
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 699, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 197, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2081, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 192, in prim_fun
    out = prim.bind(*args, **params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1716, in process_primitive
    return custom_staging_rules[primitive](self, *tracers, **params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3305, in _reshape_staging_rule
    return trace.default_process_primitive(reshape_p, (x,), params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1721, in default_process_primitive
    out_avals, effects = primitive.abstract_eval(*avals, **params)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 365, in abstract_eval_
    return abstract_eval(*args, **kwargs), no_effects
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
    return core.ShapedArray(shape_rule(*avals, **kwargs),
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3244, in _reshape_shape_rule
    not core.same_shape_sizes(np.shape(operand), new_sizes)):
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1834, in same_shape_sizes
    return 1 == divide_shape_sizes(s1, s2)
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1831, in divide_shape_sizes
    return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
  File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1727, in divide_shape_sizes
    raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (8,) and (8, 4)

I think it is caused by this code here
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L334

if isinstance(guidance_scale, float):
            # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
            # shape information, as they may be sharded (when `jit` is `True`), or not.
            guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
            if len(prompt_ids.shape) > 2:
                # Assume sharded
                guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])

the reshape only works when batch size is 1
we can probably change to something like this?

guidance_scale=guidance_scale.reshape(prompt_ids.shape[0] ,1)

Reproduction

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

python train_dreambooth_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --class_data_dir=$CLASS_DIR \
  --output_dir=$OUTPUT_DIR \
  --with_prior_preservation --prior_loss_weight=1.0 \
  --instance_prompt="a photo of sks dog" \
  --class_prompt="a photo of dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --learning_rate=5e-6 \
  --num_class_images=200 \
  --max_train_steps=800

Logs

No response

System Info

  • diffusers version: 0.10.2
  • Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.13.1+cu117 (False)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.25.1
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: pmap

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions