@@ -285,6 +285,11 @@ def flatten(node):
285285    return  '' .join (acc )
286286
287287
288+ def  make_xml (text ):
289+     xml  =  ET .XML ('<xml>%s</xml>'  %  text )
290+     return  xml 
291+ 
292+ 
288293def  normalize_xpath (path ):
289294    path  =  path .replace ("{{channel}}" , channel )
290295    if  path .startswith ('//' ):
@@ -401,7 +406,7 @@ def get_tree_count(tree, path):
401406    return  len (tree .findall (path ))
402407
403408
404- def  check_snapshot (snapshot_name , tree , normalize_to_text ):
409+ def  check_snapshot (snapshot_name , actual_tree , normalize_to_text ):
405410    assert  rust_test_path .endswith ('.rs' )
406411    snapshot_path  =  '{}.{}.{}' .format (rust_test_path [:- 3 ], snapshot_name , 'html' )
407412    try :
@@ -414,11 +419,15 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
414419            raise  FailedCheck ('No saved snapshot value' )
415420
416421    if  not  normalize_to_text :
417-         actual_str  =  ET .tostring (tree ).decode ('utf-8' )
422+         actual_str  =  ET .tostring (actual_tree ).decode ('utf-8' )
418423    else :
419-         actual_str  =  flatten (tree )
424+         actual_str  =  flatten (actual_tree )
425+ 
426+     if  not  expected_str  \
427+         or  (not  normalize_to_text  and 
428+             not  compare_tree (make_xml (actual_str ), make_xml (expected_str ), stderr )) \
429+         or  (normalize_to_text  and  actual_str  !=  expected_str ):
420430
421-     if  expected_str  !=  actual_str :
422431        if  bless :
423432            with  open (snapshot_path , 'w' ) as  snapshot_file :
424433                snapshot_file .write (actual_str )
@@ -430,6 +439,59 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
430439            print ()
431440            raise  FailedCheck ('Actual snapshot value is different than expected' )
432441
442+ 
443+ # Adapted from https://github.com/formencode/formencode/blob/3a1ba9de2fdd494dd945510a4568a3afeddb0b2e/formencode/doctest_xml_compare.py#L72-L120 
444+ def  compare_tree (x1 , x2 , reporter = None ):
445+     if  x1 .tag  !=  x2 .tag :
446+         if  reporter :
447+             reporter ('Tags do not match: %s and %s'  %  (x1 .tag , x2 .tag ))
448+         return  False 
449+     for  name , value  in  x1 .attrib .items ():
450+         if  x2 .attrib .get (name ) !=  value :
451+             if  reporter :
452+                 reporter ('Attributes do not match: %s=%r, %s=%r' 
453+                          %  (name , value , name , x2 .attrib .get (name )))
454+             return  False 
455+     for  name  in  x2 .attrib :
456+         if  name  not  in x1 .attrib :
457+             if  reporter :
458+                 reporter ('x2 has an attribute x1 is missing: %s' 
459+                          %  name )
460+             return  False 
461+     if  not  text_compare (x1 .text , x2 .text ):
462+         if  reporter :
463+             reporter ('text: %r != %r'  %  (x1 .text , x2 .text ))
464+         return  False 
465+     if  not  text_compare (x1 .tail , x2 .tail ):
466+         if  reporter :
467+             reporter ('tail: %r != %r'  %  (x1 .tail , x2 .tail ))
468+         return  False 
469+     cl1  =  list (x1 )
470+     cl2  =  list (x2 )
471+     if  len (cl1 ) !=  len (cl2 ):
472+         if  reporter :
473+             reporter ('children length differs, %i != %i' 
474+                      %  (len (cl1 ), len (cl2 )))
475+         return  False 
476+     i  =  0 
477+     for  c1 , c2  in  zip (cl1 , cl2 ):
478+         i  +=  1 
479+         if  not  compare_tree (c1 , c2 , reporter = reporter ):
480+             if  reporter :
481+                 reporter ('children %i do not match: %s' 
482+                          %  (i , c1 .tag ))
483+             return  False 
484+     return  True 
485+ 
486+ 
487+ def  text_compare (t1 , t2 ):
488+     if  not  t1  and  not  t2 :
489+         return  True 
490+     if  t1  ==  '*'  or  t2  ==  '*' :
491+         return  True 
492+     return  (t1  or  '' ).strip () ==  (t2  or  '' ).strip ()
493+ 
494+ 
433495def  stderr (* args ):
434496    if  sys .version_info .major  <  3 :
435497        file  =  codecs .getwriter ('utf-8' )(sys .stderr )
0 commit comments