Skip to content

Commit a643c63

Browse files
[K Diffusion] Add k diffusion sampler natively (#1603)
* uP * uP
1 parent 326de41 commit a643c63

File tree

13 files changed

+602
-2
lines changed

13 files changed

+602
-2
lines changed

examples/community/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
686686
pipe = pipe.to("cuda")
687687

688688
prompt = "an astronaut riding a horse on mars"
689-
pipe.set_sampler("sample_heun")
689+
pipe.set_scheduler("sample_heun")
690690
generator = torch.Generator(device="cuda").manual_seed(seed)
691691
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
692692

@@ -721,7 +721,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
721721
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
722722
pipe = pipe.to("cuda")
723723

724-
pipe.set_sampler("sample_euler")
724+
pipe.set_scheduler("sample_euler")
725725
generator = torch.Generator(device="cuda").manual_seed(seed)
726726
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
727727
```

examples/community/sd_text2img_k_diffusion.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import importlib
16+
import warnings
1617
from typing import Callable, List, Optional, Union
1718

1819
import torch
@@ -111,6 +112,10 @@ def __init__(
111112
self.k_diffusion_model = CompVisDenoiser(model)
112113

113114
def set_sampler(self, scheduler_type: str):
115+
warnings.warn("The `set_sampler` method is deprecated, please use `set_scheduler` instead.")
116+
return self.set_scheduler(scheduler_type)
117+
118+
def set_scheduler(self, scheduler_type: str):
114119
library = importlib.import_module("k_diffusion")
115120
sampling = getattr(library, "sampling")
116121
self.sampler = getattr(sampling, scheduler_type)

hi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
"isort>=5.5.4",
9292
"jax>=0.2.8,!=0.3.2",
9393
"jaxlib>=0.1.65",
94+
"k-diffusion",
9495
"librosa",
9596
"modelcards>=0.1.4",
9697
"numpy",
@@ -182,6 +183,7 @@ def run(self):
182183
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
183184
extras["test"] = deps_list(
184185
"datasets",
186+
"k-diffusion",
185187
"librosa",
186188
"parameterized",
187189
"pytest",

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .utils import (
66
is_flax_available,
77
is_inflect_available,
8+
is_k_diffusion_available,
89
is_onnx_available,
910
is_scipy_available,
1011
is_torch_available,
@@ -90,6 +91,11 @@
9091
else:
9192
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
9293

94+
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
95+
from .pipelines import StableDiffusionKDiffusionPipeline
96+
else:
97+
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
98+
9399
if is_torch_available() and is_transformers_available() and is_onnx_available():
94100
from .pipelines import (
95101
OnnxStableDiffusionImg2ImgPipeline,

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"isort": "isort>=5.5.4",
1616
"jax": "jax>=0.2.8,!=0.3.2",
1717
"jaxlib": "jaxlib>=0.1.65",
18+
"k-diffusion": "k-diffusion",
1819
"librosa": "librosa",
1920
"modelcards": "modelcards>=0.1.4",
2021
"numpy": "numpy",

src/diffusers/pipelines/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ..utils import (
22
is_flax_available,
3+
is_k_diffusion_available,
34
is_librosa_available,
45
is_onnx_available,
56
is_torch_available,
@@ -56,5 +57,8 @@
5657
StableDiffusionOnnxPipeline,
5758
)
5859

60+
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
61+
from .stable_diffusion import StableDiffusionKDiffusionPipeline
62+
5963
if is_transformers_available() and is_flax_available():
6064
from .stable_diffusion import FlaxStableDiffusionPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...utils import (
1010
BaseOutput,
1111
is_flax_available,
12+
is_k_diffusion_available,
1213
is_onnx_available,
1314
is_torch_available,
1415
is_transformers_available,
@@ -48,6 +49,9 @@ class StableDiffusionPipelineOutput(BaseOutput):
4849
else:
4950
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
5051

52+
if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
53+
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
54+
5155
if is_transformers_available() and is_onnx_available():
5256
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
5357
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline

0 commit comments

Comments
 (0)