@@ -173,6 +173,9 @@ def leave(self, node, key, parent, path, ancestors):
173
173
# Provide special return values as attributes
174
174
BREAK , SKIP , REMOVE , IDLE = BREAK , SKIP , REMOVE , IDLE
175
175
176
+ def __init__ (self ):
177
+ self ._visit_fns = {}
178
+
176
179
def __init_subclass__ (cls ) -> None :
177
180
"""Verify that all defined handlers are valid."""
178
181
super ().__init_subclass__ ()
@@ -197,11 +200,12 @@ def __init_subclass__(cls) -> None:
197
200
198
201
def get_visit_fn (self , kind : str , is_leaving : bool = False ) -> Callable :
199
202
"""Get the visit function for the given node kind and direction."""
200
- method = "leave" if is_leaving else "enter"
201
- visit_fn = getattr (self , f"{ method } _{ kind } " , None )
202
- if not visit_fn :
203
- visit_fn = getattr (self , method , None )
204
- return visit_fn
203
+ key = (kind , is_leaving )
204
+ if key not in self ._visit_fns :
205
+ method = "leave" if is_leaving else "enter"
206
+ fn = getattr (self , f"{ method } _{ kind } " , None )
207
+ self ._visit_fns [key ] = fn or getattr (self , method , None )
208
+ return self ._visit_fns [key ]
205
209
206
210
207
211
class Stack (NamedTuple ):
@@ -367,14 +371,22 @@ class ParallelVisitor(Visitor):
367
371
368
372
def __init__ (self , visitors : Collection [Visitor ]):
369
373
"""Create a new visitor from the given list of parallel visitors."""
374
+ super ().__init__ ()
370
375
self .visitors = visitors
371
376
self .skipping : List [Any ] = [None ] * len (visitors )
377
+ self ._enter_visit_fns = {}
378
+ self ._leave_visit_fns = {}
372
379
373
380
def enter (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
381
+ visit_fns = self ._enter_visit_fns .get (node .kind )
382
+ if visit_fns is None :
383
+ visit_fns = [v .get_visit_fn (node .kind ) for v in self .visitors ]
384
+ self ._enter_visit_fns [node .kind ] = visit_fns
385
+
374
386
skipping = self .skipping
375
387
for i , visitor in enumerate (self .visitors ):
376
388
if not skipping [i ]:
377
- fn = visitor . get_visit_fn ( node . kind )
389
+ fn = visit_fns [ i ]
378
390
if fn :
379
391
result = fn (node , * args )
380
392
if result is SKIP or result is False :
@@ -386,10 +398,15 @@ def enter(self, node: Node, *args: Any) -> Optional[VisitorAction]:
386
398
return None
387
399
388
400
def leave (self , node : Node , * args : Any ) -> Optional [VisitorAction ]:
401
+ visit_fns = self ._leave_visit_fns .get (node .kind )
402
+ if visit_fns is None :
403
+ visit_fns = [v .get_visit_fn (node .kind , is_leaving = True ) for v in self .visitors ]
404
+ self ._leave_visit_fns [node .kind ] = visit_fns
405
+
389
406
skipping = self .skipping
390
407
for i , visitor in enumerate (self .visitors ):
391
408
if not skipping [i ]:
392
- fn = visitor . get_visit_fn ( node . kind , is_leaving = True )
409
+ fn = visit_fns [ i ]
393
410
if fn :
394
411
result = fn (node , * args )
395
412
if result is BREAK or result is True :
0 commit comments