Skip to content
2 changes: 2 additions & 0 deletions docs/design/torch_compile.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ With all these factors taken into consideration, usually we can guarantee that t

A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all the compilation finishes before we serve any requests. No requests will trigger new compilations. Otherwise, the engine would be blocked on that request, and the response time will have unexpected spikes.

By default, the cache saves compiled artifacts as binary files. If you would like to interact with the generated code for debugging purposes, set the field `compile_cache_save_format=unpacked` in the compilation config, or omit this and set the env variable `VLLM_COMPILE_CACHE_SAVE_FORMAT=unpacked`.

## Python Code Compilation

In the very verbose logs, we can see:
Expand Down
4 changes: 3 additions & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
and hasattr(torch._inductor, "standalone_compile")
):
logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
return InductorStandaloneAdaptor(
compilation_config.compile_cache_save_format
)
else:
logger.debug("Using InductorAdaptor")
return InductorAdaptor()
Expand Down
9 changes: 6 additions & 3 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from typing import Any, Literal
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -175,6 +175,9 @@ class InductorStandaloneAdaptor(CompilerInterface):

name = "inductor_standalone"

def __init__(self, save_format: Literal["binary", "unpacked"]):
self.save_format = save_format

def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
Expand Down Expand Up @@ -220,7 +223,7 @@ def compile(
assert key is not None
path = os.path.join(self.cache_dir, key)
if not envs.VLLM_DISABLE_COMPILE_CACHE:
compiled_graph.save(path=path, format="unpacked")
compiled_graph.save(path=path, format=self.save_format)
compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path)

Expand All @@ -237,7 +240,7 @@ def load(
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked"
path=path, format=self.save_format
)
from torch._inductor.compile_fx import graph_returns_tuple

Expand Down
23 changes: 22 additions & 1 deletion vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from collections.abc import Callable
from dataclasses import asdict, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal

from pydantic import TypeAdapter, field_validator
from pydantic.dataclasses import dataclass

import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config
from vllm.logger import init_logger
Expand Down Expand Up @@ -208,6 +209,15 @@ class CompilationConfig:
"""The directory to store the compiled graph, to accelerate Inductor
compilation. By default, it will use model-related information to generate
a cache directory."""
compile_cache_save_format: Literal["binary", "unpacked"] = field(
default_factory=lambda: envs.VLLM_COMPILE_CACHE_SAVE_FORMAT
)
Comment on lines +212 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this as a config but I think generally env variables override CLI options (so that users can modify from the outside for debugging), see debug_dump_path - could you change this in a follow-up PR?

"""Format for saving torch compile cache:\n
- "binary": saves as binary file (multiprocess safe)\n
- "unpacked": saves as directory structure for inspection/debugging
(NOT multiprocess safe)\n
Defaults to `VLLM_COMPILE_CACHE_SAVE_FORMAT` if not specified.
"""
backend: str = ""
"""The backend for compilation. It needs to be a string:

Expand Down Expand Up @@ -478,6 +488,7 @@ def compute_hash(self) -> str:
factors.append(self.inductor_compile_config)
factors.append(self.inductor_passes)
factors.append(self.pass_config.uuid())
factors.append(self.compile_cache_save_format)
return hashlib.sha256(str(factors).encode()).hexdigest()

def __repr__(self) -> str:
Expand Down Expand Up @@ -519,6 +530,16 @@ def validate_cudagraph_mode_before(cls, value: Any) -> Any:
return CUDAGraphMode[value.upper()]
return value

@field_validator("compile_cache_save_format")
@classmethod
def validate_compile_cache_save_format(cls, value: str) -> str:
if value not in ("binary", "unpacked"):
raise ValueError(
f"compile_cache_save_format must be 'binary' or 'unpacked', "
f"got: {value}"
)
return value

def __post_init__(self) -> None:
if self.level is not None:
logger.warning(
Expand Down
10 changes: 10 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"


def get_default_cache_root():
Expand Down Expand Up @@ -1408,6 +1409,15 @@ def get_vllm_port() -> int | None:
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
),
# Format for saving torch.compile cache artifacts
# - "binary": saves as binary file
# Safe for multiple vllm serve processes accessing the same torch compile cache.
# - "unpacked": saves as directory structure (for inspection/debugging)
# NOT multiprocess safe - race conditions may occur with multiple processes.
# Allows viewing and setting breakpoints in Inductor's code output files.
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
}

# --8<-- [end:env-vars-definition]
Expand Down