@@ -37,10 +37,12 @@ def test_simple(self):
37
37
example_inputs = m .example_inputs ()
38
38
ep = export_for_training (m , example_inputs , strict = True )
39
39
m = ep .module ()
40
- self ._assert_each_node_has_debug_handle (m )
41
- debug_handle_map = self ._extract_debug_handles (m )
40
+ self ._assert_each_node_has_from_node_source (m )
41
+ from_node_source_map = self ._extract_from_node_source (m )
42
42
43
- self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
43
+ self .assertEqual (
44
+ len (set (from_node_source_map .values ())), len (from_node_source_map )
45
+ )
44
46
45
47
@unittest .skip ("debug flow not working on model with conditional control flow" )
46
48
def test_control_flow (self ):
@@ -49,37 +51,42 @@ def test_control_flow(self):
49
51
ep = export_for_training (m , example_inputs , strict = True )
50
52
m = ep .module ()
51
53
52
- self ._assert_each_node_has_debug_handle (m )
53
- debug_handle_map = self ._extract_debug_handles (m )
54
+ self ._assert_each_node_has_from_node_source (m )
55
+ from_node_source_map = self ._extract_from_node_source (m )
54
56
55
- self .assertEqual (len (set (debug_handle_map .values ())), len (debug_handle_map ))
57
+ self .assertEqual (
58
+ len (set (from_node_source_map .values ())), len (from_node_source_map )
59
+ )
56
60
57
61
def test_copy_preserve_handle (self ):
58
62
m = TestHelperModules .Conv2dThenConv1d ()
59
63
example_inputs = m .example_inputs ()
60
64
ep = torch .export .export (m , example_inputs , strict = True )
61
65
m = ep .module ()
62
66
63
- self ._assert_each_node_has_debug_handle (m )
64
- debug_handle_map_ref = self ._extract_debug_handles (m )
67
+ self ._assert_each_node_has_from_node_source (m )
68
+ from_node_source_map_ref = self ._extract_from_node_source (m )
65
69
66
70
ep_copy = copy .copy (ep )
67
- debug_handle_map = self ._extract_debug_handles (ep_copy .module ())
71
+ from_node_source_map = self ._extract_from_node_source (ep_copy .module ())
68
72
69
- self ._assert_each_node_has_debug_handle (ep )
70
- self .assertEqual (debug_handle_map , debug_handle_map_ref )
73
+ self ._assert_each_node_has_from_node_source (ep )
74
+ self .assertEqual (from_node_source_map , from_node_source_map_ref )
71
75
72
76
def test_deepcopy_preserve_handle (self ):
73
77
m = TestHelperModules .Conv2dThenConv1d ()
74
78
example_inputs = m .example_inputs ()
75
79
ep = torch .export .export (m , example_inputs , strict = True )
76
80
77
- debug_handle_map_ref = self ._extract_debug_handles (ep .module ())
81
+ from_node_source_map_ref = self ._extract_from_node_source (ep .module ())
78
82
ep_copy = copy .deepcopy (ep )
79
- debug_handle_map = self ._extract_debug_handles (ep_copy .module ())
83
+ from_node_source_map = self ._extract_from_node_source (ep_copy .module ())
80
84
81
- self ._assert_each_node_has_debug_handle (ep .module ())
82
- self .assertEqual (debug_handle_map , debug_handle_map_ref )
85
+ self ._assert_each_node_has_from_node_source (ep .module ())
86
+ self .assertEqual (from_node_source_map , from_node_source_map_ref )
87
+ self .assertEqual (
88
+ set (from_node_source_map .values ()), set (from_node_source_map_ref .values ())
89
+ )
83
90
84
91
@unittest .skip (
85
92
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
@@ -90,16 +97,16 @@ def test_re_export_preserve_handle(self):
90
97
ep = export_for_training (m , example_inputs , strict = True )
91
98
m = ep .module ()
92
99
93
- self ._assert_each_node_has_debug_handle (m )
94
- debug_handle_map_ref = self ._extract_debug_handles (m )
100
+ self ._assert_each_node_has_from_node_source (m )
101
+ from_node_source_map_ref = self ._extract_from_node_source (m )
95
102
96
103
ep_reexport = export_for_training (m , example_inputs , strict = True )
97
104
m_reexport = ep_reexport .module ()
98
105
99
- self ._assert_each_node_has_debug_handle (m_reexport )
100
- debug_handle_map = self ._extract_debug_handles (m_reexport )
106
+ self ._assert_each_node_has_from_node_source (m_reexport )
107
+ from_node_source_map = self ._extract_from_node_source (m_reexport )
101
108
102
- self .assertEqual (debug_handle_map , debug_handle_map_ref )
109
+ self .assertEqual (from_node_source_map , from_node_source_map_ref )
103
110
104
111
@unittest .skip (
105
112
"torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recom..."
@@ -110,19 +117,19 @@ def test_run_decompositions_same_handle_id(self):
110
117
ep = export_for_training (m , example_inputs , strict = True )
111
118
m = ep .module ()
112
119
113
- self ._assert_each_node_has_debug_handle (m )
114
- debug_handle_map_ref = self ._extract_debug_handles (m )
120
+ self ._assert_each_node_has_from_node_source (m )
121
+ from_node_source_map_ref = self ._extract_from_node_source (m )
115
122
116
123
ep_copy = copy .copy (ep )
117
124
ep_copy = ep_copy .run_decompositions ()
118
125
m_decomposed = ep_copy .module ()
119
126
120
- self ._assert_each_node_has_debug_handle (m_decomposed )
121
- debug_handle_map = self ._extract_debug_handles (m_decomposed )
127
+ self ._assert_each_node_has_from_node_source (m_decomposed )
128
+ from_node_source_map = self ._extract_from_node_source (m_decomposed )
122
129
123
130
# checking the map still has the same ids, the node may change
124
131
self .assertEqual (
125
- set (debug_handle_map .values ()), set (debug_handle_map_ref .values ())
132
+ set (from_node_source_map .values ()), set (from_node_source_map_ref .values ())
126
133
)
127
134
128
135
@unittest .skip (
@@ -139,22 +146,23 @@ def test_run_decompositions_map_handle_to_new_nodes(self):
139
146
ep = export_for_training (m , example_inputs , strict = True )
140
147
m = ep .module ()
141
148
142
- self ._assert_each_node_has_debug_handle (m )
143
- pre_decomp_to_debug_handle_map_ref = (
144
- self ._extract_debug_handles_with_prev_decomp_op (m )
149
+ self ._assert_each_node_has_from_node_source (m )
150
+ pre_decomp_to_from_node_source_map_ref = (
151
+ self ._extract_from_node_source_with_prev_decomp_op (m )
145
152
)
146
153
147
154
ep_copy = copy .copy (ep )
148
155
ep_copy = ep_copy .run_decompositions ()
149
156
m_decomposed = ep_copy .module ()
150
- self ._assert_each_node_has_debug_handle (m_decomposed )
151
- pre_decomp_to_debug_handle_map = (
152
- self ._extract_debug_handles_with_prev_decomp_op (m_decomposed )
157
+ self ._assert_each_node_has_from_node_source (m_decomposed )
158
+ pre_decomp_to_from_node_source_map = (
159
+ self ._extract_from_node_source_with_prev_decomp_op (m_decomposed )
153
160
)
154
161
155
- # checking the map still has the same ids , the node may change
162
+ # checking the map still has the same infos , the node may change
156
163
self .assertEqual (
157
- pre_decomp_to_debug_handle_map , pre_decomp_to_debug_handle_map_ref
164
+ pre_decomp_to_from_node_source_map ,
165
+ pre_decomp_to_from_node_source_map_ref ,
158
166
)
159
167
160
168
def test_prepare_for_propagation_comparison (self ):
@@ -178,18 +186,18 @@ def test_added_node_gets_unique_id(self) -> None:
178
186
example_inputs = m .example_inputs ()
179
187
ep = export_for_training (m , example_inputs , strict = True )
180
188
181
- ref_handles = self ._extract_debug_handles (ep .module ())
182
- ref_counter = Counter (ref_handles .values ())
189
+ ref_from_node_source = self ._extract_from_node_source (ep .module ())
190
+ ref_counter = Counter (ref_from_node_source .values ())
183
191
184
192
for k , v in ref_counter .items ():
185
193
self .assertEqual (
186
194
v ,
187
195
1 ,
188
- msg = f"For handle { k } , there were { v } nodes with that handle , but expected only 1" ,
196
+ msg = f"For from_node info { k } , there were { v } nodes with that info , but expected only 1" ,
189
197
)
190
198
191
- # Now that we have unique ids , add a new node into the graph and re-generate
192
- # to make sure that the new node gets a unique id .
199
+ # Now that we have unique infos , add a new node into the graph and re-generate
200
+ # to make sure that the new node gets a unique info .
193
201
last_node = next (iter (reversed (ep .graph .nodes )))
194
202
with ep .graph .inserting_before (last_node ):
195
203
arg = last_node .args [0 ]
@@ -200,30 +208,39 @@ def test_added_node_gets_unique_id(self) -> None:
200
208
arg .replace_all_uses_with (n , lambda x : x != n )
201
209
ep .graph_module .recompile ()
202
210
203
- # Regenerate handles , make sure only the new relu node has a new id , and
204
- # it doesn't clash with any of the existing ids .
211
+ # Regenerate from_node info , make sure only the new relu node has a new info , and
212
+ # it doesn't clash with any of the existing infos .
205
213
206
214
m = ep .module ()
207
- self ._assert_each_node_has_debug_handle (m )
208
- handles_after_modification = self ._extract_debug_handles (m )
209
- handles_counter = Counter (handles_after_modification .values ())
210
- for name , handle in ref_handles .items ():
211
- self .assertIn (name , handles_after_modification )
212
- # Check that handle was unchanged.
213
- self .assertEqual (handles_after_modification [name ], handle )
215
+ self ._assert_each_node_has_from_node_source (m )
216
+ from_node_source_after_modification = self ._extract_from_node_source (m )
217
+ from_node_source_counter = Counter (from_node_source_after_modification .values ())
218
+ for name , from_node_source in ref_from_node_source .items ():
219
+ self .assertIn (name , from_node_source_after_modification )
220
+ # Check that from_node info was unchanged.
221
+ self .assertEqual (
222
+ from_node_source_after_modification [name ], from_node_source
223
+ )
214
224
# Check that total count was unchanged.
215
- ref_count = ref_counter [handle ]
216
- after_count = handles_counter [ handle ]
225
+ ref_count = ref_counter [from_node_source ]
226
+ after_count = from_node_source_counter [ from_node_source ]
217
227
self .assertEqual (
218
228
after_count ,
219
229
ref_count ,
220
- msg = f"For handle { handle } , there were { after_count } nodes with that handle , but expected only { ref_count } " ,
230
+ msg = f"For from_node info { from_node_source } , there were { after_count } nodes with that info , but expected only { ref_count } " ,
221
231
)
222
232
223
- # Check for relu specifically. Avoid hardcoding the handle id since it
233
+ # Check for relu specifically. Avoid hardcoding the from_node info since it
224
234
# may change with future node ordering changes.
225
- self .assertNotIn (handles_after_modification ["relu_default" ], ref_counter )
226
- self .assertEqual (handles_counter [handles_after_modification ["relu_default" ]], 1 )
235
+ self .assertNotIn (
236
+ from_node_source_after_modification ["relu_default" ], ref_counter
237
+ )
238
+ self .assertEqual (
239
+ from_node_source_counter [
240
+ from_node_source_after_modification ["relu_default" ]
241
+ ],
242
+ 1 ,
243
+ )
227
244
228
245
229
246
if __name__ == "__main__" :
0 commit comments