Skip to content

Commit ebaa844

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
Graduate debug handle in torchao (#2452)
Summary: This diff maket torchao's debugging infra fully leverage node["from_node"] info and get rid of debug handle. debug handle, we will miss you 🫡 Reviewed By: jerryzh168 Differential Revision: D76628702
1 parent ac14d92 commit ebaa844

File tree

3 files changed

+191
-148
lines changed

3 files changed

+191
-148
lines changed

test/quantization/pt2e/test_numeric_debugger.py

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def test_simple(self):
3737
example_inputs = m.example_inputs()
3838
ep = export_for_training(m, example_inputs, strict=True)
3939
m = ep.module()
40-
self._assert_each_node_has_debug_handle(m)
41-
debug_handle_map = self._extract_debug_handles(m)
40+
self._assert_each_node_has_from_node_source(m)
41+
from_node_source_map = self._extract_from_node_source(m)
4242

43-
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
43+
self.assertEqual(
44+
len(set(from_node_source_map.values())), len(from_node_source_map)
45+
)
4446

4547
@unittest.skip("debug flow not working on model with conditional control flow")
4648
def test_control_flow(self):
@@ -49,37 +51,42 @@ def test_control_flow(self):
4951
ep = export_for_training(m, example_inputs, strict=True)
5052
m = ep.module()
5153

52-
self._assert_each_node_has_debug_handle(m)
53-
debug_handle_map = self._extract_debug_handles(m)
54+
self._assert_each_node_has_from_node_source(m)
55+
from_node_source_map = self._extract_from_node_source(m)
5456

55-
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
57+
self.assertEqual(
58+
len(set(from_node_source_map.values())), len(from_node_source_map)
59+
)
5660

5761
def test_copy_preserve_handle(self):
5862
m = TestHelperModules.Conv2dThenConv1d()
5963
example_inputs = m.example_inputs()
6064
ep = torch.export.export(m, example_inputs, strict=True)
6165
m = ep.module()
6266

63-
self._assert_each_node_has_debug_handle(m)
64-
debug_handle_map_ref = self._extract_debug_handles(m)
67+
self._assert_each_node_has_from_node_source(m)
68+
from_node_source_map_ref = self._extract_from_node_source(m)
6569

6670
ep_copy = copy.copy(ep)
67-
debug_handle_map = self._extract_debug_handles(ep_copy.module())
71+
from_node_source_map = self._extract_from_node_source(ep_copy.module())
6872

69-
self._assert_each_node_has_debug_handle(ep)
70-
self.assertEqual(debug_handle_map, debug_handle_map_ref)
73+
self._assert_each_node_has_from_node_source(ep)
74+
self.assertEqual(from_node_source_map, from_node_source_map_ref)
7175

7276
def test_deepcopy_preserve_handle(self):
7377
m = TestHelperModules.Conv2dThenConv1d()
7478
example_inputs = m.example_inputs()
7579
ep = torch.export.export(m, example_inputs, strict=True)
7680

77-
debug_handle_map_ref = self._extract_debug_handles(ep.module())
81+
from_node_source_map_ref = self._extract_from_node_source(ep.module())
7882
ep_copy = copy.deepcopy(ep)
79-
debug_handle_map = self._extract_debug_handles(ep_copy.module())
83+
from_node_source_map = self._extract_from_node_source(ep_copy.module())
8084

81-
self._assert_each_node_has_debug_handle(ep.module())
82-
self.assertEqual(debug_handle_map, debug_handle_map_ref)
85+
self._assert_each_node_has_from_node_source(ep.module())
86+
self.assertEqual(from_node_source_map, from_node_source_map_ref)
87+
self.assertEqual(
88+
set(from_node_source_map.values()), set(from_node_source_map_ref.values())
89+
)
8390

8491
@unittest.skip(
8592
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
@@ -90,16 +97,16 @@ def test_re_export_preserve_handle(self):
9097
ep = export_for_training(m, example_inputs, strict=True)
9198
m = ep.module()
9299

93-
self._assert_each_node_has_debug_handle(m)
94-
debug_handle_map_ref = self._extract_debug_handles(m)
100+
self._assert_each_node_has_from_node_source(m)
101+
from_node_source_map_ref = self._extract_from_node_source(m)
95102

96103
ep_reexport = export_for_training(m, example_inputs, strict=True)
97104
m_reexport = ep_reexport.module()
98105

99-
self._assert_each_node_has_debug_handle(m_reexport)
100-
debug_handle_map = self._extract_debug_handles(m_reexport)
106+
self._assert_each_node_has_from_node_source(m_reexport)
107+
from_node_source_map = self._extract_from_node_source(m_reexport)
101108

102-
self.assertEqual(debug_handle_map, debug_handle_map_ref)
109+
self.assertEqual(from_node_source_map, from_node_source_map_ref)
103110

104111
@unittest.skip(
105112
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
@@ -110,19 +117,19 @@ def test_run_decompositions_same_handle_id(self):
110117
ep = export_for_training(m, example_inputs, strict=True)
111118
m = ep.module()
112119

113-
self._assert_each_node_has_debug_handle(m)
114-
debug_handle_map_ref = self._extract_debug_handles(m)
120+
self._assert_each_node_has_from_node_source(m)
121+
from_node_source_map_ref = self._extract_from_node_source(m)
115122

116123
ep_copy = copy.copy(ep)
117124
ep_copy = ep_copy.run_decompositions()
118125
m_decomposed = ep_copy.module()
119126

120-
self._assert_each_node_has_debug_handle(m_decomposed)
121-
debug_handle_map = self._extract_debug_handles(m_decomposed)
127+
self._assert_each_node_has_from_node_source(m_decomposed)
128+
from_node_source_map = self._extract_from_node_source(m_decomposed)
122129

123130
# checking the map still has the same ids, the node may change
124131
self.assertEqual(
125-
set(debug_handle_map.values()), set(debug_handle_map_ref.values())
132+
set(from_node_source_map.values()), set(from_node_source_map_ref.values())
126133
)
127134

128135
@unittest.skip(
@@ -139,22 +146,23 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
139146
ep = export_for_training(m, example_inputs, strict=True)
140147
m = ep.module()
141148

142-
self._assert_each_node_has_debug_handle(m)
143-
pre_decomp_to_debug_handle_map_ref = (
144-
self._extract_debug_handles_with_prev_decomp_op(m)
149+
self._assert_each_node_has_from_node_source(m)
150+
pre_decomp_to_from_node_source_map_ref = (
151+
self._extract_from_node_source_with_prev_decomp_op(m)
145152
)
146153

147154
ep_copy = copy.copy(ep)
148155
ep_copy = ep_copy.run_decompositions()
149156
m_decomposed = ep_copy.module()
150-
self._assert_each_node_has_debug_handle(m_decomposed)
151-
pre_decomp_to_debug_handle_map = (
152-
self._extract_debug_handles_with_prev_decomp_op(m_decomposed)
157+
self._assert_each_node_has_from_node_source(m_decomposed)
158+
pre_decomp_to_from_node_source_map = (
159+
self._extract_from_node_source_with_prev_decomp_op(m_decomposed)
153160
)
154161

155-
# checking the map still has the same ids, the node may change
162+
# checking the map still has the same infos, the node may change
156163
self.assertEqual(
157-
pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref
164+
pre_decomp_to_from_node_source_map,
165+
pre_decomp_to_from_node_source_map_ref,
158166
)
159167

160168
def test_prepare_for_propagation_comparison(self):
@@ -178,18 +186,18 @@ def test_added_node_gets_unique_id(self) -> None:
178186
example_inputs = m.example_inputs()
179187
ep = export_for_training(m, example_inputs, strict=True)
180188

181-
ref_handles = self._extract_debug_handles(ep.module())
182-
ref_counter = Counter(ref_handles.values())
189+
ref_from_node_source = self._extract_from_node_source(ep.module())
190+
ref_counter = Counter(ref_from_node_source.values())
183191

184192
for k, v in ref_counter.items():
185193
self.assertEqual(
186194
v,
187195
1,
188-
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
196+
msg=f"For from_node info {k}, there were {v} nodes with that info, but expected only 1",
189197
)
190198

191-
# Now that we have unique ids, add a new node into the graph and re-generate
192-
# to make sure that the new node gets a unique id.
199+
# Now that we have unique infos, add a new node into the graph and re-generate
200+
# to make sure that the new node gets a unique info.
193201
last_node = next(iter(reversed(ep.graph.nodes)))
194202
with ep.graph.inserting_before(last_node):
195203
arg = last_node.args[0]
@@ -200,30 +208,39 @@ def test_added_node_gets_unique_id(self) -> None:
200208
arg.replace_all_uses_with(n, lambda x: x != n)
201209
ep.graph_module.recompile()
202210

203-
# Regenerate handles, make sure only the new relu node has a new id, and
204-
# it doesn't clash with any of the existing ids.
211+
# Regenerate from_node info, make sure only the new relu node has a new info, and
212+
# it doesn't clash with any of the existing infos.
205213

206214
m = ep.module()
207-
self._assert_each_node_has_debug_handle(m)
208-
handles_after_modification = self._extract_debug_handles(m)
209-
handles_counter = Counter(handles_after_modification.values())
210-
for name, handle in ref_handles.items():
211-
self.assertIn(name, handles_after_modification)
212-
# Check that handle was unchanged.
213-
self.assertEqual(handles_after_modification[name], handle)
215+
self._assert_each_node_has_from_node_source(m)
216+
from_node_source_after_modification = self._extract_from_node_source(m)
217+
from_node_source_counter = Counter(from_node_source_after_modification.values())
218+
for name, from_node_source in ref_from_node_source.items():
219+
self.assertIn(name, from_node_source_after_modification)
220+
# Check that from_node info was unchanged.
221+
self.assertEqual(
222+
from_node_source_after_modification[name], from_node_source
223+
)
214224
# Check that total count was unchanged.
215-
ref_count = ref_counter[handle]
216-
after_count = handles_counter[handle]
225+
ref_count = ref_counter[from_node_source]
226+
after_count = from_node_source_counter[from_node_source]
217227
self.assertEqual(
218228
after_count,
219229
ref_count,
220-
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
230+
msg=f"For from_node info {from_node_source}, there were {after_count} nodes with that info, but expected only {ref_count}",
221231
)
222232

223-
# Check for relu specifically. Avoid hardcoding the handle id since it
233+
# Check for relu specifically. Avoid hardcoding the from_node info since it
224234
# may change with future node ordering changes.
225-
self.assertNotIn(handles_after_modification["relu_default"], ref_counter)
226-
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
235+
self.assertNotIn(
236+
from_node_source_after_modification["relu_default"], ref_counter
237+
)
238+
self.assertEqual(
239+
from_node_source_counter[
240+
from_node_source_after_modification["relu_default"]
241+
],
242+
1,
243+
)
227244

228245

229246
if __name__ == "__main__":

0 commit comments

Comments
 (0)