-
Notifications
You must be signed in to change notification settings - Fork 563
Upgrade to 0.11.1 newest vllm commit #3762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7c50b3e
Upgrade to 0.11.1 newest vllm commit
wxsIcey 94c9125
change commit and fix send_delta_data
wxsIcey c2dc165
fix init_with_cudagraph_sizes
wxsIcey 6ba3f39
skit embed aclgraph e2e
wxsIcey e8849b4
fix init_with_cudagraph_sizes
wxsIcey 0ca98f5
change commit id to 0.11.1
wxsIcey e8f87f6
tiny fix
wxsIcey 8e82843
fix eagle
wxsIcey c166742
fix aclgraph
wxsIcey 89db007
skip test_embedding_aclgraph test
wxsIcey 28b1306
tiny fix
wxsIcey 445650b
fix vl
wxsIcey 5b09cc0
tiny fix
wxsIcey File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,8 @@ | |
| delete_torchair_cache_file) | ||
| from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, | ||
| prefill_context_parallel_enable, | ||
| update_aclgraph_sizes, vllm_version_is) | ||
| update_aclgraph_sizes, | ||
| update_cudagraph_capture_sizes, vllm_version_is) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import ModelConfig, VllmConfig | ||
|
|
@@ -142,24 +143,47 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
| "Non-MLA LLMs forcibly disable the chunked prefill feature," | ||
| "as the performance of operators supporting this feature " | ||
| "functionality is currently suboptimal.") | ||
| if not model_config.is_multimodal_model and \ | ||
| structured_outputs_config.backend == "auto" and \ | ||
| not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ | ||
| scheduler_config.policy == "fcfs": | ||
| ascend_scheduler_config.enabled = True | ||
| chunked_prefill_enabled_in_ascend_scheduler = getattr( | ||
| ascend_scheduler_config, "enable_chunked_prefill", False) | ||
| if chunked_prefill_enabled_in_ascend_scheduler: | ||
| logger.warning( | ||
| "Chunked prefill feature is enabled in ascend_scheduler," | ||
| "but note that the operator supporting this feature " | ||
| "would lead to performance degradation.") | ||
| # In this situation, max_num_batched_tokens would have been rewritten. | ||
| # So we must make sure max_num_batched_tokens is not smaller than max_model_len. | ||
| if (scheduler_config.max_num_batched_tokens | ||
| < scheduler_config.max_model_len | ||
| and not chunked_prefill_enabled_in_ascend_scheduler): | ||
| scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len | ||
| if vllm_version_is("0.11.0"): | ||
| if not model_config.is_multimodal_model and \ | ||
| structured_outputs_config.backend == "auto" and \ | ||
| not scheduler_config.send_delta_data and \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. getattr(scheduler_config, "send_delta_data", False) |
||
| not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ | ||
| scheduler_config.policy == "fcfs": | ||
| ascend_scheduler_config.enabled = True | ||
| chunked_prefill_enabled_in_ascend_scheduler = getattr( | ||
| ascend_scheduler_config, "enable_chunked_prefill", | ||
| False) | ||
| if chunked_prefill_enabled_in_ascend_scheduler: | ||
| logger.warning( | ||
| "Chunked prefill feature is enabled in ascend_scheduler," | ||
| "but note that the operator supporting this feature " | ||
| "would lead to performance degradation.") | ||
| # In this situation, max_num_batched_tokens would have been rewritten. | ||
| # So we must make sure max_num_batched_tokens is not smaller than max_model_len. | ||
| if (scheduler_config.max_num_batched_tokens | ||
| < scheduler_config.max_model_len and | ||
| not chunked_prefill_enabled_in_ascend_scheduler): | ||
| scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len | ||
| else: | ||
| if not model_config.is_multimodal_model and \ | ||
| structured_outputs_config.backend == "auto" and \ | ||
| not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ | ||
| scheduler_config.policy == "fcfs": | ||
| ascend_scheduler_config.enabled = True | ||
| chunked_prefill_enabled_in_ascend_scheduler = getattr( | ||
| ascend_scheduler_config, "enable_chunked_prefill", | ||
| False) | ||
| if chunked_prefill_enabled_in_ascend_scheduler: | ||
| logger.warning( | ||
| "Chunked prefill feature is enabled in ascend_scheduler," | ||
| "but note that the operator supporting this feature " | ||
| "would lead to performance degradation.") | ||
| # In this situation, max_num_batched_tokens would have been rewritten. | ||
| # So we must make sure max_num_batched_tokens is not smaller than max_model_len. | ||
| if (scheduler_config.max_num_batched_tokens | ||
| < scheduler_config.max_model_len and | ||
| not chunked_prefill_enabled_in_ascend_scheduler): | ||
| scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len | ||
|
|
||
| kv_cache_dtype = vllm_config.additional_config.get( | ||
| "kv_cache_dtype", None) | ||
|
|
@@ -237,8 +261,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
| f"{vllm_config.parallel_config.tensor_parallel_size}") | ||
| if len(sp_aclgraph_sizes) != len(original_sizes): | ||
| compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes | ||
| vllm_config.compilation_config.init_with_cudagraph_sizes( | ||
| sp_aclgraph_sizes) | ||
| if vllm_version_is("0.11.0"): | ||
| compilation_config.init_with_cudagraph_sizes( | ||
| sp_aclgraph_sizes) | ||
| else: | ||
| update_cudagraph_capture_sizes(vllm_config, | ||
| sp_aclgraph_sizes) | ||
|
|
||
| # TODO: Full graph is fully supported later, and the default value will be set to full graph. | ||
| if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -311,6 +311,41 @@ def _rec_find(d): | |
| return max(layer_counts) | ||
|
|
||
|
|
||
| # Update cudagraph capture sizes for vllm config | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is maybe not correct. I'll look more |
||
| def update_cudagraph_capture_sizes(vllm_config: VllmConfig, | ||
| cudagraph_capture_sizes: List[int]): | ||
|
|
||
| valid_max_size = (cudagraph_capture_sizes[-1] | ||
| if cudagraph_capture_sizes else 0) | ||
| if (vllm_config.compilation_config.max_cudagraph_capture_size is not None | ||
| and vllm_config.compilation_config.max_cudagraph_capture_size | ||
| != valid_max_size): | ||
| if vllm_config.compilation_config.cudagraph_capture_sizes is not None: | ||
| raise ValueError( | ||
| "customized max_cudagraph_capture_size" | ||
| f"(={vllm_config.compilation_config.max_cudagraph_capture_size}) " | ||
| "should be consistent with the max value of " | ||
| f"cudagraph_capture_sizes(={valid_max_size})") | ||
| logger.warning( | ||
| "Truncating max_cudagraph_capture_size to %d", | ||
| valid_max_size, | ||
| ) | ||
|
|
||
| vllm_config.compilation_config.max_cudagraph_capture_size = valid_max_size | ||
|
|
||
| if vllm_config.compilation_config.cudagraph_capture_sizes is not None and len( | ||
| cudagraph_capture_sizes) < len( | ||
| vllm_config.compilation_config.cudagraph_capture_sizes): | ||
| logger.warning( | ||
| ("cudagraph_capture_sizes specified in compilation_config" | ||
| " %s is overridden by config %s"), | ||
| vllm_config.compilation_config.cudagraph_capture_sizes, | ||
| cudagraph_capture_sizes, | ||
| ) | ||
| vllm_config.compilation_config.cudagraph_capture_sizes = cudagraph_capture_sizes | ||
| vllm_config.compilation_config.post_init_cudagraph_sizes() | ||
|
|
||
|
|
||
| def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: | ||
| """Update ACL graph capture sizes based on hardware limitations""" | ||
| # NOTE: Currently, we can only capture 1800 graphs at most, | ||
|
|
@@ -402,7 +437,10 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: | |
| indices[0], indices[-1] = 0, len(original_sizes) - 1 | ||
|
|
||
| sampled_sizes = [original_sizes[i] for i in indices] | ||
| compilation_config.init_with_cudagraph_sizes(sampled_sizes) | ||
| if vllm_version_is("0.11.0"): | ||
| compilation_config.init_with_cudagraph_sizes(sampled_sizes) | ||
| else: | ||
| update_cudagraph_capture_sizes(vllm_config, sampled_sizes) | ||
|
|
||
| logger.info( | ||
| "Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes", | ||
|
|
@@ -433,7 +471,10 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: | |
| if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: | ||
| enlarged_sizes = [(num_speculative_tokens + 1) * size | ||
| for size in original_sizes] | ||
| compilation_config.init_with_cudagraph_sizes(enlarged_sizes) | ||
| if vllm_version_is("0.11.0"): | ||
| compilation_config.init_with_cudagraph_sizes(enlarged_sizes) | ||
| else: | ||
| update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) | ||
| logger.info( | ||
| "Adjusted ACL graphs: %s → %s for speculative decoding", | ||
| original_sizes, enlarged_sizes) | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.