@@ -524,6 +524,7 @@ def configure_post_pass(self):
524524
525525 def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
526526
527+ vllm_config = self .vllm_config
527528 if not self .compilation_config .cache_dir :
528529 # no provided cache dir, generate one based on the known factors
529530 # that affects the compilation. if none of the factors change,
@@ -532,7 +533,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
532533
533534 # 1. factors come from the vllm_config (it mainly summarizes how the
534535 # model is created)
535- vllm_config = self .vllm_config
536536 config_hash = vllm_config .compute_hash ()
537537
538538 # 2. factors come from the code files that are traced by Dynamo (
@@ -556,20 +556,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
556556 hash_key = hashlib .md5 (
557557 f"{ config_hash } _{ code_hash } " .encode ()).hexdigest ()[:10 ]
558558 cache_dir = os .path .join (
559- envs .VLLM_CACHE_ROOT , "torch_compile_cache" , hash_key ,
560- f"rank_{ vllm_config .parallel_config .rank } " )
561- else :
562- cache_dir = self .compilation_config .cache_dir
559+ envs .VLLM_CACHE_ROOT ,
560+ "torch_compile_cache" ,
561+ hash_key ,
562+ )
563+ self .compilation_config .cache_dir = cache_dir
564+
565+ cache_dir = self .compilation_config .cache_dir
563566 os .makedirs (cache_dir , exist_ok = True )
567+ local_cache_dir = os .path .join (
568+ cache_dir , f"rank_{ vllm_config .parallel_config .rank } " )
569+ self .compilation_config .local_cache_dir = local_cache_dir
564570
565571 disabled = envs .VLLM_DISABLE_COMPILE_CACHE
566572 self .inductor_hash_cache : InductorHashCache = InductorHashCache (
567- cache_dir , disabled = disabled )
573+ local_cache_dir , disabled = disabled )
568574 if disabled :
569575 logger .info ("vLLM's torch.compile cache is disabled." )
570576 else :
571577 logger .info ("Using cache directory: %s for vLLM's torch.compile" ,
572- cache_dir )
578+ local_cache_dir )
573579
574580 # when dynamo calls the backend, it means the bytecode
575581 # transform and analysis are done
@@ -609,6 +615,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
609615 self .vllm_config , self .graph_pool ,
610616 self ).run (* example_inputs )
611617
618+ graph_path = os .path .join (local_cache_dir , "computation_graph.py" )
619+ if not os .path .exists (graph_path ):
620+ # code adapted from https://github.com/thuml/depyf/blob/dab831108a752d1facc00acdd6d4243891845c37/depyf/explain/patched_lazy_format_graph_code.py#L30 # noqa
621+ # use `print_readable` because it can include submodules
622+ src = "from __future__ import annotations\n import torch\n " + \
623+ self .split_gm .print_readable (print_output = False )
624+ src = src .replace ("<lambda>" , "GraphModule" )
625+ with open (graph_path , "w" ) as f :
626+ f .write (src )
627+
628+ logger .debug ("Computation graph saved to %s" , graph_path )
629+
612630 self ._called = True
613631
614632 if not self .compilation_config .use_cudagraph or \
0 commit comments