Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,3 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module

## AutoencoderKL
[[autodoc]] AutoencoderKL

## FlaxModelMixin
[[autodoc]] FlaxModelMixin

## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput

## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel

## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput

## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput

## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL
2 changes: 1 addition & 1 deletion docs/source/api/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
To this end, the design of schedulers is such that:

- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
- Schedulers are currently by default in PyTorch.


## API
Expand Down
12 changes: 2 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@
"datasets",
"filelock",
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.10.0",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards>=0.1.4",
"numpy",
"onnxruntime",
Expand Down Expand Up @@ -188,15 +185,9 @@ def run(self):
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")

if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")

extras["dev"] = (
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
)

install_requires = [
Expand All @@ -207,6 +198,7 @@ def run(self):
deps["regex"],
deps["requests"],
deps["Pillow"],
deps["torch"]
]

setup(
Expand Down
23 changes: 0 additions & 23 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .utils import (
is_flax_available,
is_inflect_available,
is_onnx_available,
is_scipy_available,
Expand Down Expand Up @@ -61,25 +60,3 @@
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403

if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
from .utils.dummy_flax_objects import * # noqa F403

if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
3 changes: 0 additions & 3 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.10.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
Expand Down
6 changes: 1 addition & 5 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_flax_available, is_torch_available
from ..utils import is_torch_available


if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel

if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
3 changes: 0 additions & 3 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,3 @@

if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline

if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline
26 changes: 1 addition & 25 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL
from PIL import Image

from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from ...utils import BaseOutput, is_onnx_available, is_torch_available, is_transformers_available


@dataclass
Expand Down Expand Up @@ -35,27 +35,3 @@ class StableDiffusionPipelineOutput(BaseOutput):

if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline

if is_transformers_available() and is_flax_available():
import flax

@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.

Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""

images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]

from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
13 changes: 1 addition & 12 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from ..utils import is_flax_available, is_scipy_available, is_torch_available
from ..utils import is_scipy_available, is_torch_available


if is_torch_available():
Expand All @@ -27,17 +27,6 @@
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
else:
from ..utils.dummy_flax_objects import * # noqa F403


if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
Expand Down