@@ -21,7 +21,7 @@ class Context:
2121 lines : set [int ]
2222
2323
24- class RegionFinder ( ast . NodeVisitor ) :
24+ class RegionFinder :
2525 """An ast visitor that will find and track regions of code.
2626
2727 Functions and classes are tracked by name. Results are in the .regions
@@ -34,13 +34,27 @@ def __init__(self) -> None:
3434
3535 def parse_source (self , source : str ) -> None :
3636 """Parse `source` and walk the ast to populate the .regions attribute."""
37- self .visit (ast .parse (source ))
37+ self .handle_node (ast .parse (source ))
3838
3939 def fq_node_name (self ) -> str :
4040 """Get the current fully qualified name we're processing."""
4141 return "." .join (c .name for c in self .context )
4242
43- def visit_FunctionDef (self , node : ast .FunctionDef ) -> None :
43+ def handle_node (self , node : ast .AST ) -> None :
44+ """Recursively handle any node."""
45+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
46+ self .handle_FunctionDef (node )
47+ elif isinstance (node , ast .ClassDef ):
48+ self .handle_ClassDef (node )
49+ else :
50+ self .handle_node_body (node )
51+
52+ def handle_node_body (self , node : ast .AST ) -> None :
53+ """Recursively handle the nodes in this node's body, if any."""
54+ for body_node in getattr (node , "body" , ()):
55+ self .handle_node (body_node )
56+
57+ def handle_FunctionDef (self , node : ast .FunctionDef | ast .AsyncFunctionDef ) -> None :
4458 """Called for `def` or `async def`."""
4559 lines = set (range (node .body [0 ].lineno , cast (int , node .body [- 1 ].end_lineno ) + 1 ))
4660 if self .context and self .context [- 1 ].kind == "class" :
@@ -60,12 +74,10 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
6074 lines = lines ,
6175 )
6276 )
63- self .generic_visit (node )
77+ self .handle_node_body (node )
6478 self .context .pop ()
6579
66- visit_AsyncFunctionDef = visit_FunctionDef # type: ignore[assignment]
67-
68- def visit_ClassDef (self , node : ast .ClassDef ) -> None :
80+ def handle_ClassDef (self , node : ast .ClassDef ) -> None :
6981 """Called for `class`."""
7082 # The lines for a class are the lines in the methods of the class.
7183 # We start empty, and count on visit_FunctionDef to add the lines it
@@ -80,7 +92,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
8092 lines = lines ,
8193 )
8294 )
83- self .generic_visit (node )
95+ self .handle_node_body (node )
8496 self .context .pop ()
8597 # Class bodies should be excluded from the enclosing classes.
8698 for ancestor in reversed (self .context ):
0 commit comments