Skip to content

Commit 89ce62a

Browse files
authored
[platform] add ray_device_key (#11948)
Signed-off-by: youkaichao <[email protected]>
1 parent c3f05b0 commit 89ce62a

File tree

9 files changed

+38
-8
lines changed

9 files changed

+38
-8
lines changed

vllm/executor/ray_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.config import ParallelConfig
99
from vllm.executor.msgspec_utils import decode_hook, encode_hook
1010
from vllm.logger import init_logger
11+
from vllm.platforms import current_platform
1112
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
1213
from vllm.utils import get_ip
1314
from vllm.worker.worker_base import WorkerWrapperBase
@@ -47,7 +48,12 @@ def get_node_ip(self) -> str:
4748

4849
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
4950
node_id = ray.get_runtime_context().get_node_id()
50-
gpu_ids = ray.get_gpu_ids()
51+
device_key = current_platform.ray_device_key
52+
if not device_key:
53+
raise RuntimeError("current platform %s does not support ray.",
54+
current_platform.device_name)
55+
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
56+
)[device_key]
5157
return node_id, gpu_ids
5258

5359
def execute_model_spmd(
@@ -249,11 +255,12 @@ def initialize_ray_cluster(
249255
# Placement group is already set.
250256
return
251257

252-
device_str = "GPU"
253-
if current_platform.is_tpu():
254-
device_str = "TPU"
255-
elif current_platform.is_hpu():
256-
device_str = 'HPU'
258+
device_str = current_platform.ray_device_key
259+
if not device_str:
260+
raise ValueError(
261+
f"current platform {current_platform.device_name} does not "
262+
"support ray.")
263+
257264
# Create placement group for worker processes
258265
current_placement_group = ray.util.get_current_placement_group()
259266
if current_placement_group:

vllm/platforms/cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class CudaPlatformBase(Platform):
7777
device_name: str = "cuda"
7878
device_type: str = "cuda"
7979
dispatch_key: str = "CUDA"
80+
ray_device_key: str = "GPU"
8081

8182
@classmethod
8283
def get_device_capability(cls,

vllm/platforms/hpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class HpuPlatform(Platform):
1919
device_name: str = "hpu"
2020
device_type: str = "hpu"
2121
dispatch_key: str = "HPU"
22+
ray_device_key: str = "HPU"
2223

2324
@classmethod
2425
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,

vllm/platforms/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class Platform:
8282
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
8383
# use "CPU" as a fallback for platforms not registered in PyTorch
8484
dispatch_key: str = "CPU"
85+
# available ray device keys:
86+
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
87+
# empty string means the device does not support ray
88+
ray_device_key: str = ""
8589
# The torch.compile backend for compiling simple and
8690
# standalone functions. The default value is "inductor" to keep
8791
# the same behavior as PyTorch.

vllm/platforms/neuron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class NeuronPlatform(Platform):
1616
_enum = PlatformEnum.NEURON
1717
device_name: str = "neuron"
1818
device_type: str = "neuron"
19+
ray_device_key: str = "neuron_cores"
1920
supported_quantization: list[str] = ["neuron_quant"]
2021

2122
@classmethod

vllm/platforms/rocm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class RocmPlatform(Platform):
6464
device_name: str = "rocm"
6565
device_type: str = "cuda"
6666
dispatch_key: str = "CUDA"
67+
ray_device_key: str = "GPU"
68+
6769
supported_quantization: list[str] = [
6870
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
6971
"fbgemm_fp8", "gguf"

vllm/platforms/tpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class TpuPlatform(Platform):
1919
device_name: str = "tpu"
2020
device_type: str = "tpu"
2121
dispatch_key: str = "XLA"
22+
ray_device_key: str = "TPU"
23+
2224
supported_quantization: list[str] = [
2325
"tpu_int8", "compressed-tensors", "compressed_tensors"
2426
]

vllm/platforms/xpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class XPUPlatform(Platform):
1919
device_name: str = "xpu"
2020
device_type: str = "xpu"
2121
dispatch_key: str = "XPU"
22+
# Intel XPU's device key is "GPU" for Ray.
23+
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
24+
ray_device_key: str = "GPU"
2225

2326
@classmethod
2427
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,

vllm/v1/executor/ray_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ def get_node_ip(self) -> str:
4141

4242
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
4343
node_id = ray.get_runtime_context().get_node_id()
44-
gpu_ids = ray.get_gpu_ids()
44+
device_key = current_platform.ray_device_key
45+
if not device_key:
46+
raise RuntimeError("current platform %s does not support ray.",
47+
current_platform.device_name)
48+
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
49+
)[device_key]
4550
return node_id, gpu_ids
4651

4752
def setup_device_if_necessary(self):
@@ -211,7 +216,11 @@ def initialize_ray_cluster(
211216
# Placement group is already set.
212217
return
213218

214-
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
219+
device_str = current_platform.ray_device_key
220+
if not device_str:
221+
raise ValueError(
222+
f"current platform {current_platform.device_name} does not "
223+
"support ray.")
215224
# Create placement group for worker processes
216225
current_placement_group = ray.util.get_current_placement_group()
217226
if current_placement_group:

0 commit comments

Comments
 (0)