@@ -561,7 +561,12 @@ def onnx_path(self):
561561
562562 @property
563563 def qpc_path (self ):
564- return [self .vision_model .qpc_path , self .lang_model .qpc_path ]
564+ if self .vision_model .qpc_path and self .lang_model .qpc_path :
565+ return [self .vision_model .qpc_path , self .lang_model .qpc_path ]
566+ elif self .vision_model .qpc_path :
567+ return self .vision_model .qpc_path
568+ else :
569+ return self .lang_model .qpc_path
565570
566571 def export (
567572 self ,
@@ -600,6 +605,8 @@ def compile(
600605 num_speculative_tokens : Optional [int ] = None ,
601606 enable_qnn : bool = False ,
602607 qnn_config : Optional [str ] = None ,
608+ skip_vision : Optional [bool ] = False ,
609+ skip_lang : Optional [bool ] = False ,
603610 ** compiler_options ,
604611 ) -> str :
605612 if (
@@ -615,6 +622,9 @@ def compile(
615622 f"enable_qnn={ enable_qnn } , qnn_config={ qnn_config } "
616623 )
617624
625+ if skip_lang and skip_vision :
626+ raise ValueError ("Expected at least one of 'skip_lang' or 'skip_vision' to be False" )
627+
618628 output_names = self .model .get_output_names (kv_offload = True )
619629
620630 specializations , compiler_options = self .model .get_specializations (
@@ -642,41 +652,43 @@ def compile(
642652 ):
643653 self .export ()
644654
645- self .vision_model ._compile (
646- compile_dir ,
647- compile_only = True ,
648- specializations = specializations ["vision" ],
649- convert_to_fp16 = True ,
650- mxfp6_matmul = mxfp6_matmul ,
651- mdp_ts_num_devices = num_devices ,
652- aic_num_cores = num_cores ,
653- custom_io = custom_io_vision ,
654- ** compiler_options ,
655- )
655+ if not skip_vision :
656+ self .vision_model ._compile (
657+ compile_dir ,
658+ compile_only = True ,
659+ specializations = specializations ["vision" ],
660+ convert_to_fp16 = True ,
661+ mxfp6_matmul = mxfp6_matmul ,
662+ mdp_ts_num_devices = num_devices ,
663+ aic_num_cores = num_cores ,
664+ custom_io = custom_io_vision ,
665+ ** compiler_options ,
666+ )
656667
657- custom_io_lang = {}
658- # Inputs
659- for output_name in output_names ["lang" ]:
660- if output_name .endswith ("_RetainedState" ):
661- custom_io_lang [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
668+ if not skip_lang :
669+ custom_io_lang = {}
670+ # Inputs
671+ for output_name in output_names ["lang" ]:
672+ if output_name .endswith ("_RetainedState" ):
673+ custom_io_lang [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
662674
663- # outputs
664- for output_name in output_names ["lang" ]:
665- if output_name .endswith ("_RetainedState" ):
666- custom_io_lang [output_name ] = kv_cache_dtype
675+ # outputs
676+ for output_name in output_names ["lang" ]:
677+ if output_name .endswith ("_RetainedState" ):
678+ custom_io_lang [output_name ] = kv_cache_dtype
667679
668- self .lang_model ._compile (
669- compile_dir ,
670- compile_only = True ,
671- retained_state = True ,
672- specializations = specializations ["lang" ],
673- convert_to_fp16 = True ,
674- mxfp6_matmul = mxfp6_matmul ,
675- mdp_ts_num_devices = num_devices ,
676- aic_num_cores = num_cores ,
677- custom_io = custom_io_lang ,
678- ** compiler_options ,
679- )
680+ self .lang_model ._compile (
681+ compile_dir ,
682+ compile_only = True ,
683+ retained_state = True ,
684+ specializations = specializations ["lang" ],
685+ convert_to_fp16 = True ,
686+ mxfp6_matmul = mxfp6_matmul ,
687+ mdp_ts_num_devices = num_devices ,
688+ aic_num_cores = num_cores ,
689+ custom_io = custom_io_lang ,
690+ ** compiler_options ,
691+ )
680692 return self .qpc_path
681693
682694 def generate (
@@ -711,6 +723,9 @@ def kv_offload_generate(
711723 device_ids : List [int ] = None ,
712724 generation_len : int = None ,
713725 ):
726+ if not self .vision_model .qpc_path or not self .lang_model .qpc_path :
727+ raise TypeError ("Please run compile API for vision and language model first!" )
728+
714729 lang_session = QAICInferenceSession (self .lang_model .qpc_path , device_ids , activate = False )
715730
716731 vision_session = QAICInferenceSession (self .vision_model .qpc_path , device_ids )
@@ -1461,6 +1476,51 @@ def export(self, export_dir: Optional[str] = None) -> str:
14611476 export_dir = export_dir ,
14621477 )
14631478
1479+ def build_prefill_specialization (
1480+ self ,
1481+ prefill_seq_len : int = 32 ,
1482+ ctx_len : int = 128 ,
1483+ batch_size : int = 1 ,
1484+ kv_cache_batch_size : Optional [int ] = None ,
1485+ full_batch_size : Optional [int ] = None ,
1486+ ):
1487+ spec = {
1488+ "batch_size" : 1 if self .continuous_batching else batch_size ,
1489+ "seq_len" : prefill_seq_len ,
1490+ "ctx_len" : ctx_len ,
1491+ "num_logits_to_keep" : 1 if self .is_tlm else None ,
1492+ }
1493+ if self .continuous_batching :
1494+ spec ["full_batch_size" ] = kv_cache_batch_size
1495+ else :
1496+ spec ["batch_size" ] = kv_cache_batch_size
1497+ if full_batch_size :
1498+ spec ["full_batch_exec_size" ] = full_batch_size
1499+ return {k : v for k , v in spec .items () if v is not None }
1500+
1501+ def build_decode_specialization (
1502+ self ,
1503+ prefill_seq_len : int = 32 ,
1504+ ctx_len : int = 128 ,
1505+ batch_size : int = 1 ,
1506+ kv_cache_batch_size : Optional [int ] = None ,
1507+ full_batch_size : Optional [int ] = None ,
1508+ num_speculative_tokens : Optional [int ] = None ,
1509+ ):
1510+ if prefill_seq_len == 1 and not self .continuous_batching :
1511+ return None # Avoid duplication with prefill
1512+ spec = {
1513+ "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1514+ "seq_len" : (num_speculative_tokens + 1 ) if self .is_tlm else 1 ,
1515+ "ctx_len" : ctx_len ,
1516+ "num_logits_to_keep" : (num_speculative_tokens + 1 ) if self .is_tlm else None ,
1517+ }
1518+ if self .continuous_batching :
1519+ spec ["full_batch_size" ] = kv_cache_batch_size
1520+ else :
1521+ spec ["batch_size" ] = kv_cache_batch_size
1522+ return {k : v for k , v in spec .items () if v is not None }
1523+
14641524 def compile (
14651525 self ,
14661526 onnx_path : Optional [str ] = None ,
@@ -1478,6 +1538,7 @@ def compile(
14781538 num_speculative_tokens : Optional [int ] = None ,
14791539 enable_qnn : bool = False ,
14801540 qnn_config : Optional [str ] = None ,
1541+ prefill_only : Optional [bool ] = None ,
14811542 ** compiler_options ,
14821543 ) -> str :
14831544 """
@@ -1501,74 +1562,63 @@ def compile(
15011562 :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
15021563 :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
15031564 :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1565+ :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1566+ :compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
15041567
15051568 Returns:
15061569 :str: Path of the compiled ``qpc`` package.
15071570 """
1571+ # --- Validation ---
1572+ if prefill_only is not None and not isinstance (prefill_only , bool ):
1573+ raise TypeError ("`prefill_only` must be a boolean." )
1574+
15081575 if self .is_tlm :
1509- # assert num_speculative_tokens cfg is acceptable if defined
15101576 if num_speculative_tokens is None :
1511- raise TypeError ("missing required argument `num_speculative_tokens` as `is_tlm` is True." )
1512- if not isinstance (num_speculative_tokens , int ) and num_speculative_tokens < 2 :
1513- ValueError (
1514- f"`num_speculative_tokens` arg should be an integer greater than 1, got { num_speculative_tokens } "
1515- )
1516- num_logits_to_keep = num_speculative_tokens + 1
1517- if prefill_seq_len < num_logits_to_keep :
1577+ raise TypeError ("`num_speculative_tokens` is required when `is_tlm=True`." )
1578+ if not isinstance (num_speculative_tokens , int ) or num_speculative_tokens < 2 :
1579+ raise ValueError ("`num_speculative_tokens` must be an integer >= 2." )
1580+ if prefill_seq_len < (num_speculative_tokens + 1 ):
15181581 raise ValueError (
1519- f"sequence length ({ prefill_seq_len } ) must be at least `num_speculative_tokens+1` ({ num_logits_to_keep } )"
1582+ f"`prefill_seq_len` must be at least `num_speculative_tokens + 1` "
1583+ f"({ num_speculative_tokens + 1 } ), got { prefill_seq_len } ."
15201584 )
15211585
15221586 if self .continuous_batching and full_batch_size is None :
1523- raise TypeError ("missing required argument: 'full_batch_size' " )
1587+ raise TypeError ("`full_batch_size` is required when `continuous_batching=True`. " )
15241588
15251589 if kv_cache_batch_size and not full_batch_size :
15261590 raise ValueError (
1527- "Prefix caching is enabled only for continuous batching as of now. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call"
1591+ "KV caching requires continuous batching. Please set `full_batch_size` and "
1592+ "enable `continuous_batching=True` in `from_pretrained`."
15281593 )
15291594
1530- kv_cache_batch_size = (
1531- kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1532- )
1533- # Define prefill specialization
1534- prefill_specialization = {
1535- # Prefill is always run with single BS for continuous batching.
1536- "batch_size" : 1 if self .continuous_batching else batch_size ,
1537- "seq_len" : prefill_seq_len ,
1538- "ctx_len" : ctx_len ,
1539- # TODO: should be renamed to kv_cache_batch_size in specialization too
1540- }
1541- prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1542- if self .continuous_batching :
1543- prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1544- else :
1545- prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1546- prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1547- specializations = [
1548- prefill_specialization ,
1549- ]
1595+ # Infer kv_cache_batch_size if not provided
1596+ kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
15501597
1551- # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1552- if prefill_seq_len != 1 or self .continuous_batching :
1553- decode_specialization = {
1554- "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1555- "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1556- "ctx_len" : ctx_len ,
1557- }
1558- if self .continuous_batching :
1559- decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1560- else :
1561- decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1562- decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1563- specializations .append (decode_specialization )
1598+ # --- Specializations ---
1599+ specializations = []
1600+
1601+ if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1602+ specializations .append (
1603+ self .build_prefill_specialization (
1604+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1605+ )
1606+ )
1607+ if prefill_only is None or not prefill_only :
1608+ decode_spec = self .build_decode_specialization (
1609+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1610+ )
1611+ if decode_spec :
1612+ specializations .append (decode_spec )
15641613
1614+ # --- Compilation ---
15651615 if enable_qnn :
15661616 if compiler_options :
1567- logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only " )
1617+ logger .warning ("Extra arguments to QNN compilation are ignored. Use ` qnn_config.json`. " )
15681618
15691619 qpc_path = self ._qnn_compile (
1570- onnx_path ,
1571- compile_dir ,
1620+ onnx_path = onnx_path ,
1621+ compile_dir = compile_dir ,
15721622 specializations = specializations ,
15731623 prefill_seq_len = prefill_seq_len ,
15741624 ctx_len = ctx_len ,
@@ -1582,17 +1632,17 @@ def compile(
15821632 kv_cache_batch_size = kv_cache_batch_size ,
15831633 )
15841634 else :
1585- # Custom IO
1586- custom_io = {}
15871635 kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1636+ custom_io = {}
1637+
15881638 for suffix in ["" , "_RetainedState" ]:
15891639 for i in range (self .num_layers ):
15901640 for kv in ["key" , "value" ]:
15911641 custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
15921642
15931643 qpc_path = self ._compile (
1594- onnx_path ,
1595- compile_dir ,
1644+ onnx_path = onnx_path ,
1645+ compile_dir = compile_dir ,
15961646 compile_only = True ,
15971647 retained_state = True ,
15981648 specializations = specializations ,
@@ -1604,6 +1654,7 @@ def compile(
16041654 aic_num_cores = num_cores ,
16051655 ** compiler_options ,
16061656 )
1657+
16071658 return qpc_path
16081659
16091660 # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
0 commit comments