Skip to content

Commit c5e9513

Browse files
quic-amitrajeplatero97
authored andcommitted
Disaggregated serving (quic#365)
Adding support of- 1. `prefill_only` 2. `compile_for` for VLM 3. `mdp_ts_json_path` --------- Signed-off-by: Rishin Raj <[email protected]> Signed-off-by: Amit Raj <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Co-authored-by: Rishin Raj <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Co-authored-by: Onkar Chougule <[email protected]> Signed-off-by: eplatero <[email protected]>
1 parent 25830a2 commit c5e9513

File tree

3 files changed

+132
-178
lines changed

3 files changed

+132
-178
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,6 @@ def _compile(
254254
qpc_path = compile_dir / "qpc"
255255
if not onnx_path.is_file():
256256
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
257-
258-
if enable_qnn:
259-
self.qpc_path = qnn_compile(
260-
onnx_path=onnx_path,
261-
qpc_base_path=compile_dir,
262-
specializations=specializations,
263-
custom_io=custom_io,
264-
device_group=list(range(mdp_ts_num_devices)),
265-
num_cores=compiler_options.get("aic_num_cores", 16),
266-
mxfp6=compiler_options.get("mxfp6_matmul", False),
267-
mxint8=mxint8_kv_cache,
268-
qnn_config=qnn_config,
269-
)
270-
271-
return self.qpc_path
272-
273257
command = constants.COMPILER + [f"-m={onnx_path}"]
274258
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
275259
mdp_ts_num_devices = None

QEfficient/transformers/models/modeling_auto.py

Lines changed: 113 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,8 @@ def compile(
603603
mxfp6_matmul: bool = False,
604604
mxint8_kv_cache: bool = False,
605605
num_speculative_tokens: Optional[int] = None,
606-
skip_vision: Optional[bool] = False,
607-
skip_lang: Optional[bool] = False,
606+
enable_qnn: bool = False,
607+
qnn_config: Optional[str] = None,
608608
**compiler_options,
609609
) -> str:
610610
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
@@ -646,47 +646,41 @@ def compile(
646646
):
647647
self.export()
648648

649-
if not skip_vision:
650-
self.vision_model._compile(
651-
compile_dir,
652-
compile_only=True,
653-
specializations=specializations["vision"],
654-
convert_to_fp16=True,
655-
mxfp6_matmul=mxfp6_matmul,
656-
mdp_ts_num_devices=num_devices,
657-
aic_num_cores=num_cores,
658-
custom_io=custom_io_vision,
659-
mxint8_kv_cache=mxint8_kv_cache,
660-
**compiler_options,
661-
)
649+
self.vision_model._compile(
650+
compile_dir,
651+
compile_only=True,
652+
specializations=specializations["vision"],
653+
convert_to_fp16=True,
654+
mxfp6_matmul=mxfp6_matmul,
655+
mdp_ts_num_devices=num_devices,
656+
aic_num_cores=num_cores,
657+
custom_io=custom_io_vision,
658+
**compiler_options,
659+
)
662660

663-
if not skip_lang:
664-
custom_io_lang = {}
665-
# Inputs
666-
for output_name in output_names["lang"]:
667-
if output_name.endswith("_RetainedState"):
668-
custom_io_lang[output_name[: -len("_RetainedState")]] = (
669-
"float16" if "vision_embeds" in output_name else kv_cache_dtype
670-
)
671-
672-
# outputs
673-
for output_name in output_names["lang"]:
674-
if output_name.endswith("_RetainedState"):
675-
custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype
676-
677-
self.lang_model._compile(
678-
compile_dir,
679-
compile_only=True,
680-
retained_state=True,
681-
specializations=specializations["lang"],
682-
convert_to_fp16=True,
683-
mxfp6_matmul=mxfp6_matmul,
684-
mdp_ts_num_devices=num_devices,
685-
aic_num_cores=num_cores,
686-
custom_io=custom_io_lang,
687-
mxint8_kv_cache=mxint8_kv_cache,
688-
**compiler_options,
689-
)
661+
custom_io_lang = {}
662+
# Inputs
663+
for output_name in output_names["lang"]:
664+
if output_name.endswith("_RetainedState"):
665+
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
666+
667+
# outputs
668+
for output_name in output_names["lang"]:
669+
if output_name.endswith("_RetainedState"):
670+
custom_io_lang[output_name] = kv_cache_dtype
671+
672+
self.lang_model._compile(
673+
compile_dir,
674+
compile_only=True,
675+
retained_state=True,
676+
specializations=specializations["lang"],
677+
convert_to_fp16=True,
678+
mxfp6_matmul=mxfp6_matmul,
679+
mdp_ts_num_devices=num_devices,
680+
aic_num_cores=num_cores,
681+
custom_io=custom_io_lang,
682+
**compiler_options,
683+
)
690684
return self.qpc_path
691685

692686
def generate(
@@ -1547,7 +1541,8 @@ def compile(
15471541
mxfp6_matmul: bool = False,
15481542
mxint8_kv_cache: bool = False,
15491543
num_speculative_tokens: Optional[int] = None,
1550-
prefill_only: Optional[bool] = None,
1544+
enable_qnn: bool = False,
1545+
qnn_config: Optional[str] = None,
15511546
**compiler_options,
15521547
) -> str:
15531548
"""
@@ -1569,14 +1564,8 @@ def compile(
15691564
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
15701565
:mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
15711566
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1572-
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1573-
:compiler_options (dict, optional): Pass any compiler option as input. ``Defaults to None``.
1574-
Following flag can be passed in compiler_options to enable QNN Compilation path.
1575-
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
1576-
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
1577-
for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
1578-
- aic_num_cores=16 -> -aic-num-cores=16
1579-
- convert_to_fp16=True -> -convert-to-fp16
1567+
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1568+
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
15801569
15811570
Returns:
15821571
:str: Path of the compiled ``qpc`` package.
@@ -1598,50 +1587,83 @@ def compile(
15981587
"enable `continuous_batching=True` in `from_pretrained`."
15991588
)
16001589

1601-
# Infer kv_cache_batch_size if not provided
1602-
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1590+
kv_cache_batch_size = (
1591+
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
1592+
)
1593+
# Define prefill specialization
1594+
prefill_specialization = {
1595+
# Prefill is always run with single BS for continuous batching.
1596+
"batch_size": 1 if self.continuous_batching else batch_size,
1597+
"seq_len": prefill_seq_len,
1598+
"ctx_len": ctx_len,
1599+
# TODO: should be renamed to kv_cache_batch_size in specialization too
1600+
}
1601+
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
1602+
if self.continuous_batching:
1603+
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
1604+
else:
1605+
prefill_specialization.update({"batch_size": kv_cache_batch_size})
1606+
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
1607+
specializations = [
1608+
prefill_specialization,
1609+
]
16031610

1604-
# --- Specializations ---
1605-
specializations = []
1611+
# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1612+
if prefill_seq_len != 1 or self.continuous_batching:
1613+
decode_specialization = {
1614+
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1615+
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
1616+
"ctx_len": ctx_len,
1617+
}
1618+
if self.continuous_batching:
1619+
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
1620+
else:
1621+
decode_specialization.update({"batch_size": kv_cache_batch_size})
1622+
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
1623+
specializations.append(decode_specialization)
16061624

1607-
if prefill_only is None or prefill_only or prefill_seq_len == 1:
1608-
specializations.append(
1609-
self.build_prefill_specialization(
1610-
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
1611-
)
1625+
if enable_qnn:
1626+
if compiler_options:
1627+
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1628+
1629+
qpc_path = self._qnn_compile(
1630+
onnx_path,
1631+
compile_dir,
1632+
specializations=specializations,
1633+
prefill_seq_len=prefill_seq_len,
1634+
ctx_len=ctx_len,
1635+
batch_size=batch_size,
1636+
full_batch_size=full_batch_size,
1637+
mdp_ts_num_devices=num_devices,
1638+
num_cores=num_cores,
1639+
mxfp6_matmul=mxfp6_matmul,
1640+
mxint8_kv_cache=mxint8_kv_cache,
1641+
qnn_config=qnn_config,
1642+
kv_cache_batch_size=kv_cache_batch_size,
16121643
)
1613-
if prefill_only is None or not prefill_only:
1614-
decode_spec = self.build_decode_specialization(
1615-
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
1644+
else:
1645+
# Custom IO
1646+
custom_io = {}
1647+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1648+
for suffix in ["", "_RetainedState"]:
1649+
for i in range(self.num_layers):
1650+
for kv in ["key", "value"]:
1651+
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1652+
1653+
qpc_path = self._compile(
1654+
onnx_path,
1655+
compile_dir,
1656+
compile_only=True,
1657+
retained_state=True,
1658+
specializations=specializations,
1659+
convert_to_fp16=True,
1660+
mxfp6_matmul=mxfp6_matmul,
1661+
custom_io=custom_io,
1662+
mdp_ts_num_devices=num_devices,
1663+
num_speculative_tokens=num_speculative_tokens,
1664+
aic_num_cores=num_cores,
1665+
**compiler_options,
16161666
)
1617-
if decode_spec:
1618-
specializations.append(decode_spec)
1619-
1620-
# --- Compilation ---
1621-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1622-
custom_io = {}
1623-
1624-
for suffix in ["", "_RetainedState"]:
1625-
for i in range(self.num_layers):
1626-
for kv in ["key", "value"]:
1627-
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1628-
1629-
qpc_path = self._compile(
1630-
onnx_path=onnx_path,
1631-
compile_dir=compile_dir,
1632-
compile_only=True,
1633-
retained_state=True,
1634-
specializations=specializations,
1635-
convert_to_fp16=True,
1636-
mxfp6_matmul=mxfp6_matmul,
1637-
custom_io=custom_io,
1638-
mdp_ts_num_devices=num_devices,
1639-
num_speculative_tokens=num_speculative_tokens,
1640-
aic_num_cores=num_cores,
1641-
mxint8_kv_cache=mxint8_kv_cache,
1642-
**compiler_options,
1643-
)
1644-
16451667
return qpc_path
16461668

16471669
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate

0 commit comments

Comments
 (0)