Skip to content

Commit 2c9ae8a

Browse files
authored
Introduces MaybeApply layer. (keras-team#435)
* Added MaybeApply layer. * Changed MaybeApply to override _augment method. * Added seed to maybe_apply_test random generator. * Added seed to layer in batched input test. * Fixed MaybeApply docs.
1 parent a6ece4a commit 2c9ae8a

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

keras_cv/layers/preprocessing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from keras_cv.layers.preprocessing.fourier_mix import FourierMix
3737
from keras_cv.layers.preprocessing.grayscale import Grayscale
3838
from keras_cv.layers.preprocessing.grid_mask import GridMask
39+
from keras_cv.layers.preprocessing.maybe_apply import MaybeApply
3940
from keras_cv.layers.preprocessing.mix_up import MixUp
4041
from keras_cv.layers.preprocessing.posterization import Posterization
4142
from keras_cv.layers.preprocessing.rand_augment import RandAugment
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2022 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import tensorflow as tf
15+
16+
17+
@tf.keras.utils.register_keras_serializable(package="keras_cv")
18+
class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
19+
"""Apply provided layer to random elements in a batch.
20+
21+
Args:
22+
layer: a keras `Layer` or `BaseImageAugmentationLayer`. This layer will be
23+
applied to randomly chosen samples in a batch. Layer should not modify the
24+
size of provided inputs.
25+
rate: controls the frequency of applying the layer. 1.0 means all elements in
26+
a batch will be modified. 0.0 means no elements will be modified.
27+
Defaults to 0.5.
28+
auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for
29+
batched input. Setting this to True might give better performance but
30+
currently doesn't work with XLA. Defaults to False.
31+
seed: integer, controls random behaviour.
32+
33+
Example usage:
34+
```
35+
# Let's declare an example layer that will set all image pixels to zero.
36+
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})
37+
38+
# Create a small batch of random, single-channel, 2x2 images:
39+
images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1])
40+
print(images[..., 0])
41+
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
42+
# array([[[0.08216608, 0.40928006],
43+
# [0.39318466, 0.3162533 ]],
44+
#
45+
# [[0.34717774, 0.73199546],
46+
# [0.56369007, 0.9769211 ]],
47+
#
48+
# [[0.55243933, 0.13101244],
49+
# [0.2941643 , 0.5130266 ]],
50+
#
51+
# [[0.38977218, 0.80855536],
52+
# [0.6040567 , 0.10502195]],
53+
#
54+
# [[0.51828027, 0.12730157],
55+
# [0.288486 , 0.252975 ]]], dtype=float32)>
56+
57+
# Apply the layer with 50% probability:
58+
maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234)
59+
outputs = maybe_apply(images)
60+
print(outputs[..., 0])
61+
# <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy=
62+
# array([[[0. , 0. ],
63+
# [0. , 0. ]],
64+
#
65+
# [[0.34717774, 0.73199546],
66+
# [0.56369007, 0.9769211 ]],
67+
#
68+
# [[0.55243933, 0.13101244],
69+
# [0.2941643 , 0.5130266 ]],
70+
#
71+
# [[0.38977218, 0.80855536],
72+
# [0.6040567 , 0.10502195]],
73+
#
74+
# [[0. , 0. ],
75+
# [0. , 0. ]]], dtype=float32)>
76+
77+
# We can observe that the layer has been randomly applied to 2 out of 5 samples.
78+
```
79+
"""
80+
81+
def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs):
82+
super().__init__(seed=seed, **kwargs)
83+
84+
if not (0 <= rate <= 1.0):
85+
raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}")
86+
87+
self._layer = layer
88+
self._rate = rate
89+
self.auto_vectorize = auto_vectorize
90+
self.seed = seed
91+
92+
def _augment(self, inputs):
93+
if self._random_generator.random_uniform(shape=()) > 1.0 - self._rate:
94+
return self._layer(inputs)
95+
else:
96+
return inputs
97+
98+
def get_config(self):
99+
config = super().get_config()
100+
config.update(
101+
{
102+
"rate": self._rate,
103+
"layer": self._layer,
104+
"seed": self.seed,
105+
"auto_vectorize": self.auto_vectorize,
106+
}
107+
)
108+
return config
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2022 The KerasCV Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import tensorflow as tf
15+
from absl.testing import parameterized
16+
17+
from keras_cv.layers.preprocessing.maybe_apply import MaybeApply
18+
19+
20+
class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer):
21+
"""Zero out all entries, for testing purposes."""
22+
23+
def __init__(self):
24+
super(ZeroOut, self).__init__()
25+
26+
def augment_image(self, image, transformation=None):
27+
return 0 * image
28+
29+
def augment_label(self, label, transformation=None):
30+
return 0 * label
31+
32+
def augment_bounding_box(self, bounding_box, transformation=None):
33+
return 0 * bounding_box
34+
35+
36+
class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase):
37+
rng = tf.random.Generator.from_seed(seed=1234)
38+
39+
@parameterized.parameters([-0.5, 1.7])
40+
def test_raises_error_on_invalid_rate_parameter(self, invalid_rate):
41+
with self.assertRaises(ValueError):
42+
MaybeApply(rate=invalid_rate, layer=ZeroOut())
43+
44+
def test_works_with_batched_input(self):
45+
batch_size = 32
46+
dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3))
47+
layer = MaybeApply(rate=0.5, layer=ZeroOut(), seed=1234)
48+
49+
outputs = layer(dummy_inputs)
50+
num_zero_inputs = self._num_zero_batches(dummy_inputs)
51+
num_zero_outputs = self._num_zero_batches(outputs)
52+
53+
self.assertEqual(num_zero_inputs, 0)
54+
self.assertLess(num_zero_outputs, batch_size)
55+
self.assertGreater(num_zero_outputs, 0)
56+
57+
@staticmethod
58+
def _num_zero_batches(images):
59+
num_batches = tf.shape(images)[0]
60+
num_non_zero_batches = tf.math.count_nonzero(
61+
tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32
62+
)
63+
return num_batches - num_non_zero_batches
64+
65+
def test_inputs_unchanged_with_zero_rate(self):
66+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
67+
layer = MaybeApply(rate=0.0, layer=ZeroOut())
68+
69+
outputs = layer(dummy_inputs)
70+
71+
self.assertAllClose(outputs, dummy_inputs)
72+
73+
def test_all_inputs_changed_with_rate_equal_to_one(self):
74+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
75+
layer = MaybeApply(rate=1.0, layer=ZeroOut())
76+
77+
outputs = layer(dummy_inputs)
78+
79+
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))
80+
81+
def test_works_with_single_image(self):
82+
dummy_inputs = self.rng.uniform(shape=(224, 224, 3))
83+
layer = MaybeApply(rate=1.0, layer=ZeroOut())
84+
85+
outputs = layer(dummy_inputs)
86+
87+
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))
88+
89+
def test_can_modify_label(self):
90+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
91+
dummy_labels = tf.ones(shape=(32, 2))
92+
layer = MaybeApply(rate=1.0, layer=ZeroOut())
93+
94+
outputs = layer({"images": dummy_inputs, "labels": dummy_labels})
95+
96+
self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels))
97+
98+
def test_can_modify_bounding_box(self):
99+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
100+
dummy_boxes = tf.ones(shape=(32, 4))
101+
layer = MaybeApply(rate=1.0, layer=ZeroOut())
102+
103+
outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes})
104+
105+
self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes))
106+
107+
def test_works_with_native_keras_layers(self):
108+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
109+
zero_out = tf.keras.layers.Lambda(lambda x: {"images": 0 * x["images"]})
110+
layer = MaybeApply(rate=1.0, layer=zero_out)
111+
112+
outputs = layer(dummy_inputs)
113+
114+
self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs))
115+
116+
def test_works_with_xla(self):
117+
dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3))
118+
# auto_vectorize=True will crash XLA
119+
layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False)
120+
121+
@tf.function(jit_compile=True)
122+
def apply(x):
123+
return layer(x)
124+
125+
apply(dummy_inputs)

keras_cv/layers/serialization_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
134134
"seed": 1234,
135135
},
136136
),
137+
(
138+
"MaybeApply",
139+
preprocessing.MaybeApply,
140+
{
141+
"rate": 0.5,
142+
"layer": None,
143+
"seed": 1234,
144+
},
145+
),
137146
)
138147
def test_layer_serialization(self, layer_cls, init_args):
139148
layer = layer_cls(**init_args)

0 commit comments

Comments
 (0)