Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe = pipe.to("cuda")

prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
pipe.set_scheduler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]

Expand Down Expand Up @@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

pipe.set_sampler("sample_euler")
pipe.set_scheduler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
Expand Down
5 changes: 5 additions & 0 deletions examples/community/sd_text2img_k_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import importlib
import warnings
from typing import Callable, List, Optional, Union

import torch
Expand Down Expand Up @@ -111,6 +112,10 @@ def __init__(
self.k_diffusion_model = CompVisDenoiser(model)

def set_sampler(self, scheduler_type: str):
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
return self.set_scheduler(scheduler_type)

def set_scheduler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
Expand Down
1 change: 1 addition & 0 deletions hi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"k-diffusion",
"librosa",
"modelcards>=0.1.4",
"numpy",
Expand Down Expand Up @@ -182,6 +183,7 @@ def run(self):
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"k-diffusion",
"librosa",
"parameterized",
"pytest",
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .utils import (
is_flax_available,
is_inflect_available,
is_k_diffusion_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
Expand Down Expand Up @@ -90,6 +91,11 @@
else:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .pipelines import StableDiffusionKDiffusionPipeline
else:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403

if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion",
"librosa": "librosa",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..utils import (
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_onnx_available,
is_torch_available,
Expand Down Expand Up @@ -56,5 +57,8 @@
StableDiffusionOnnxPipeline,
)

if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
from .stable_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...utils import (
BaseOutput,
is_flax_available,
is_k_diffusion_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
Expand Down Expand Up @@ -48,6 +49,9 @@ class StableDiffusionPipelineOutput(BaseOutput):
else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline

if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline

if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
Expand Down
Loading