Skip to content

StableDiffusionControlNetPipeline triggers unexpected keyword argument 'callback_on_step_end' #5798

@alexisrolland

Description

@alexisrolland

Describe the bug

When calling StableDiffusionControlNetPipeline with the new callback_on_step_end, it triggers the error message:

TypeError: StableDiffusionControlNetPipeline.call() got an unexpected keyword argument 'callback_on_step_end'

Reproduction

import os
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.utils import load_image

def test_callback(pipeline, step, timestep, kwargs):
    print (timestep)
    return kwargs

# Load ControlNet
MODEL_PATH_CONTROLNET_CANNY = os.getenv('MODEL_PATH_CONTROLNET_CANNY')
controlnet_canny = ControlNetModel.from_pretrained(MODEL_PATH_CONTROLNET_CANNY, torch_dtype=torch.float16)

# Load pipeline
MODEL_PATH = os.getenv('MODEL_PATH')
pipeline = StableDiffusionControlNetPipeline.from_pretrained(MODEL_PATH, controlnet=[controlnet_canny], torch_dtype=torch.float16)

# Generation settings
prompts = ["hello"]
seed = 123456
generators = torch.Generator(device='cuda').manual_seed(seed)
image = load_image("https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg")

# Run inference
image = pipeline(
    prompt=prompts,
    image=[image],
    controlnet_conditioning_scale=0.5,
    num_inference_steps=20,
    num_images_per_prompt=1,
    generator=generators,
    callback_on_step_end=test_callback,
    callback_on_step_end_tensor_inputs=[]
).images[0]

Logs

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
TypeError: StableDiffusionControlNetPipeline.__call__() got an unexpected keyword argument 'callback_on_step_end'

System Info

  • diffusers version: 0.23.0
  • Platform: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.36
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Huggingface_hub version: 0.17.3
  • Transformers version: 4.35.0
  • Accelerate version: 0.24.1
  • xFormers version: 0.0.22.post7
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@sayakpaul @yiyixuxu @DN6 @patrickvonplaten

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions