@@ -605,6 +605,8 @@ def compile(
605605 num_speculative_tokens : Optional [int ] = None ,
606606 enable_qnn : bool = False ,
607607 qnn_config : Optional [str ] = None ,
608+ skip_vision : Optional [bool ] = False ,
609+ skip_lang : Optional [bool ] = False ,
608610 ** compiler_options ,
609611 ) -> str :
610612 if any (param is not None for param in [full_batch_size , kv_cache_batch_size , num_speculative_tokens ]):
@@ -646,17 +648,18 @@ def compile(
646648 ):
647649 self .export ()
648650
649- self .vision_model ._compile (
650- compile_dir ,
651- compile_only = True ,
652- specializations = specializations ["vision" ],
653- convert_to_fp16 = True ,
654- mxfp6_matmul = mxfp6_matmul ,
655- mdp_ts_num_devices = num_devices ,
656- aic_num_cores = num_cores ,
657- custom_io = custom_io_vision ,
658- ** compiler_options ,
659- )
651+ if not skip_vision :
652+ self .vision_model ._compile (
653+ compile_dir ,
654+ compile_only = True ,
655+ specializations = specializations ["vision" ],
656+ convert_to_fp16 = True ,
657+ mxfp6_matmul = mxfp6_matmul ,
658+ mdp_ts_num_devices = num_devices ,
659+ aic_num_cores = num_cores ,
660+ custom_io = custom_io_vision ,
661+ ** compiler_options ,
662+ )
660663
661664 custom_io_lang = {}
662665 # Inputs
@@ -669,18 +672,18 @@ def compile(
669672 if output_name .endswith ("_RetainedState" ):
670673 custom_io_lang [output_name ] = kv_cache_dtype
671674
672- self .lang_model ._compile (
673- compile_dir ,
674- compile_only = True ,
675- retained_state = True ,
676- specializations = specializations ["lang" ],
677- convert_to_fp16 = True ,
678- mxfp6_matmul = mxfp6_matmul ,
679- mdp_ts_num_devices = num_devices ,
680- aic_num_cores = num_cores ,
681- custom_io = custom_io_lang ,
682- ** compiler_options ,
683- )
675+ self .lang_model ._compile (
676+ compile_dir ,
677+ compile_only = True ,
678+ retained_state = True ,
679+ specializations = specializations ["lang" ],
680+ convert_to_fp16 = True ,
681+ mxfp6_matmul = mxfp6_matmul ,
682+ mdp_ts_num_devices = num_devices ,
683+ aic_num_cores = num_cores ,
684+ custom_io = custom_io_lang ,
685+ ** compiler_options ,
686+ )
684687 return self .qpc_path
685688
686689 def generate (
@@ -1539,6 +1542,7 @@ def compile(
15391542 num_speculative_tokens : Optional [int ] = None ,
15401543 enable_qnn : bool = False ,
15411544 qnn_config : Optional [str ] = None ,
1545+ prefill_only : Optional [bool ] = None ,
15421546 ** compiler_options ,
15431547 ) -> str :
15441548 """
@@ -1562,6 +1566,8 @@ def compile(
15621566 :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
15631567 :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
15641568 :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1569+ :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1570+ :compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
15651571
15661572 Returns:
15671573 :str: Path of the compiled ``qpc`` package.
@@ -1583,48 +1589,33 @@ def compile(
15831589 "enable `continuous_batching=True` in `from_pretrained`."
15841590 )
15851591
1586- kv_cache_batch_size = (
1587- kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1588- )
1589- # Define prefill specialization
1590- prefill_specialization = {
1591- # Prefill is always run with single BS for continuous batching.
1592- "batch_size" : 1 if self .continuous_batching else batch_size ,
1593- "seq_len" : prefill_seq_len ,
1594- "ctx_len" : ctx_len ,
1595- # TODO: should be renamed to kv_cache_batch_size in specialization too
1596- }
1597- prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1598- if self .continuous_batching :
1599- prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1600- else :
1601- prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1602- prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1603- specializations = [
1604- prefill_specialization ,
1605- ]
1592+ # Infer kv_cache_batch_size if not provided
1593+ kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
16061594
1607- # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1608- if prefill_seq_len != 1 or self .continuous_batching :
1609- decode_specialization = {
1610- "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1611- "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1612- "ctx_len" : ctx_len ,
1613- }
1614- if self .continuous_batching :
1615- decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1616- else :
1617- decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1618- decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1619- specializations .append (decode_specialization )
1595+ # --- Specializations ---
1596+ specializations = []
1597+
1598+ if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1599+ specializations .append (
1600+ self .build_prefill_specialization (
1601+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1602+ )
1603+ )
1604+ if prefill_only is None or not prefill_only :
1605+ decode_spec = self .build_decode_specialization (
1606+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1607+ )
1608+ if decode_spec :
1609+ specializations .append (decode_spec )
16201610
1611+ # --- Compilation ---
16211612 if enable_qnn :
16221613 if compiler_options :
1623- logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only " )
1614+ logger .warning ("Extra arguments to QNN compilation are ignored. Use ` qnn_config.json`. " )
16241615
16251616 qpc_path = self ._qnn_compile (
1626- onnx_path ,
1627- compile_dir ,
1617+ onnx_path = onnx_path ,
1618+ compile_dir = compile_dir ,
16281619 specializations = specializations ,
16291620 prefill_seq_len = prefill_seq_len ,
16301621 ctx_len = ctx_len ,
@@ -1638,17 +1629,17 @@ def compile(
16381629 kv_cache_batch_size = kv_cache_batch_size ,
16391630 )
16401631 else :
1641- # Custom IO
1642- custom_io = {}
16431632 kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1633+ custom_io = {}
1634+
16441635 for suffix in ["" , "_RetainedState" ]:
16451636 for i in range (self .num_layers ):
16461637 for kv in ["key" , "value" ]:
16471638 custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
16481639
16491640 qpc_path = self ._compile (
1650- onnx_path ,
1651- compile_dir ,
1641+ onnx_path = onnx_path ,
1642+ compile_dir = compile_dir ,
16521643 compile_only = True ,
16531644 retained_state = True ,
16541645 specializations = specializations ,
@@ -1660,6 +1651,7 @@ def compile(
16601651 aic_num_cores = num_cores ,
16611652 ** compiler_options ,
16621653 )
1654+
16631655 return qpc_path
16641656
16651657 # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1867,22 +1859,8 @@ def compile(
18671859 if num_speculative_tokens :
18681860 logger .warning ("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq" )
18691861
1870- output_names = self .model .get_output_names ()
1871-
1872- kv_cache_dtype = "float16"
1873- custom_io = {}
1874-
1875- custom_io ["input_features" ] = kv_cache_dtype
1876-
1877- # Slice output_names to get input names
1878- for output_name in output_names :
1879- if output_name .endswith ("_RetainedState" ):
1880- custom_io [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
1881-
1882- # Get output names
1883- for output_name in output_names :
1884- if output_name .endswith ("_RetainedState" ):
1885- custom_io [output_name ] = kv_cache_dtype
1862+ if enable_qnn or qnn_config :
1863+ logger .warning ("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq" )
18861864
18871865 return self ._compile (
18881866 onnx_path ,
0 commit comments