Skip to content

Commit baf7506

Browse files
authored
[worker] fix: support for vllm V0 deprecation version (#3687)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Related to: - vllm-project/vllm#25901 - vllm-project/vllm#25345 Now we first try to import `WorkerWrapperBase` from `vllm.worker.worker_base`, if we have an error, we append `v1` there. For `compute_logits` patch, we can just remove the import of `SamplingMetadata`, create a wrapper that accepts any arguments with *args, **kwargs, and pass them through to the original method, so that it can be more flexible and future-proof. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [X] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [X] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [X] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
1 parent 798a6f8 commit baf7506

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@
5252
from vllm import LLM, SamplingParams
5353
from vllm.config import CompilationConfig, CompilationLevel, LoRAConfig
5454
from vllm.lora.request import LoRARequest
55-
from vllm.model_executor.sampling_metadata import SamplingMetadata
56-
from vllm.worker.worker_base import WorkerWrapperBase
55+
56+
try:
57+
from vllm.worker.worker_base import WorkerWrapperBase
58+
except ModuleNotFoundError:
59+
# https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b
60+
from vllm.v1.worker.worker_base import WorkerWrapperBase
5761

5862
from verl import DataProto
5963
from verl.third_party.vllm import VLLM_SLEEP_LEVEL
@@ -459,10 +463,10 @@ def _monkey_patch_compute_logits(model, vocab_size: int):
459463

460464
def compute_logits(
461465
self,
462-
hidden_states: torch.Tensor,
463-
sampling_metadata: SamplingMetadata,
466+
*args,
467+
**kwargs,
464468
) -> torch.Tensor:
465-
logits = original_compute_logits(hidden_states, sampling_metadata)
469+
logits = original_compute_logits(*args, **kwargs)
466470
logits[..., vocab_size:] = float("-inf")
467471
return logits
468472

0 commit comments

Comments
 (0)