|
8 | 8 | from vllm.config import ParallelConfig |
9 | 9 | from vllm.executor.msgspec_utils import decode_hook, encode_hook |
10 | 10 | from vllm.logger import init_logger |
| 11 | +from vllm.platforms import current_platform |
11 | 12 | from vllm.sequence import ExecuteModelRequest, IntermediateTensors |
12 | 13 | from vllm.utils import get_ip |
13 | 14 | from vllm.worker.worker_base import WorkerWrapperBase |
@@ -47,7 +48,12 @@ def get_node_ip(self) -> str: |
47 | 48 |
|
48 | 49 | def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: |
49 | 50 | node_id = ray.get_runtime_context().get_node_id() |
50 | | - gpu_ids = ray.get_gpu_ids() |
| 51 | + device_key = current_platform.ray_device_key |
| 52 | + if not device_key: |
| 53 | + raise RuntimeError("current platform %s does not support ray.", |
| 54 | + current_platform.device_name) |
| 55 | + gpu_ids = ray.get_runtime_context().get_accelerator_ids( |
| 56 | + )[device_key] |
51 | 57 | return node_id, gpu_ids |
52 | 58 |
|
53 | 59 | def execute_model_spmd( |
@@ -249,11 +255,12 @@ def initialize_ray_cluster( |
249 | 255 | # Placement group is already set. |
250 | 256 | return |
251 | 257 |
|
252 | | - device_str = "GPU" |
253 | | - if current_platform.is_tpu(): |
254 | | - device_str = "TPU" |
255 | | - elif current_platform.is_hpu(): |
256 | | - device_str = 'HPU' |
| 258 | + device_str = current_platform.ray_device_key |
| 259 | + if not device_str: |
| 260 | + raise ValueError( |
| 261 | + f"current platform {current_platform.device_name} does not " |
| 262 | + "support ray.") |
| 263 | + |
257 | 264 | # Create placement group for worker processes |
258 | 265 | current_placement_group = ray.util.get_current_placement_group() |
259 | 266 | if current_placement_group: |
|
0 commit comments