|
1 | 1 | from vllm.logger import init_logger |
2 | 2 | from vllm.platforms import current_platform |
| 3 | +from vllm.utils import resolve_obj_by_qualname |
3 | 4 |
|
4 | 5 | from .punica_base import PunicaWrapperBase |
5 | 6 |
|
6 | 7 | logger = init_logger(__name__) |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: |
10 | | - if current_platform.is_cuda_alike(): |
11 | | - # Lazy import to avoid ImportError |
12 | | - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU |
13 | | - logger.info_once("Using PunicaWrapperGPU.") |
14 | | - return PunicaWrapperGPU(*args, **kwargs) |
15 | | - elif current_platform.is_cpu(): |
16 | | - # Lazy import to avoid ImportError |
17 | | - from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU |
18 | | - logger.info_once("Using PunicaWrapperCPU.") |
19 | | - return PunicaWrapperCPU(*args, **kwargs) |
20 | | - elif current_platform.is_hpu(): |
21 | | - # Lazy import to avoid ImportError |
22 | | - from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU |
23 | | - logger.info_once("Using PunicaWrapperHPU.") |
24 | | - return PunicaWrapperHPU(*args, **kwargs) |
25 | | - else: |
26 | | - raise NotImplementedError |
| 11 | + punica_wrapper_qualname = current_platform.get_punica_wrapper() |
| 12 | + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) |
| 13 | + punica_wrapper = punica_wrapper_cls(*args, **kwargs) |
| 14 | + assert punica_wrapper is not None, \ |
| 15 | + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." |
| 16 | + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + |
| 17 | + ".") |
| 18 | + return punica_wrapper |
0 commit comments