1313# limitations under the License.
1414"""Utility functions that help in the computation of per-example gradient norms."""
1515
16- from typing import Any , Callable , Dict , Iterable , List , Optional , Text , Tuple , Union
16+ from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Text , Tuple , Union
1717
1818from absl import logging
1919import tensorflow as tf
@@ -36,19 +36,6 @@ def has_internal_compute_graph(input_object: Any):
3636 )
3737
3838
39- def _get_internal_layers (
40- input_layer : tf .keras .layers .Layer ,
41- ) -> List [tf .keras .layers .Layer ]:
42- """Returns a list of layers that are nested within a given layer."""
43- internal_layers = []
44- if isinstance (input_layer , tf .keras .Model ) and hasattr (input_layer , 'layers' ):
45- for layer in input_layer .layers :
46- internal_layers .extend (_get_internal_layers (layer ))
47- else :
48- internal_layers .append (input_layer )
49- return internal_layers
50-
51-
5239def model_forward_pass (
5340 input_model : tf .keras .Model ,
5441 inputs : PackedTensors ,
@@ -114,18 +101,10 @@ def generator_fn(layer_instance, args, kwargs):
114101 generator_outputs_list .extend (node_generator_outputs )
115102 else :
116103 # Otherwise, we parse the node directly.
117- node_layers = _get_internal_layers (node .layer )
118- for layer in node_layers :
119- node_layer_outputs , layer_generator_outputs = generator_fn (
120- layer , args , kwargs
121- )
122- generator_outputs_list .append (layer_generator_outputs )
123- args = (
124- node_layer_outputs
125- if isinstance (node_layer_outputs , tuple )
126- else (node_layer_outputs ,)
127- )
128- kwargs = {}
104+ node_layer_outputs , layer_generator_outputs = generator_fn (
105+ node .layer , args , kwargs
106+ )
107+ generator_outputs_list .append (layer_generator_outputs )
129108
130109 # Update the current dictionary of inputs for the next node.
131110 for x_id , y in zip (
@@ -163,9 +142,8 @@ def all_trainable_layers_are_registered(
163142 False otherwise.
164143 """
165144 for layer in input_model .layers :
166- for sublayer in _get_internal_layers (layer ):
167- if not layer_registry .is_elem (sublayer ) and sublayer .trainable_variables :
168- return False
145+ if not layer_registry .is_elem (layer ) and layer .trainable_variables :
146+ return False
169147 return True
170148
171149
@@ -213,17 +191,53 @@ def add_noise(g):
213191
214192def generate_model_outputs_using_core_keras_layers (
215193 input_model : tf .keras .Model ,
194+ custom_layer_set : Optional [Set [type ]] = None , # pylint: disable=g-bare-generic
216195) -> PackedTensors :
217- """Returns the model outputs generated by only core Keras layers."""
218- cust_obj_dict = dict .copy (tf .keras .utils .get_custom_objects ())
219- cust_hash_set = set ([hash (v ) for v in cust_obj_dict .values ()])
196+ """Returns the model outputs generated by only core Keras layers.
197+
198+ Args:
199+ input_model: A `tf.keras.Model` instance to obtain outputs from.
200+ custom_layer_set: An optional `set` of custom layers to expand. If `None`,
201+ then this is the set of all registered custom Keras layers.
202+
203+ Returns:
204+ A `tf.Tensor` that is the result of `input_model(input_model.inputs)`
205+ using only Keras layers that are not in `custom_layer_set`.
206+ """
207+ # Set up helper variables and functions.
208+ custom_layer_set = (
209+ custom_layer_set or tf .keras .utils .get_custom_objects ().values ()
210+ )
211+
212+ def _is_core (layer_instance ):
213+ return type (layer_instance ) not in custom_layer_set
220214
221215 def generator_fn (layer_instance , args , kwargs ):
222- if hash (layer_instance .__class__ ) in cust_hash_set :
223- # Using `.call()` does not register the layer in the compute graph of
224- # a forward pass.
225- return layer_instance .call (* args , ** kwargs ), None
226- else :
227- return layer_instance (* args , ** kwargs ), None
216+ # Using `.call()` does not register the layer in the compute graph of
217+ # a forward pass.
218+ layer_outputs = (
219+ layer_instance (* args , ** kwargs )
220+ if _is_core (layer_instance )
221+ else layer_instance .call (* args , ** kwargs )
222+ )
223+ return layer_outputs , None
224+
225+ # Return early if all the existing layers contain only core layers.
226+ if all (_is_core (layer ) for layer in input_model .layers ):
227+ return model_forward_pass (input_model , input_model .inputs )[0 ]
228228
229- return model_forward_pass (input_model , input_model .inputs , generator_fn )[0 ]
229+ # Do a forward pass to expand the outermost layers.
230+ candidate_outputs , _ = model_forward_pass (
231+ input_model , input_model .inputs , generator_fn
232+ )
233+
234+ # The following recursion is inefficient because it recursively builds `n`
235+ # Keras model graphs, where `n` is the number of recursive calls. However,
236+ # it appears to be the only valid approach without accessing Keras's internal
237+ # functions (e.g., `keras.engine.functional._map_graph_network()`).
238+ cleaned_model = tf .keras .Model (
239+ inputs = input_model .inputs , outputs = candidate_outputs
240+ )
241+ return generate_model_outputs_using_core_keras_layers (
242+ cleaned_model , custom_layer_set
243+ )
0 commit comments