-
Notifications
You must be signed in to change notification settings - Fork 331
Introduces MaybeApply layer. #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
49b9206
618829b
5136a9d
39820d4
7fa1fe7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # Copyright 2022 The KerasCV Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="keras_cv") | ||
| class MaybeApply(tf.keras.__internal__.layers.BaseImageAugmentationLayer): | ||
| """Apply provided layer to random elements in a batch. | ||
|
|
||
| Args: | ||
| layer: a keras Layer or BaseImageAugmentationLayer. This layer will be applied | ||
| to randomly chosen samples in a batch. | ||
| rate: controls the frequency of applying the layer. 1.0 means all elements in | ||
| a batch will be modified. 0.0 means no elements will be modified. | ||
| Defaults to 0.5. | ||
| auto_vectorize: bool, whether to use tf.vectorized_map or tf.map_fn for | ||
| batched input. Setting this to True might give better performance but | ||
| currently doesn't work with XLA. Defaults to False. | ||
| seed: integer, controls random behaviour. | ||
|
|
||
| Example usage: | ||
LukeWood marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Let's declare an example layer that will set all image pixels to zero. | ||
| zero_out = tf.keras.layers.Lambda(lambda x: 0 * x) | ||
|
|
||
| # Create a small batch of random, single-channel, 2x2 images: | ||
| images = tf.random.stateless_uniform(shape=(5, 2, 2, 1), seed=[0, 1]) | ||
| print(images[..., 0]) | ||
| # <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
| # array([[[0.08216608, 0.40928006], | ||
| # [0.39318466, 0.3162533 ]], | ||
| # | ||
| # [[0.34717774, 0.73199546], | ||
| # [0.56369007, 0.9769211 ]], | ||
| # | ||
| # [[0.55243933, 0.13101244], | ||
| # [0.2941643 , 0.5130266 ]], | ||
| # | ||
| # [[0.38977218, 0.80855536], | ||
| # [0.6040567 , 0.10502195]], | ||
| # | ||
| # [[0.51828027, 0.12730157], | ||
| # [0.288486 , 0.252975 ]]], dtype=float32)> | ||
|
|
||
| # Apply the layer with 50% probability: | ||
| maybe_apply = MaybeApply(layer=zero_out, rate=0.5, seed=1234) | ||
| outputs = maybe_apply(images) | ||
| print(outputs[..., 0]) | ||
| # <tf.Tensor: shape=(5, 2, 2), dtype=float32, numpy= | ||
| # array([[[0. , 0. ], | ||
| # [0. , 0. ]], | ||
| # | ||
| # [[0.34717774, 0.73199546], | ||
| # [0.56369007, 0.9769211 ]], | ||
| # | ||
| # [[0.55243933, 0.13101244], | ||
| # [0.2941643 , 0.5130266 ]], | ||
| # | ||
| # [[0.38977218, 0.80855536], | ||
| # [0.6040567 , 0.10502195]], | ||
| # | ||
| # [[0. , 0. ], | ||
| # [0. , 0. ]]], dtype=float32)> | ||
|
|
||
| # We can observe that the layer has been randomly applied to 2 out of 5 batches. | ||
| """ | ||
|
|
||
| def __init__(self, layer, rate=0.5, auto_vectorize=False, seed=None, **kwargs): | ||
| super().__init__(seed=seed, **kwargs) | ||
|
|
||
| if not (0 <= rate <= 1.0): | ||
| raise ValueError(f"rate must be in range [0, 1]. Received rate: {rate}") | ||
|
|
||
| self._layer = layer | ||
| self._rate = rate | ||
| self.auto_vectorize = auto_vectorize | ||
| self.seed = seed | ||
|
|
||
| def augment_image(self, image, transformation=None): | ||
| should_apply = transformation > 1.0 - self._rate | ||
| return tf.cond(should_apply, lambda: self._layer(image), lambda: image) | ||
LukeWood marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def get_random_transformation(self, image=None, label=None, bounding_box=None): | ||
| return self._random_generator.random_uniform(shape=()) | ||
|
|
||
| def augment_label(self, label, transformation=None): | ||
| should_apply = transformation > 1.0 - self._rate | ||
| return tf.cond( | ||
| should_apply, lambda: self._layer.augment_label(label), lambda: label | ||
| ) | ||
|
|
||
| def augment_bounding_box(self, bounding_box, transformation=None): | ||
| should_apply = transformation > 1.0 - self._rate | ||
| return tf.cond( | ||
| should_apply, | ||
| lambda: self._layer.augment_bounding_box(bounding_box), | ||
| lambda: bounding_box, | ||
| ) | ||
|
|
||
| def get_config(self): | ||
| config = super().get_config() | ||
| config.update( | ||
| { | ||
| "rate": self._rate, | ||
| "layer": self._layer, | ||
| "seed": self.seed, | ||
| "auto_vectorize": self.auto_vectorize, | ||
| } | ||
| ) | ||
| return config | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # Copyright 2022 The KerasCV Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import tensorflow as tf | ||
| from absl.testing import parameterized | ||
|
|
||
| from keras_cv.layers.preprocessing.maybe_apply import MaybeApply | ||
|
|
||
|
|
||
| class ZeroOut(tf.keras.__internal__.layers.BaseImageAugmentationLayer): | ||
| """Zero out all entries, for testing purposes.""" | ||
|
|
||
| def __init__(self): | ||
| super(ZeroOut, self).__init__() | ||
|
|
||
| def augment_image(self, image, transformation=None): | ||
| return 0 * image | ||
|
|
||
| def augment_label(self, label, transformation=None): | ||
| return 0 * label | ||
|
|
||
| def augment_bounding_box(self, bounding_box, transformation=None): | ||
| return 0 * bounding_box | ||
|
|
||
|
|
||
| class MaybeApplyTest(tf.test.TestCase, parameterized.TestCase): | ||
| rng = tf.random.Generator.from_non_deterministic_state() | ||
|
|
||
| @parameterized.parameters([-0.5, 1.7]) | ||
| def test_raises_error_on_invalid_rate_parameter(self, invalid_rate): | ||
| with self.assertRaises(ValueError): | ||
| MaybeApply(rate=invalid_rate, layer=ZeroOut()) | ||
|
|
||
| def test_works_with_batched_input(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you pass a seed so that this test is not potentially flaky? Given, it is 1/2^32 flakiness, but still may as well seed it.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added seed to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! does this seed the layer too?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. Added |
||
| batch_size = 32 | ||
| dummy_inputs = self.rng.uniform(shape=(batch_size, 224, 224, 3)) | ||
| layer = MaybeApply(rate=0.5, layer=ZeroOut()) | ||
|
|
||
| outputs = layer(dummy_inputs) | ||
| num_zero_inputs = self._num_zero_batches(dummy_inputs) | ||
| num_zero_outputs = self._num_zero_batches(outputs) | ||
|
|
||
| self.assertEqual(num_zero_inputs, 0) | ||
| self.assertLess(num_zero_outputs, batch_size) | ||
| self.assertGreater(num_zero_outputs, 0) | ||
|
|
||
| @staticmethod | ||
| def _num_zero_batches(images): | ||
| num_batches = tf.shape(images)[0] | ||
| num_non_zero_batches = tf.math.count_nonzero( | ||
| tf.math.count_nonzero(images, axis=[1, 2, 3]), dtype=tf.int32 | ||
| ) | ||
| return num_batches - num_non_zero_batches | ||
|
|
||
| def test_inputs_unchanged_with_zero_rate(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| layer = MaybeApply(rate=0.0, layer=ZeroOut()) | ||
|
|
||
| outputs = layer(dummy_inputs) | ||
|
|
||
| self.assertAllClose(outputs, dummy_inputs) | ||
|
|
||
| def test_all_inputs_changed_with_rate_equal_to_one(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
|
||
| outputs = layer(dummy_inputs) | ||
|
|
||
| self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
|
||
| def test_works_with_single_image(self): | ||
| dummy_inputs = self.rng.uniform(shape=(224, 224, 3)) | ||
| layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
|
||
| outputs = layer(dummy_inputs) | ||
|
|
||
| self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
|
||
| def test_can_modify_label(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| dummy_labels = tf.ones(shape=(32, 2)) | ||
| layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
|
||
| outputs = layer({"images": dummy_inputs, "labels": dummy_labels}) | ||
|
|
||
| self.assertAllEqual(outputs["labels"], tf.zeros_like(dummy_labels)) | ||
|
|
||
| def test_can_modify_bounding_box(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| dummy_boxes = tf.ones(shape=(32, 4)) | ||
| layer = MaybeApply(rate=1.0, layer=ZeroOut()) | ||
|
|
||
| outputs = layer({"images": dummy_inputs, "bounding_boxes": dummy_boxes}) | ||
|
|
||
| self.assertAllEqual(outputs["bounding_boxes"], tf.zeros_like(dummy_boxes)) | ||
|
|
||
| def test_works_with_native_keras_layers(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| zero_out = tf.keras.layers.Lambda(lambda x: 0 * x) | ||
| layer = MaybeApply(rate=1.0, layer=zero_out) | ||
|
|
||
| outputs = layer(dummy_inputs) | ||
|
|
||
| self.assertAllEqual(outputs, tf.zeros_like(dummy_inputs)) | ||
|
|
||
| def test_works_with_xla(self): | ||
| dummy_inputs = self.rng.uniform(shape=(32, 224, 224, 3)) | ||
| # auto_vectorize=True will crash XLA | ||
| layer = MaybeApply(rate=0.5, layer=ZeroOut(), auto_vectorize=False) | ||
|
|
||
| @tf.function(jit_compile=True) | ||
| def apply(x): | ||
| return layer(x) | ||
|
|
||
| apply(dummy_inputs) | ||
Uh oh!
There was an error while loading. Please reload this page.