Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]

GeneratorFunction = Optional[Callable[[Any, Tuple, Dict], Tuple[Any, Any]]]
GeneratorFunction = Callable[[Any, Tuple, Dict], Tuple[Any, Any]]

LayerFunction = Callable[[tf.keras.layers.Layer], None]


def has_internal_compute_graph(input_object: Any):
Expand Down Expand Up @@ -52,7 +54,7 @@ def _get_internal_layers(
def model_forward_pass(
input_model: tf.keras.Model,
inputs: PackedTensors,
generator_fn: GeneratorFunction = None,
generator_fn: Optional[GeneratorFunction] = None,
) -> Tuple[PackedTensors, List[Any]]:
"""Does a forward pass of a model and returns useful intermediates.

Expand Down Expand Up @@ -211,6 +213,55 @@ def add_noise(g):
return tf.nest.map_structure(add_noise, clipped_grads)


def depth_first_backward_pass(
outputs: PackedTensors, layer_function: Optional[LayerFunction] = None
):
"""Performs a depth-first traversal on a given set of model outputs.

This function is simplified version of
`tf.keras.engine.functional._build_map()` that allows additional side-effects
performed by an (optional) layer function.

NOTE: The behavior, name, and implementation details of this function may
change in future versions. Users should avoid using it outside of this module.

Args:
outputs: A `PackedTensor` that should be generated by calling a
`tf.keras.Model` on a set of non-eager inputs.
layer_function: A callable that consumes a `tf.keras.layers.Layer`. This
callable is applied to every layer in the DAG that generates `outputs`.
"""

# Helper function that performs the traversal.
finished_nodes = set()
nodes_in_progress = set()

def graph_crawler(tensor: tf.Tensor):
layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access
node = layer._inbound_nodes[node_index] # pylint: disable=protected-access
# Avoid duplicating work on shared subgraphs.
if node in finished_nodes:
return
# Check if we encountered a cycle.
if node in nodes_in_progress:
raise ValueError(
f'Tensor {tensor} from layer "{layer.name}" is part of a cycle.'
)
# Apply side-effects and go to the next node (pre-order traversal).
if layer_function is not None:
layer_function(layer)
nodes_in_progress.add(node)
if not node.is_input:
for tensor in node.keras_inputs:
graph_crawler(tensor)
finished_nodes.add(node)
nodes_in_progress.remove(node)

# Traverse over the outputs.
for output in tf.nest.flatten(outputs):
graph_crawler(output)


def generate_model_outputs_using_core_keras_layers(
input_model: tf.keras.Model,
) -> PackedTensors:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,65 @@ def test_outputs_are_consistent(
self.assertAllClose(computed_outputs, true_outputs)


class DepthFirstBackwardPassTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.product(
depth=[1, 2],
input_packing_type=[None, tuple, list, dict],
output_packing_type=[None, tuple, list, dict],
)
def test_layer_function(self, depth, input_packing_type, output_packing_type):
num_dims = 3
num_units = 5
num_inputs = 1 if input_packing_type is None else 2
num_outputs = 1 if output_packing_type is None else 2
sample_inputs = [tf.keras.Input((num_dims,)) for i in range(num_inputs)]
temp_sum = tf.stack(sample_inputs, axis=0)
sample_sum = [
tf.multiply(temp_sum, float(i + 1.0)) for i in range(num_outputs)
]
sample_outputs = sample_sum
for _ in range(depth):
sample_outputs = [
tf.keras.layers.Dense(num_units)(t) for t in sample_outputs
]

# Pack inputs.
if input_packing_type is None:
inputs = sample_inputs[0]
elif input_packing_type is not dict:
inputs = input_packing_type(sample_inputs)
else:
inputs = {}
keys = [str(i) for i in range(len(sample_inputs))]
for k, v in zip(keys, sample_inputs):
inputs[k] = v

# Pack outputs.
if output_packing_type is None:
outputs = sample_outputs[0]
elif output_packing_type is not dict:
outputs = output_packing_type(sample_outputs)
else:
outputs = {}
keys = [str(i) for i in range(len(sample_outputs))]
for k, v in zip(keys, sample_outputs):
outputs[k] = v

# Append the trainable layers into a list.
layer_list = []

def layer_function(layer):
if layer.trainable_variables:
layer_list.append(layer)

# Run the traversal and verify the outputs that are relevant to
# the above layer function.
gradient_clipping_utils.depth_first_backward_pass(outputs, layer_function)
self.assertLen(layer_list, num_outputs * depth)
for l in layer_list:
self.assertIsInstance(l, tf.keras.layers.Dense)


if __name__ == '__main__':
tf.test.main()