-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
I run into this error when trying to run the flax example of dreambooth exactly as it is instructed https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-with-flaxjax
the error message is:
Traceback (most recent call last):
File "train_dreambooth_flax.py", line 656, in <module>
main()
File "train_dreambooth_flax.py", line 363, in main
images = pipeline(prompt_ids, p_params, sample_rng, jit=True).images
File "/home/yixu/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py", line 343, in __call__
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 793, in _reshape
return lax.reshape(a, newshape, None)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 882, in reshape
return reshape_p.bind(
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 699, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 197, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2081, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(*in_tracers_)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 192, in prim_fun
out = prim.bind(*args, **params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1716, in process_primitive
return custom_staging_rules[primitive](self, *tracers, **params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3305, in _reshape_staging_rule
return trace.default_process_primitive(reshape_p, (x,), params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1721, in default_process_primitive
out_avals, effects = primitive.abstract_eval(*avals, **params)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 365, in abstract_eval_
return abstract_eval(*args, **kwargs), no_effects
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
return core.ShapedArray(shape_rule(*avals, **kwargs),
File "/home/yixu/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 3244, in _reshape_shape_rule
not core.same_shape_sizes(np.shape(operand), new_sizes)):
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1834, in same_shape_sizes
return 1 == divide_shape_sizes(s1, s2)
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1831, in divide_shape_sizes
return handler.divide_shape_sizes(ds[:len(s1)], ds[len(s1):])
File "/home/yixu/.local/lib/python3.8/site-packages/jax/core.py", line 1727, in divide_shape_sizes
raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (8,) and (8, 4)I think it is caused by this code here
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L334
if isinstance(guidance_scale, float):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
if len(prompt_ids.shape) > 2:
# Assume sharded
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])the reshape only works when batch size is 1
we can probably change to something like this?
guidance_scale=guidance_scale.reshape(prompt_ids.shape[0] ,1)Reproduction
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--learning_rate=5e-6 \
--num_class_images=200 \
--max_train_steps=800Logs
No response
System Info
diffusersversion: 0.10.2- Platform: Linux-5.13.0-1027-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- PyTorch version (GPU?): 1.13.1+cu117 (False)
- Huggingface_hub version: 0.10.1
- Transformers version: 4.25.1
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: pmap
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working