Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from keras_cv.layers.preprocessing.fourier_mix import FourierMix
from keras_cv.layers.preprocessing.grayscale import Grayscale
from keras_cv.layers.preprocessing.grid_mask import GridMask
from keras_cv.layers.preprocessing.maybe_apply import MaybeApply
from keras_cv.layers.preprocessing.mix_up import MixUp
from keras_cv.layers.preprocessing.posterization import Posterization
from keras_cv.layers.preprocessing.rand_augment import RandAugment
Expand Down
120 changes: 120 additions & 0 deletions keras_cv/layers/preprocessing/maybe_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
"""Apply provided layer to random elements in a batch.

Args:
layer: a keras Layer or BaseImageAugmentationLayer. This layer will be applied
to randomly chosen samples in a batch.
rate: controls the frequency of applying the layer. 1.0 means all elements in
a batch will be modified. 0.0 means no elements will be modified.
Defaults to 0.5.
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for
batched input. Setting this to True might give better performance but
currently doesn't work with XLA. Defaults to False.
seed: integer, controls random behaviour.

Example usage:
# Let's declare an example layer that will set all image pixels to zero.
zero_out = tf.keras.layers.Lambda(lambda x: 0 * x)

# Create a small batch of random, single-channel, 2x2 images:
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])
print(images[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0.08216608, 0.40928006],
# [0.39318466, 0.3162533 ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0.51828027, 0.12730157],
# [0.288486 , 0.252975 ]]], dtype=float32)>

# Apply the layer with 50% probability:
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234)
outputs = maybe_apply(images)
print(outputs[..., 0])
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
# array([[[0. , 0. ],
# [0. , 0. ]],
#
# [[0.34717774, 0.73199546],
# [0.56369007, 0.9769211 ]],
#
# [[0.55243933, 0.13101244],
# [0.2941643 , 0.5130266 ]],
#
# [[0.38977218, 0.80855536],
# [0.6040567 , 0.10502195]],
#
# [[0. , 0. ],
# [0. , 0. ]]], dtype=float32)>

# We can observe that the layer has been randomly applied to 2 out of 5 batches.
"""

def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs):
super().__init__(seed=seed, **kwargs)

if not (0 <= rate <= 1.0):
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}")

self._layer = layer
self._rate = rate
self.auto_vectorize = auto_vectorize
self.seed = seed

def augment_image(self, image, transformation=None):
should_apply = transformation > 1.0 - self._rate
return tf.cond(should_apply, lambda: self._layer(image), lambda: image)

def get_random_transformation(self, image=None, label=None, bounding_box=None):
return self._random_generator.random_uniform(shape=())

def augment_label(self, label, transformation=None):
should_apply = transformation > 1.0 - self._rate
return tf.cond(
should_apply, lambda: self._layer.augment_label(label), lambda: label
)

def augment_bounding_box(self, bounding_box, transformation=None):
should_apply = transformation > 1.0 - self._rate
return tf.cond(
should_apply,
lambda: self._layer.augment_bounding_box(bounding_box),
lambda: bounding_box,
)

def get_config(self):
config = super().get_config()
config.update(
{
"rate": self._rate,
"layer": self._layer,
"seed": self.seed,
"auto_vectorize": self.auto_vectorize,
}
)
return config
125 changes: 125 additions & 0 deletions keras_cv/layers/preprocessing/maybe_apply_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from absl.testing import parameterized

from keras_cv.layers.preprocessing.maybe_apply import MaybeApply


class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
"""Zero out all entries, for testing purposes."""

def __init__(self):
super(ZeroOut, self).__init__()

def augment_image(self, image, transformation=None):
return 0 * image

def augment_label(self, label, transformation=None):
return 0 * label

def augment_bounding_box(self, bounding_box, transformation=None):
return 0 * bounding_box


class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase):
rng = tf.random.Generator.from_non_deterministic_state()

@parameterized.parameters([-0.5, 1.7])
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
with self.assertRaises(ValueError):
MaybeApply(rate=invalid_rate, layer=ZeroOut())

def test_works_with_batched_input(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you pass a seed so that this test is not potentially flaky? Given, it is 1/2^32 flakiness, but still may as well seed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added seed to rng on line 37.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! does this seed the layer too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Added seed param to layer as well.

batch_size = 32
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
layer = MaybeApply(rate=0.5, layer=ZeroOut())

outputs = layer(dummy_inputs)
num_zero_inputs = self._num_zero_batches(dummy_inputs)
num_zero_outputs = self._num_zero_batches(outputs)

self.assertEqual(num_zero_inputs, 0)
self.assertLess(num_zero_outputs, batch_size)
self.assertGreater(num_zero_outputs, 0)

@staticmethod
def _num_zero_batches(images):
num_batches = tf.shape(images)[0]
num_non_zero_batches = tf.math.count_nonzero(
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32
)
return num_batches - num_non_zero_batches

def test_inputs_unchanged_with_zero_rate(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = MaybeApply(rate=0.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllClose(outputs, dummy_inputs)

def test_all_inputs_changed_with_rate_equal_to_one(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_works_with_single_image(self):
dummy_inputs = self.rng.uniform(shape=(224, 224, 3))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_can_modify_label(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_labels = tf.ones(shape=(32, 2))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer({"images": dummy_inputs, "labels": dummy_labels})

self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels))

def test_can_modify_bounding_box(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
dummy_boxes = tf.ones(shape=(32, 4))
layer = MaybeApply(rate=1.0, layer=ZeroOut())

outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes})

self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes))

def test_works_with_native_keras_layers(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
zero_out = tf.keras.layers.Lambda(lambda x: 0 * x)
layer = MaybeApply(rate=1.0, layer=zero_out)

outputs = layer(dummy_inputs)

self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))

def test_works_with_xla(self):
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
# auto_vectorize=True will crash XLA
layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False)

@tf.function(jit_compile=True)
def apply(x):
return layer(x)

apply(dummy_inputs)
9 changes: 9 additions & 0 deletions keras_cv/layers/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
"seed": 1234,
},
),
(
"MaybeApply",
preprocessing.MaybeApply,
{
"rate": 0.5,
"layer": None,
"seed": 1234,
},
),
)
def test_layer_serialization(self, layer_cls, init_args):
layer = layer_cls(**init_args)
Expand Down