Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,33 @@
import warnings
from typing import Optional, Tuple

import numpy as np

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.struct import field
from transformers import CLIPVisionConfig
from transformers import CLIPConfig, FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule

from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...modeling_flax_utils import FlaxModelMixin


def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
return jnp.matmul(norm_emb_1, norm_emb_2.T)


@flax_register_to_config
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
projection_dim: int = 768
# CLIPVisionConfig fields
vision_config: dict = field(default_factory=dict)
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
config: CLIPConfig
dtype: jnp.dtype = jnp.float32

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensor
input_shape = (
1,
self.vision_config["image_size"],
self.vision_config["image_size"],
self.vision_config["num_channels"],
)
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, pixel_values)["params"]

def setup(self):
clip_vision_config = CLIPVisionConfig(**self.vision_config)
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False)

self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
self.special_care_embeds = self.param(
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
)

self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
Expand Down Expand Up @@ -109,3 +89,59 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images):
)

return images, has_nsfw_concepts


class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
config_class = CLIPConfig
main_input_name = "clip_input"
module_class = FlaxStableDiffusionSafetyCheckerModule

def __init__(
self,
config: CLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if input_shape is None:
input_shape = (1, 224, 224, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

random_params = self.module.init(rngs, clip_input)["params"]

return random_params

def __call__(
self,
clip_input,
params: dict = None,
):
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))

return self.module.apply(
{"params": params or self.params},
jnp.array(clip_input, dtype=jnp.float32),
rngs={},
)

def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)

return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)