diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index 80648f6c77..10a68858f2 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -37,10 +37,12 @@ def test_simple(self): example_inputs = m.example_inputs() ep = export_for_training(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map = self._extract_from_node_source(m) - self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + self.assertEqual( + len(set(from_node_source_map.values())), len(from_node_source_map) + ) @unittest.skip("debug flow not working on model with conditional control flow") def test_control_flow(self): @@ -49,10 +51,12 @@ def test_control_flow(self): ep = export_for_training(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map = self._extract_from_node_source(m) - self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map)) + self.assertEqual( + len(set(from_node_source_map.values())), len(from_node_source_map) + ) def test_copy_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() @@ -60,26 +64,29 @@ def test_copy_preserve_handle(self): ep = torch.export.export(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) ep_copy = copy.copy(ep) - debug_handle_map = self._extract_debug_handles(ep_copy.module()) + from_node_source_map = self._extract_from_node_source(ep_copy.module()) - self._assert_each_node_has_debug_handle(ep) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self._assert_each_node_has_from_node_source(ep) + self.assertEqual(from_node_source_map, from_node_source_map_ref) def test_deepcopy_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() ep = torch.export.export(m, example_inputs, strict=True) - debug_handle_map_ref = self._extract_debug_handles(ep.module()) + from_node_source_map_ref = self._extract_from_node_source(ep.module()) ep_copy = copy.deepcopy(ep) - debug_handle_map = self._extract_debug_handles(ep_copy.module()) + from_node_source_map = self._extract_from_node_source(ep_copy.module()) - self._assert_each_node_has_debug_handle(ep.module()) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self._assert_each_node_has_from_node_source(ep.module()) + self.assertEqual(from_node_source_map, from_node_source_map_ref) + self.assertEqual( + set(from_node_source_map.values()), set(from_node_source_map_ref.values()) + ) @unittest.skip( "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." @@ -90,16 +97,16 @@ def test_re_export_preserve_handle(self): ep = export_for_training(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) ep_reexport = export_for_training(m, example_inputs, strict=True) m_reexport = ep_reexport.module() - self._assert_each_node_has_debug_handle(m_reexport) - debug_handle_map = self._extract_debug_handles(m_reexport) + self._assert_each_node_has_from_node_source(m_reexport) + from_node_source_map = self._extract_from_node_source(m_reexport) - self.assertEqual(debug_handle_map, debug_handle_map_ref) + self.assertEqual(from_node_source_map, from_node_source_map_ref) @unittest.skip( "torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..." @@ -110,19 +117,19 @@ def test_run_decompositions_same_handle_id(self): ep = export_for_training(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - debug_handle_map_ref = self._extract_debug_handles(m) + self._assert_each_node_has_from_node_source(m) + from_node_source_map_ref = self._extract_from_node_source(m) ep_copy = copy.copy(ep) ep_copy = ep_copy.run_decompositions() m_decomposed = ep_copy.module() - self._assert_each_node_has_debug_handle(m_decomposed) - debug_handle_map = self._extract_debug_handles(m_decomposed) + self._assert_each_node_has_from_node_source(m_decomposed) + from_node_source_map = self._extract_from_node_source(m_decomposed) # checking the map still has the same ids, the node may change self.assertEqual( - set(debug_handle_map.values()), set(debug_handle_map_ref.values()) + set(from_node_source_map.values()), set(from_node_source_map_ref.values()) ) @unittest.skip( @@ -139,22 +146,23 @@ def test_run_decompositions_map_handle_to_new_nodes(self): ep = export_for_training(m, example_inputs, strict=True) m = ep.module() - self._assert_each_node_has_debug_handle(m) - pre_decomp_to_debug_handle_map_ref = ( - self._extract_debug_handles_with_prev_decomp_op(m) + self._assert_each_node_has_from_node_source(m) + pre_decomp_to_from_node_source_map_ref = ( + self._extract_from_node_source_with_prev_decomp_op(m) ) ep_copy = copy.copy(ep) ep_copy = ep_copy.run_decompositions() m_decomposed = ep_copy.module() - self._assert_each_node_has_debug_handle(m_decomposed) - pre_decomp_to_debug_handle_map = ( - self._extract_debug_handles_with_prev_decomp_op(m_decomposed) + self._assert_each_node_has_from_node_source(m_decomposed) + pre_decomp_to_from_node_source_map = ( + self._extract_from_node_source_with_prev_decomp_op(m_decomposed) ) - # checking the map still has the same ids, the node may change + # checking the map still has the same infos, the node may change self.assertEqual( - pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref + pre_decomp_to_from_node_source_map, + pre_decomp_to_from_node_source_map_ref, ) def test_prepare_for_propagation_comparison(self): @@ -178,18 +186,18 @@ def test_added_node_gets_unique_id(self) -> None: example_inputs = m.example_inputs() ep = export_for_training(m, example_inputs, strict=True) - ref_handles = self._extract_debug_handles(ep.module()) - ref_counter = Counter(ref_handles.values()) + ref_from_node_source = self._extract_from_node_source(ep.module()) + ref_counter = Counter(ref_from_node_source.values()) for k, v in ref_counter.items(): self.assertEqual( v, 1, - msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1", + msg=f"For from_node info {k}, there were {v} nodes with that info, but expected only 1", ) - # Now that we have unique ids, add a new node into the graph and re-generate - # to make sure that the new node gets a unique id. + # Now that we have unique infos, add a new node into the graph and re-generate + # to make sure that the new node gets a unique info. last_node = next(iter(reversed(ep.graph.nodes))) with ep.graph.inserting_before(last_node): arg = last_node.args[0] @@ -200,30 +208,39 @@ def test_added_node_gets_unique_id(self) -> None: arg.replace_all_uses_with(n, lambda x: x != n) ep.graph_module.recompile() - # Regenerate handles, make sure only the new relu node has a new id, and - # it doesn't clash with any of the existing ids. + # Regenerate from_node info, make sure only the new relu node has a new info, and + # it doesn't clash with any of the existing infos. m = ep.module() - self._assert_each_node_has_debug_handle(m) - handles_after_modification = self._extract_debug_handles(m) - handles_counter = Counter(handles_after_modification.values()) - for name, handle in ref_handles.items(): - self.assertIn(name, handles_after_modification) - # Check that handle was unchanged. - self.assertEqual(handles_after_modification[name], handle) + self._assert_each_node_has_from_node_source(m) + from_node_source_after_modification = self._extract_from_node_source(m) + from_node_source_counter = Counter(from_node_source_after_modification.values()) + for name, from_node_source in ref_from_node_source.items(): + self.assertIn(name, from_node_source_after_modification) + # Check that from_node info was unchanged. + self.assertEqual( + from_node_source_after_modification[name], from_node_source + ) # Check that total count was unchanged. - ref_count = ref_counter[handle] - after_count = handles_counter[handle] + ref_count = ref_counter[from_node_source] + after_count = from_node_source_counter[from_node_source] self.assertEqual( after_count, ref_count, - msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}", + msg=f"For from_node info {from_node_source}, there were {after_count} nodes with that info, but expected only {ref_count}", ) - # Check for relu specifically. Avoid hardcoding the handle id since it + # Check for relu specifically. Avoid hardcoding the from_node info since it # may change with future node ordering changes. - self.assertNotIn(handles_after_modification["relu_default"], ref_counter) - self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1) + self.assertNotIn( + from_node_source_after_modification["relu_default"], ref_counter + ) + self.assertEqual( + from_node_source_counter[ + from_node_source_after_modification["relu_default"] + ], + 1, + ) if __name__ == "__main__": diff --git a/torchao/quantization/pt2e/_numeric_debugger.py b/torchao/quantization/pt2e/_numeric_debugger.py index de1e1eee84..0346981391 100644 --- a/torchao/quantization/pt2e/_numeric_debugger.py +++ b/torchao/quantization/pt2e/_numeric_debugger.py @@ -30,6 +30,21 @@ log = logging.getLogger(__name__) +@dataclass(frozen=True) +class NodeSourceDebugInfo: + """ + Contains node source information for locating the node in the original graph. + This replaces the numeric debug handle approach with direct node source info. + """ + + # The name of the node in the graph, e.g. "conv2d" + name: str + + # The unique id of the graph that the node belongs to. + graph_id: int + + +# This function is no longer used for torchao debug flow, but is kept here for backward compatibility. def generate_numeric_debug_handle(ep: ExportedProgram) -> None: """ Attach numeric_debug_handle_id for all nodes in the graph module of the given @@ -84,53 +99,48 @@ def _assign_debug_handle(node: torch.fx.Node) -> None: bfs_trace_with_node_process(ep, _assign_debug_handle) -def _get_greatest_ancestor_node_source(node: Node) -> Optional["NodeSource"]: - if (node_source := node.meta.get(FROM_NODE_KEY)) is None: - return None +def _extract_node_source_debug_info(node: Node) -> Optional[NodeSourceDebugInfo]: + """ + Extract node source debug info from a node, or return None if the node + does not need to be traced. - node_source = node_source[-1] + Returns NodeSourceDebugInfo containing the name and graph_id from the + node's greatest ancestor node source, or None if the node is not in + the original graph. + """ - while len(node_source.from_node) > 0: - node_source = node_source.from_node[-1] + def _get_greatest_ancestor_node_source(node: Node) -> "NodeSource": + node_source = node.meta.get(FROM_NODE_KEY)[-1] - return node_source + while len(node_source.from_node) > 0: + node_source = node_source.from_node[-1] + return node_source -def _generate_debug_handle_from_node(node: Node) -> Optional[int]: - """ - Generate a debug handle based on node's oldest ancestor node's name - and graph id, or return None if the node does not need to be traced. + def _is_node_in_original_graph(node: Node) -> bool: + if ( + FROM_NODE_KEY not in node.meta + or node.meta[FROM_NODE_KEY] is None + or node.meta[FROM_NODE_KEY][-1].pass_name + == "ExportedProgram.module().unlift()" + ): + # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle + return False - This is a temporary function for migrating node tracing infra from - using debug handle to node.meta["from_node"]. The infrastructure will - depend on node.meta["from_node"] directly in the future, without the need - of debug handle as intermediate variable. - """ + return True if node.op == "placeholder" or node.op == "output": - # placeholder and output nodes don't have debug handle + # placeholder and output nodes don't have debug info return None - if ( - FROM_NODE_KEY not in node.meta - or node.meta[FROM_NODE_KEY] is None - or node.meta[FROM_NODE_KEY][-1].pass_name == "ExportedProgram.module().unlift()" - ): - # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle + if not _is_node_in_original_graph(node): return None greatest_ancestor_node_source = _get_greatest_ancestor_node_source(node) - if greatest_ancestor_node_source is None: - # This node is not part of the ExportedProgram.module().graph, so it doesn't have a debug handle - return None - - if greatest_ancestor_node_source.pass_name == "ExportedProgram.module().unlift()": - # uplifted nodes don't have debug handle - return None - - return hash( - greatest_ancestor_node_source.name + str(greatest_ancestor_node_source.graph_id) + return NodeSourceDebugInfo( + name=greatest_ancestor_node_source.name, + graph_id=greatest_ancestor_node_source.graph_id, ) @@ -192,14 +202,14 @@ class OutputLogger(torch.nn.Module): def __init__( self, - debug_handle: int, + debug_info: NodeSourceDebugInfo, node_name: Optional[str] = None, nn_module_stack: Optional[object] = None, ) -> None: super().__init__() self.node_name = node_name self.nn_module_stack = nn_module_stack - self.debug_handle = debug_handle + self.debug_info = debug_info self.stats: list[object] = [] def forward(self, x: object) -> object: @@ -208,15 +218,17 @@ def forward(self, x: object) -> object: def __extra_repr__(self) -> str: return ( - f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + f"debug_info={self.debug_info}, node_name={self.node_name}, " "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" ) -def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: +def _insert_logger( + model: GraphModule, node: Node, debug_info: NodeSourceDebugInfo +) -> Node: """For a given node, adds an OutputLogger that observes the output of that node, and all its users use the OutputLogger output instead. - The OutputLogger will contain the debug_handle which can be used to compare + The OutputLogger will contain the debug_info which can be used to compare graphs after transforms""" # to avoid circular dep @@ -229,7 +241,7 @@ def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: setattr( model, logger_name, - OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + OutputLogger(debug_info, node.name, node.meta.get("nn_module_stack")), ) logger_node = model.graph.call_module(logger_name, (node,), {}) @@ -259,8 +271,8 @@ def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: # don't change the original model model = copy.deepcopy(model) for n in model.graph.nodes: - if (numeric_debug_handle := _generate_debug_handle_from_node(n)) is not None: - _insert_logger(model, n, numeric_debug_handle) + if (debug_info := _extract_node_source_debug_info(n)) is not None: + _insert_logger(model, n, debug_info) model.recompile() return model @@ -310,7 +322,7 @@ def __post_init__(self) -> None: @dataclass(frozen=True) class NodeAccuracySummary: - handle: int + debug_info: NodeSourceDebugInfo actual_node_name: str actual_module_stack: str ref_node_name: str @@ -334,21 +346,21 @@ def _module_stack_to_str(module_stack: object) -> str: def extract_results_from_loggers( model: GraphModule, -) -> dict[int, tuple[Optional[str], object, list[object]]]: - """For a given model, extract the tensors stats and related information for each debug handle. +) -> dict[NodeSourceDebugInfo, tuple[Optional[str], object, list[object]]]: + """For a given model, extract the tensors stats and related information for each debug info. The reason we have a list of object, instead of Tensor is because the output of node may not be a Tensor, it could be (nested) list, tuple or dict as well. Returns: - A dict is keyed by the debug_handle id and the values are a list of object recorded + A dict is keyed by the NodeSourceDebugInfo and the values are a list of object recorded in loggers """ - # Results maps debug handle to a tensor list for each model being compared. - handles: dict[int, tuple[Optional[str], object, list[object]]] = {} - for _name, module in model.named_children(): + # Results maps debug info to a tensor list for each model being compared. + handles: dict[NodeSourceDebugInfo, tuple[Optional[str], object, list[object]]] = {} + for _, module in model.named_children(): if isinstance(module, OutputLogger) and len(module.stats) > 0: - handles[module.debug_handle] = ( + handles[module.debug_info] = ( module.node_name, module.nn_module_stack, module.stats, @@ -358,29 +370,33 @@ def extract_results_from_loggers( def compare_results( - ref_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], - actual_results: dict[int, tuple[Optional[str], object, list[torch.Tensor]]], -) -> dict[int, NodeAccuracySummary]: - """Given two dict mapping from `debug_handle_id` (int) to list of tensors - return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + ref_results: dict[ + NodeSourceDebugInfo, tuple[Optional[str], object, list[torch.Tensor]] + ], + actual_results: dict[ + NodeSourceDebugInfo, tuple[Optional[str], object, list[torch.Tensor]] + ], +) -> dict[NodeSourceDebugInfo, NodeAccuracySummary]: + """Given two dict mapping from `NodeSourceDebugInfo` to list of tensors + return a map from `NodeSourceDebugInfo` to `NodeAccuracySummary` that contains comparison information like SQNR, MSE etc. Args: - ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id - actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + ref_results (Dict[NodeSourceDebugInfo, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug info + actual_results (Dict[NodeSourceDebugInfo, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug info Returns: - Dict[int, NodeAccuracySummary] + Dict[NodeSourceDebugInfo, NodeAccuracySummary] """ comparisons = {} - for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): - if debug_handle not in actual_results: + for debug_info, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_info not in actual_results: log.debug( - "Cannot compare for handle %s because it wasn't found in the transformed model", - debug_handle, + "Cannot compare for debug info %s because it wasn't found in the transformed model", + debug_info, ) continue - actual_name, actual_stack, actual_stats = actual_results[debug_handle] + actual_name, actual_stack, actual_stats = actual_results[debug_info] try: results = [ QuantizationComparisonResult(actual=a, ref=b) @@ -388,13 +404,13 @@ def compare_results( ] except Exception as e: # Add extra information for an exception from QuantizationComparisonResult - # if the shapes didn't match, to include the handle and the node names. + # if the shapes didn't match, to include the debug info and the node names. raise ValueError( - f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}" + f"For debug_info={debug_info} from ref node {ref_name} and actual node {actual_name}" ) from e - comparisons[debug_handle] = NodeAccuracySummary( - handle=debug_handle, + comparisons[debug_info] = NodeAccuracySummary( + debug_info=debug_info, actual_node_name=actual_name or "", actual_module_stack=_module_stack_to_str(actual_stack), ref_node_name=ref_name or "", diff --git a/torchao/testing/pt2e/utils.py b/torchao/testing/pt2e/utils.py index 5d903a4a15..8ba6480835 100644 --- a/torchao/testing/pt2e/utils.py +++ b/torchao/testing/pt2e/utils.py @@ -6,7 +6,6 @@ import copy import unittest -from typing import Dict import torch from torch.ao.quantization.backend_config import ( @@ -23,7 +22,7 @@ from torch.testing._internal.common_utils import TestCase from torchao.quantization.pt2e import FROM_NODE_KEY -from torchao.quantization.pt2e._numeric_debugger import _generate_debug_handle_from_node +from torchao.quantization.pt2e._numeric_debugger import _extract_node_source_debug_info from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, @@ -147,48 +146,59 @@ class PT2ENumericDebuggerTestCase(TestCase): for numeric debugging functionality. """ - def _assert_each_node_has_debug_handle(self, model) -> None: - """Assert that each node in the model has a debug handle.""" - - def _assert_node_has_debug_handle(node): + def _assert_each_node_has_from_node_source(self, model) -> None: + def _assert_node_has_from_node_source(node): + if node.op == "placeholder" or node.op == "output": + return self.assertIn( FROM_NODE_KEY, node.meta, f"Node {node} doesn't have from_node info", ) - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) + bfs_trace_with_node_process(model, _assert_node_has_from_node_source) + + def _extract_from_node_source(self, model) -> dict[str, any]: + from_node_source_map: dict[str, any] = {} - def _extract_debug_handles(self, model) -> Dict[str, int]: - """Extract debug handles from all nodes in the model.""" - debug_handle_map: Dict[str, int] = {} + def _extract_from_node_source_from_node(node): + nonlocal from_node_source_map + if (root_node_source := _extract_node_source_debug_info(node)) is not None: + from_node_source_map[str(node)] = ( + root_node_source.name, + root_node_source.graph_id, + ) - def _extract_debug_handles_from_node(node): - nonlocal debug_handle_map - if (dh := _generate_debug_handle_from_node(node)) is not None: - debug_handle_map[str(node)] = dh + bfs_trace_with_node_process(model, _extract_from_node_source_from_node) - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) - return debug_handle_map + return from_node_source_map - def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]: - prev_decomp_op_to_debug_handle_map: dict[str, int] = {} + def _extract_from_node_source_with_prev_decomp_op(self, model) -> dict[str, any]: + prev_decomp_op_to_from_node_source_map: dict[str, any] = {} - def _extract_debug_handles_with_prev_decomp_op_from_node(node): - nonlocal prev_decomp_op_to_debug_handle_map - if FROM_NODE_KEY in node.meta: + def _extract_from_node_source_with_prev_decomp_op_from_node(node): + nonlocal prev_decomp_op_to_from_node_source_map + if FROM_NODE_KEY in node.meta and node.meta[FROM_NODE_KEY] is not None: prev_decomp_op = str(node.meta.get("nn_module_stack")) - debug_handle = _generate_debug_handle_from_node(node) - if prev_decomp_op not in prev_decomp_op_to_debug_handle_map: - prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle + from_node_source = node.meta[FROM_NODE_KEY] + if prev_decomp_op not in prev_decomp_op_to_from_node_source_map: + prev_decomp_op_to_from_node_source_map[prev_decomp_op] = ( + from_node_source + ) else: assert ( - prev_decomp_op_to_debug_handle_map[prev_decomp_op] - == debug_handle - ), f"Node {node} has different debug handle {debug_handle}" + prev_decomp_op_to_from_node_source_map[prev_decomp_op] + == from_node_source + ), f"Node {node} has different from_node info {from_node_source}" "than previous node sharing the same decomp op {prev_decomp_op}" bfs_trace_with_node_process( - model, _extract_debug_handles_with_prev_decomp_op_from_node + model, _extract_from_node_source_with_prev_decomp_op_from_node + ) + return prev_decomp_op_to_from_node_source_map + + def assertNodeSourcesEqual(self, node_source_1, node_source_2): + self.assertTrue( + node_source_1.name == node_source_2.name + and node_source_1.graph_id == node_source_2.graph_id ) - return prev_decomp_op_to_debug_handle_map