@@ -603,8 +603,8 @@ def compile(
603603 mxfp6_matmul : bool = False ,
604604 mxint8_kv_cache : bool = False ,
605605 num_speculative_tokens : Optional [int ] = None ,
606- skip_vision : Optional [ bool ] = False ,
607- skip_lang : Optional [bool ] = False ,
606+ enable_qnn : bool = False ,
607+ qnn_config : Optional [str ] = None ,
608608 ** compiler_options ,
609609 ) -> str :
610610 if any (param is not None for param in [full_batch_size , kv_cache_batch_size , num_speculative_tokens ]):
@@ -646,47 +646,41 @@ def compile(
646646 ):
647647 self .export ()
648648
649- if not skip_vision :
650- self .vision_model ._compile (
651- compile_dir ,
652- compile_only = True ,
653- specializations = specializations ["vision" ],
654- convert_to_fp16 = True ,
655- mxfp6_matmul = mxfp6_matmul ,
656- mdp_ts_num_devices = num_devices ,
657- aic_num_cores = num_cores ,
658- custom_io = custom_io_vision ,
659- mxint8_kv_cache = mxint8_kv_cache ,
660- ** compiler_options ,
661- )
649+ self .vision_model ._compile (
650+ compile_dir ,
651+ compile_only = True ,
652+ specializations = specializations ["vision" ],
653+ convert_to_fp16 = True ,
654+ mxfp6_matmul = mxfp6_matmul ,
655+ mdp_ts_num_devices = num_devices ,
656+ aic_num_cores = num_cores ,
657+ custom_io = custom_io_vision ,
658+ ** compiler_options ,
659+ )
662660
663- if not skip_lang :
664- custom_io_lang = {}
665- # Inputs
666- for output_name in output_names ["lang" ]:
667- if output_name .endswith ("_RetainedState" ):
668- custom_io_lang [output_name [: - len ("_RetainedState" )]] = (
669- "float16" if "vision_embeds" in output_name else kv_cache_dtype
670- )
671-
672- # outputs
673- for output_name in output_names ["lang" ]:
674- if output_name .endswith ("_RetainedState" ):
675- custom_io_lang [output_name ] = "float16" if "vision_embeds" in output_name else kv_cache_dtype
676-
677- self .lang_model ._compile (
678- compile_dir ,
679- compile_only = True ,
680- retained_state = True ,
681- specializations = specializations ["lang" ],
682- convert_to_fp16 = True ,
683- mxfp6_matmul = mxfp6_matmul ,
684- mdp_ts_num_devices = num_devices ,
685- aic_num_cores = num_cores ,
686- custom_io = custom_io_lang ,
687- mxint8_kv_cache = mxint8_kv_cache ,
688- ** compiler_options ,
689- )
661+ custom_io_lang = {}
662+ # Inputs
663+ for output_name in output_names ["lang" ]:
664+ if output_name .endswith ("_RetainedState" ):
665+ custom_io_lang [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
666+
667+ # outputs
668+ for output_name in output_names ["lang" ]:
669+ if output_name .endswith ("_RetainedState" ):
670+ custom_io_lang [output_name ] = kv_cache_dtype
671+
672+ self .lang_model ._compile (
673+ compile_dir ,
674+ compile_only = True ,
675+ retained_state = True ,
676+ specializations = specializations ["lang" ],
677+ convert_to_fp16 = True ,
678+ mxfp6_matmul = mxfp6_matmul ,
679+ mdp_ts_num_devices = num_devices ,
680+ aic_num_cores = num_cores ,
681+ custom_io = custom_io_lang ,
682+ ** compiler_options ,
683+ )
690684 return self .qpc_path
691685
692686 def generate (
@@ -1547,7 +1541,8 @@ def compile(
15471541 mxfp6_matmul : bool = False ,
15481542 mxint8_kv_cache : bool = False ,
15491543 num_speculative_tokens : Optional [int ] = None ,
1550- prefill_only : Optional [bool ] = None ,
1544+ enable_qnn : bool = False ,
1545+ qnn_config : Optional [str ] = None ,
15511546 ** compiler_options ,
15521547 ) -> str :
15531548 """
@@ -1569,14 +1564,8 @@ def compile(
15691564 :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
15701565 :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
15711566 :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1572- :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1573- :compiler_options (dict, optional): Pass any compiler option as input. ``Defaults to None``.
1574- Following flag can be passed in compiler_options to enable QNN Compilation path.
1575- :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
1576- :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
1577- for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
1578- - aic_num_cores=16 -> -aic-num-cores=16
1579- - convert_to_fp16=True -> -convert-to-fp16
1567+ :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1568+ :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
15801569
15811570 Returns:
15821571 :str: Path of the compiled ``qpc`` package.
@@ -1598,50 +1587,83 @@ def compile(
15981587 "enable `continuous_batching=True` in `from_pretrained`."
15991588 )
16001589
1601- # Infer kv_cache_batch_size if not provided
1602- kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1590+ kv_cache_batch_size = (
1591+ kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1592+ )
1593+ # Define prefill specialization
1594+ prefill_specialization = {
1595+ # Prefill is always run with single BS for continuous batching.
1596+ "batch_size" : 1 if self .continuous_batching else batch_size ,
1597+ "seq_len" : prefill_seq_len ,
1598+ "ctx_len" : ctx_len ,
1599+ # TODO: should be renamed to kv_cache_batch_size in specialization too
1600+ }
1601+ prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1602+ if self .continuous_batching :
1603+ prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1604+ else :
1605+ prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1606+ prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1607+ specializations = [
1608+ prefill_specialization ,
1609+ ]
16031610
1604- # --- Specializations ---
1605- specializations = []
1611+ # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1612+ if prefill_seq_len != 1 or self .continuous_batching :
1613+ decode_specialization = {
1614+ "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1615+ "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1616+ "ctx_len" : ctx_len ,
1617+ }
1618+ if self .continuous_batching :
1619+ decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1620+ else :
1621+ decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1622+ decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1623+ specializations .append (decode_specialization )
16061624
1607- if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1608- specializations .append (
1609- self .build_prefill_specialization (
1610- prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1611- )
1625+ if enable_qnn :
1626+ if compiler_options :
1627+ logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only" )
1628+
1629+ qpc_path = self ._qnn_compile (
1630+ onnx_path ,
1631+ compile_dir ,
1632+ specializations = specializations ,
1633+ prefill_seq_len = prefill_seq_len ,
1634+ ctx_len = ctx_len ,
1635+ batch_size = batch_size ,
1636+ full_batch_size = full_batch_size ,
1637+ mdp_ts_num_devices = num_devices ,
1638+ num_cores = num_cores ,
1639+ mxfp6_matmul = mxfp6_matmul ,
1640+ mxint8_kv_cache = mxint8_kv_cache ,
1641+ qnn_config = qnn_config ,
1642+ kv_cache_batch_size = kv_cache_batch_size ,
16121643 )
1613- if prefill_only is None or not prefill_only :
1614- decode_spec = self .build_decode_specialization (
1615- prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1644+ else :
1645+ # Custom IO
1646+ custom_io = {}
1647+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1648+ for suffix in ["" , "_RetainedState" ]:
1649+ for i in range (self .num_layers ):
1650+ for kv in ["key" , "value" ]:
1651+ custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1652+
1653+ qpc_path = self ._compile (
1654+ onnx_path ,
1655+ compile_dir ,
1656+ compile_only = True ,
1657+ retained_state = True ,
1658+ specializations = specializations ,
1659+ convert_to_fp16 = True ,
1660+ mxfp6_matmul = mxfp6_matmul ,
1661+ custom_io = custom_io ,
1662+ mdp_ts_num_devices = num_devices ,
1663+ num_speculative_tokens = num_speculative_tokens ,
1664+ aic_num_cores = num_cores ,
1665+ ** compiler_options ,
16161666 )
1617- if decode_spec :
1618- specializations .append (decode_spec )
1619-
1620- # --- Compilation ---
1621- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1622- custom_io = {}
1623-
1624- for suffix in ["" , "_RetainedState" ]:
1625- for i in range (self .num_layers ):
1626- for kv in ["key" , "value" ]:
1627- custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1628-
1629- qpc_path = self ._compile (
1630- onnx_path = onnx_path ,
1631- compile_dir = compile_dir ,
1632- compile_only = True ,
1633- retained_state = True ,
1634- specializations = specializations ,
1635- convert_to_fp16 = True ,
1636- mxfp6_matmul = mxfp6_matmul ,
1637- custom_io = custom_io ,
1638- mdp_ts_num_devices = num_devices ,
1639- num_speculative_tokens = num_speculative_tokens ,
1640- aic_num_cores = num_cores ,
1641- mxint8_kv_cache = mxint8_kv_cache ,
1642- ** compiler_options ,
1643- )
1644-
16451667 return qpc_path
16461668
16471669 # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
0 commit comments