Skip to content

Commit c45c3e9

Browse files
quic-amitrajquic-rishinrochougul
authored
Disaggregated serving (quic#365)
Adding support of- 1. `prefill_only` 2. `compile_for` for VLM 3. `mdp_ts_json_path` --------- Signed-off-by: Rishin Raj <[email protected]> Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Co-authored-by: Rishin Raj <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Signed-off-by: eplatero <[email protected]>
1 parent 5ab45f8 commit c45c3e9

File tree

4 files changed

+171
-218
lines changed

4 files changed

+171
-218
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,11 @@ def _compile(
245245
qpc_path = compile_dir / "qpc"
246246
if not onnx_path.is_file():
247247
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
248-
249248
command = constants.COMPILER + [f"-m={onnx_path}"]
249+
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
250+
mdp_ts_num_devices = None
251+
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
252+
250253
for key, value in compiler_options.items():
251254
option = "-" + key.replace("_", "-")
252255
if isinstance(value, bool):
@@ -262,9 +265,6 @@ def _compile(
262265
if custom_io is not None:
263266
compile_hash.update(to_hashable(custom_io))
264267

265-
if mdp_ts_num_devices > 1:
266-
compile_hash.update(to_hashable({"mdp_ts_num_devices": mdp_ts_num_devices}))
267-
268268
if num_speculative_tokens:
269269
compile_hash.update(to_hashable({"num_speculative_tokens": num_speculative_tokens}))
270270

@@ -300,7 +300,7 @@ def _compile(
300300
command.append(f"-custom-IO-list-file={custom_io_yaml}")
301301

302302
# Write mdp_config.json file
303-
if mdp_ts_num_devices > 1:
303+
if not mdp_ts_json_path and mdp_ts_num_devices > 1:
304304
num_cores = compiler_options.get("aic_num_cores", 16)
305305
mdp_ts_json = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
306306
with open(mdp_ts_json, "w") as fp:

QEfficient/transformers/models/modeling_auto.py

Lines changed: 135 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,12 @@ def onnx_path(self):
561561

562562
@property
563563
def qpc_path(self):
564-
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
564+
if self.vision_model.qpc_path and self.lang_model.qpc_path:
565+
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
566+
elif self.vision_model.qpc_path:
567+
return self.vision_model.qpc_path
568+
else:
569+
return self.lang_model.qpc_path
565570

566571
def export(
567572
self,
@@ -600,6 +605,8 @@ def compile(
600605
num_speculative_tokens: Optional[int] = None,
601606
enable_qnn: bool = False,
602607
qnn_config: Optional[str] = None,
608+
skip_vision: Optional[bool] = False,
609+
skip_lang: Optional[bool] = False,
603610
**compiler_options,
604611
) -> str:
605612
if (
@@ -615,6 +622,9 @@ def compile(
615622
f"enable_qnn={enable_qnn}, qnn_config={qnn_config}"
616623
)
617624

625+
if skip_lang and skip_vision:
626+
raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False")
627+
618628
output_names = self.model.get_output_names(kv_offload=True)
619629

620630
specializations, compiler_options = self.model.get_specializations(
@@ -642,41 +652,43 @@ def compile(
642652
):
643653
self.export()
644654

645-
self.vision_model._compile(
646-
compile_dir,
647-
compile_only=True,
648-
specializations=specializations["vision"],
649-
convert_to_fp16=True,
650-
mxfp6_matmul=mxfp6_matmul,
651-
mdp_ts_num_devices=num_devices,
652-
aic_num_cores=num_cores,
653-
custom_io=custom_io_vision,
654-
**compiler_options,
655-
)
655+
if not skip_vision:
656+
self.vision_model._compile(
657+
compile_dir,
658+
compile_only=True,
659+
specializations=specializations["vision"],
660+
convert_to_fp16=True,
661+
mxfp6_matmul=mxfp6_matmul,
662+
mdp_ts_num_devices=num_devices,
663+
aic_num_cores=num_cores,
664+
custom_io=custom_io_vision,
665+
**compiler_options,
666+
)
656667

657-
custom_io_lang = {}
658-
# Inputs
659-
for output_name in output_names["lang"]:
660-
if output_name.endswith("_RetainedState"):
661-
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
668+
if not skip_lang:
669+
custom_io_lang = {}
670+
# Inputs
671+
for output_name in output_names["lang"]:
672+
if output_name.endswith("_RetainedState"):
673+
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
662674

663-
# outputs
664-
for output_name in output_names["lang"]:
665-
if output_name.endswith("_RetainedState"):
666-
custom_io_lang[output_name] = kv_cache_dtype
675+
# outputs
676+
for output_name in output_names["lang"]:
677+
if output_name.endswith("_RetainedState"):
678+
custom_io_lang[output_name] = kv_cache_dtype
667679

668-
self.lang_model._compile(
669-
compile_dir,
670-
compile_only=True,
671-
retained_state=True,
672-
specializations=specializations["lang"],
673-
convert_to_fp16=True,
674-
mxfp6_matmul=mxfp6_matmul,
675-
mdp_ts_num_devices=num_devices,
676-
aic_num_cores=num_cores,
677-
custom_io=custom_io_lang,
678-
**compiler_options,
679-
)
680+
self.lang_model._compile(
681+
compile_dir,
682+
compile_only=True,
683+
retained_state=True,
684+
specializations=specializations["lang"],
685+
convert_to_fp16=True,
686+
mxfp6_matmul=mxfp6_matmul,
687+
mdp_ts_num_devices=num_devices,
688+
aic_num_cores=num_cores,
689+
custom_io=custom_io_lang,
690+
**compiler_options,
691+
)
680692
return self.qpc_path
681693

682694
def generate(
@@ -711,6 +723,9 @@ def kv_offload_generate(
711723
device_ids: List[int] = None,
712724
generation_len: int = None,
713725
):
726+
if not self.vision_model.qpc_path or not self.lang_model.qpc_path:
727+
raise TypeError("Please run compile API for vision and language model first!")
728+
714729
lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False)
715730

716731
vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids)
@@ -1461,6 +1476,51 @@ def export(self, export_dir: Optional[str] = None) -> str:
14611476
export_dir=export_dir,
14621477
)
14631478

1479+
def build_prefill_specialization(
1480+
self,
1481+
prefill_seq_len: int = 32,
1482+
ctx_len: int = 128,
1483+
batch_size: int = 1,
1484+
kv_cache_batch_size: Optional[int] = None,
1485+
full_batch_size: Optional[int] = None,
1486+
):
1487+
spec = {
1488+
"batch_size": 1 if self.continuous_batching else batch_size,
1489+
"seq_len": prefill_seq_len,
1490+
"ctx_len": ctx_len,
1491+
"num_logits_to_keep": 1 if self.is_tlm else None,
1492+
}
1493+
if self.continuous_batching:
1494+
spec["full_batch_size"] = kv_cache_batch_size
1495+
else:
1496+
spec["batch_size"] = kv_cache_batch_size
1497+
if full_batch_size:
1498+
spec["full_batch_exec_size"] = full_batch_size
1499+
return {k: v for k, v in spec.items() if v is not None}
1500+
1501+
def build_decode_specialization(
1502+
self,
1503+
prefill_seq_len: int = 32,
1504+
ctx_len: int = 128,
1505+
batch_size: int = 1,
1506+
kv_cache_batch_size: Optional[int] = None,
1507+
full_batch_size: Optional[int] = None,
1508+
num_speculative_tokens: Optional[int] = None,
1509+
):
1510+
if prefill_seq_len == 1 and not self.continuous_batching:
1511+
return None # Avoid duplication with prefill
1512+
spec = {
1513+
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1514+
"seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1,
1515+
"ctx_len": ctx_len,
1516+
"num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
1517+
}
1518+
if self.continuous_batching:
1519+
spec["full_batch_size"] = kv_cache_batch_size
1520+
else:
1521+
spec["batch_size"] = kv_cache_batch_size
1522+
return {k: v for k, v in spec.items() if v is not None}
1523+
14641524
def compile(
14651525
self,
14661526
onnx_path: Optional[str] = None,
@@ -1478,6 +1538,7 @@ def compile(
14781538
num_speculative_tokens: Optional[int] = None,
14791539
enable_qnn: bool = False,
14801540
qnn_config: Optional[str] = None,
1541+
prefill_only: Optional[bool] = None,
14811542
**compiler_options,
14821543
) -> str:
14831544
"""
@@ -1501,74 +1562,63 @@ def compile(
15011562
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
15021563
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
15031564
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1565+
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1566+
:compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
15041567
15051568
Returns:
15061569
:str: Path of the compiled ``qpc`` package.
15071570
"""
1571+
# --- Validation ---
1572+
if prefill_only is not None and not isinstance(prefill_only, bool):
1573+
raise TypeError("`prefill_only` must be a boolean.")
1574+
15081575
if self.is_tlm:
1509-
# assert num_speculative_tokens cfg is acceptable if defined
15101576
if num_speculative_tokens is None:
1511-
raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
1512-
if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
1513-
ValueError(
1514-
f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
1515-
)
1516-
num_logits_to_keep = num_speculative_tokens + 1
1517-
if prefill_seq_len < num_logits_to_keep:
1577+
raise TypeError("`num_speculative_tokens` is required when `is_tlm=True`.")
1578+
if not isinstance(num_speculative_tokens, int) or num_speculative_tokens < 2:
1579+
raise ValueError("`num_speculative_tokens` must be an integer >= 2.")
1580+
if prefill_seq_len < (num_speculative_tokens + 1):
15181581
raise ValueError(
1519-
f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
1582+
f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` "
1583+
f"({num_speculative_tokens + 1}), got {prefill_seq_len}."
15201584
)
15211585

15221586
if self.continuous_batching and full_batch_size is None:
1523-
raise TypeError("missing required argument: 'full_batch_size'")
1587+
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
15241588

15251589
if kv_cache_batch_size and not full_batch_size:
15261590
raise ValueError(
1527-
"Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
1591+
"KV caching requires continuous batching. Please set `full_batch_size` and "
1592+
"enable `continuous_batching=True` in `from_pretrained`."
15281593
)
15291594

1530-
kv_cache_batch_size = (
1531-
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
1532-
)
1533-
# Define prefill specialization
1534-
prefill_specialization = {
1535-
# Prefill is always run with single BS for continuous batching.
1536-
"batch_size": 1 if self.continuous_batching else batch_size,
1537-
"seq_len": prefill_seq_len,
1538-
"ctx_len": ctx_len,
1539-
# TODO: should be renamed to kv_cache_batch_size in specialization too
1540-
}
1541-
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
1542-
if self.continuous_batching:
1543-
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
1544-
else:
1545-
prefill_specialization.update({"batch_size": kv_cache_batch_size})
1546-
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
1547-
specializations = [
1548-
prefill_specialization,
1549-
]
1595+
# Infer kv_cache_batch_size if not provided
1596+
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
15501597

1551-
# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1552-
if prefill_seq_len != 1 or self.continuous_batching:
1553-
decode_specialization = {
1554-
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1555-
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
1556-
"ctx_len": ctx_len,
1557-
}
1558-
if self.continuous_batching:
1559-
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
1560-
else:
1561-
decode_specialization.update({"batch_size": kv_cache_batch_size})
1562-
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
1563-
specializations.append(decode_specialization)
1598+
# --- Specializations ---
1599+
specializations = []
1600+
1601+
if prefill_only is None or prefill_only or prefill_seq_len == 1:
1602+
specializations.append(
1603+
self.build_prefill_specialization(
1604+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
1605+
)
1606+
)
1607+
if prefill_only is None or not prefill_only:
1608+
decode_spec = self.build_decode_specialization(
1609+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
1610+
)
1611+
if decode_spec:
1612+
specializations.append(decode_spec)
15641613

1614+
# --- Compilation ---
15651615
if enable_qnn:
15661616
if compiler_options:
1567-
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1617+
logger.warning("Extra arguments to QNN compilation are ignored. Use `qnn_config.json`.")
15681618

15691619
qpc_path = self._qnn_compile(
1570-
onnx_path,
1571-
compile_dir,
1620+
onnx_path=onnx_path,
1621+
compile_dir=compile_dir,
15721622
specializations=specializations,
15731623
prefill_seq_len=prefill_seq_len,
15741624
ctx_len=ctx_len,
@@ -1582,17 +1632,17 @@ def compile(
15821632
kv_cache_batch_size=kv_cache_batch_size,
15831633
)
15841634
else:
1585-
# Custom IO
1586-
custom_io = {}
15871635
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1636+
custom_io = {}
1637+
15881638
for suffix in ["", "_RetainedState"]:
15891639
for i in range(self.num_layers):
15901640
for kv in ["key", "value"]:
15911641
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
15921642

15931643
qpc_path = self._compile(
1594-
onnx_path,
1595-
compile_dir,
1644+
onnx_path=onnx_path,
1645+
compile_dir=compile_dir,
15961646
compile_only=True,
15971647
retained_state=True,
15981648
specializations=specializations,
@@ -1604,6 +1654,7 @@ def compile(
16041654
aic_num_cores=num_cores,
16051655
**compiler_options,
16061656
)
1657+
16071658
return qpc_path
16081659

16091660
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate

QEfficient/utils/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ def wrapper(self, *args, **kwargs):
466466
**{
467467
k: v
468468
for k, v in kwargs.items()
469-
if k not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io"]
469+
if k
470+
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
470471
},
471472
)
472473
return result

0 commit comments

Comments
 (0)