2020from neuronx_distributed_inference .models .config import (
2121 FusedSpecNeuronConfig ,
2222 OnDeviceSamplingConfig ,
23+ ChunkedPrefillConfig ,
2324 to_torch_dtype ,
2425)
2526from neuronx_distributed_inference .models .dbrx .modeling_dbrx import NeuronDbrxForCausalLM
3839from neuronx_distributed_inference .utils .exceptions import LogitMatchingValidationError
3940from neuronx_distributed_inference .utils .hf_adapter import load_pretrained_config
4041from neuronx_distributed_inference .utils .random import set_random_seed
42+ from neuronx_distributed_inference .utils .constants import BENCHMARK_REPORT_PATH
4143
4244set_random_seed (0 )
4345
@@ -120,6 +122,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
120122 run_parser .add_argument ("--max-new-tokens" , type = int )
121123 run_parser .add_argument ("--max-length" , type = int )
122124 run_parser .add_argument ("--rpl-reduce-dtype" , type = to_torch_dtype )
125+ run_parser .add_argument ("--attention-dtype" , type = to_torch_dtype )
123126 run_parser .add_argument ("--output-logits" , action = "store_true" )
124127 run_parser .add_argument ("--vocab-parallel" , action = "store_true" )
125128 run_parser .add_argument ("--layer-boundary-markers" , action = "store_true" , default = False )
@@ -148,6 +151,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
148151 run_parser .add_argument ("--enable-bucketing" , action = "store_true" )
149152 run_parser .add_argument ("--bucket-n-active-tokens" , action = "store_true" )
150153 run_parser .add_argument ("--context-encoding-buckets" , nargs = "+" , type = int )
154+ run_parser .add_argument ("--prefix-buckets" , nargs = "+" , type = int )
151155 run_parser .add_argument ("--token-generation-buckets" , nargs = "+" , type = int )
152156
153157 # Quantization
@@ -166,6 +170,13 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
166170
167171 # MoE
168172 run_parser .add_argument ("--capacity-factor" , type = float )
173+ run_parser .add_argument ("--early-expert-affinity-modulation" , action = "store_true" )
174+ run_parser .add_argument ("--disable-normalize-top-k-affinities" , action = "store_true" )
175+ run_parser .add_argument ("--fused-shared-experts" , action = "store_true" )
176+
177+ # Router Config
178+ run_parser .add_argument ("--router-act-fn" , type = str )
179+ run_parser .add_argument ("--router-dtype" , type = str )
169180
170181 # Speculative decoding
171182 run_parser .add_argument ("--draft-model-path" , type = str )
@@ -189,6 +200,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
189200
190201 # Parallelism
191202 run_parser .add_argument ("--tp-degree" , type = int )
203+ run_parser .add_argument ("--cp-degree" , type = int )
192204 run_parser .add_argument ("--pp-degree" , type = int )
193205 run_parser .add_argument ("--ep-degree" , type = int )
194206 run_parser .add_argument ("--world-size" , type = int )
@@ -224,8 +236,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
224236 run_parser .add_argument (
225237 "--enable-prefix-caching" , dest = "is_prefix_caching" , action = "store_true"
226238 )
227- run_parser .add_argument ("--cp-max-num-seqs" , type = int )
228- run_parser .add_argument ("--cp-num-active-blocks" , type = int )
239+ run_parser .add_argument ("--max-num-seqs" , type = int )
229240
230241 # Async
231242 run_parser .add_argument ("--async-mode" , action = "store_true" )
@@ -242,7 +253,7 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
242253 # Kernels
243254 run_parser .add_argument ("--qkv-kernel-enabled" , action = "store_true" )
244255 run_parser .add_argument ("--qkv-kernel-nbsd-layout" , action = "store_true" )
245- run_parser .add_argument ("--attn-kernel-enabled" , action = "store_true" )
256+ run_parser .add_argument ("--attn-kernel-enabled" , action = argparse . BooleanOptionalAction , default = None )
246257 run_parser .add_argument ("--mlp-kernel-enabled" , action = "store_true" )
247258 run_parser .add_argument ("--quantized-mlp-kernel-enabled" , action = "store_true" )
248259 run_parser .add_argument ("--fused-rmsnorm-skip-gamma" , action = "store_true" )
@@ -270,10 +281,19 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
270281
271282 # Compiler Args
272283 run_parser .add_argument ("--cc-pipeline-tiling-factor" , type = int , default = 2 )
284+ run_parser .add_argument ("--enable-spill-reload-dge" , action = "store_true" )
273285
274286 # CPU
275287 run_parser .add_argument ("--on-cpu" , action = "store_true" )
276288
289+ # Report generation
290+ run_parser .add_argument (
291+ "--benchmark-report-path" ,
292+ type = str ,
293+ default = BENCHMARK_REPORT_PATH ,
294+ help = "File path to save benchmark report."
295+ )
296+
277297 # Debugging
278298 run_parser .add_argument (
279299 "--capture-indices" ,
@@ -283,6 +303,10 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
283303 help = f"Specify '{ argparse_utils .AUTO } ' when using check accuracy mode with { CheckAccuracyMode .LOGIT_MATCHING } for inferrring capture indices when the test fails and use the indices to capture inputs. Otherwise, provide any number of integer values for capturing inputs at those indices." )
284304 run_parser .add_argument ("--input-capture-save-dir" , type = str , default = None )
285305
306+ run_parser .add_argument ("--cast-type" , choices = ["config" , "as-declared" ], default = "config" ,
307+ help = "If set to 'config', all parameters will be casted to neuron_config.torch_dtype. "
308+ "If set to 'as-declared', casting will be done based on the dtype set for each parameter" )
309+
286310 # Optional demo arguments
287311 run_parser .add_argument (
288312 "--skip-warmup" ,
@@ -312,6 +336,20 @@ def setup_run_parser(run_parser: argparse.ArgumentParser):
312336 help = "Adds metadata into the generated HLO. This metadata maps the HLO "
313337 "operators to the corresponding lines in the PyTorch code" ,
314338 )
339+ run_parser .add_argument (
340+ "--apply-seq-ids-mask" ,
341+ action = 'store_true' ,
342+ help = "Avoid KV cache update on inactive (padded) seq_ids"
343+ )
344+ run_parser .add_argument (
345+ "--input-start-offsets" ,
346+ nargs = "+" ,
347+ default = None ,
348+ type = int ,
349+ help = "Shift the input right by an offset. There can be multiple offsets, each per sequence."
350+ "If only 1 value is provided, all sequences will be shifted by this amount. "
351+ "This flag can be used to test chunked attention."
352+ )
315353
316354
317355def validate_file_exists (path ):
@@ -339,7 +377,7 @@ def get_modules_to_not_convert_json(json_path):
339377 return modules_to_not_convert , draft_model_modules_to_not_convert
340378
341379
342- def run_inference (model_cls : Type [ NeuronApplicationBase ] , args ):
380+ def create_neuron_config (model_cls , args ):
343381 # Initialize configs.
344382 print ("Loading configs..." )
345383
@@ -348,6 +386,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
348386 config_kwargs = {k : v for k , v in config_kwargs .items () if v is not None }
349387 if args .on_device_sampling :
350388 config_kwargs ["on_device_sampling_config" ] = OnDeviceSamplingConfig (** config_kwargs )
389+ if args .is_chunked_prefill :
390+ max_num_seqs = config_kwargs .pop ("max_num_seqs" , 0 )
391+ config_kwargs ["chunked_prefill_config" ] = ChunkedPrefillConfig (
392+ max_num_seqs = max_num_seqs ,
393+ )
351394
352395 if (args .quantized and args .quantization_dtype == "f8e4m3" ) or args .kv_cache_quant :
353396 os .environ ["XLA_HANDLE_SPECIAL_SCALAR" ] = "1"
@@ -371,6 +414,11 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
371414 )
372415 adapter_ids = args .adapter_ids
373416 neuron_config = model_cls .get_neuron_config_cls ()(** config_kwargs )
417+ return adapter_ids , neuron_config
418+
419+
420+ def run_inference (model_cls : Type [NeuronApplicationBase ], args ):
421+ adapter_ids , neuron_config = create_neuron_config (model_cls , args )
374422
375423 config = model_cls .get_config_cls ()(
376424 neuron_config , load_config = load_pretrained_config (args .model_path )
@@ -395,7 +443,6 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
395443 # Set eagle specific config changes
396444 if neuron_config .enable_eagle_speculation :
397445 draft_neuron_config .is_eagle_draft = True
398- draft_neuron_config .sequence_parallel_enabled = False
399446
400447 if args .draft_model_tp_degree is not None :
401448 draft_neuron_config .tp_degree = args .draft_model_tp_degree
@@ -415,6 +462,8 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
415462 draft_model = model_cls (args .draft_model_path , draft_config )
416463
417464 model = model_cls (args .model_path , config )
465+ if args .input_start_offsets :
466+ assert len (args .input_start_offsets ) == 1 or len (args .input_start_offsets ) == args .batch_size , "The number of input offsets has to be either 1 or equal or batch size."
418467
419468 # Quantize model.
420469 if neuron_config .quantized :
@@ -481,7 +530,10 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
481530 generation_config_kwargs = {
482531 k : getattr (args , k ) for k in generation_config_args if getattr (args , k ) is not None
483532 }
484- generation_config .update (** generation_config_kwargs )
533+ remaining_kwargs = generation_config .update (** generation_config_kwargs )
534+ # add any remaining ones (this can happen when the model generation config is missing some entries)
535+ for k , v in remaining_kwargs .items ():
536+ generation_config .__dict__ [k ] = v
485537
486538 # With Medusa, the model is also the draft model.
487539 if neuron_config .is_medusa :
@@ -504,6 +556,7 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
504556 num_tokens_to_check = args .num_tokens_to_check ,
505557 draft_model = draft_model ,
506558 expected_outputs_path = args .expected_outputs_path ,
559+ input_start_offsets = args .input_start_offsets ,
507560 )
508561 except LogitMatchingValidationError as e :
509562 logit_error = e
@@ -530,14 +583,15 @@ def run_inference(model_cls: Type[NeuronApplicationBase], args):
530583 draft_model = draft_model ,
531584 adapter_ids = adapter_ids ,
532585 input_capture_hook = input_capture_hook ,
586+ input_start_offsets = args .input_start_offsets ,
533587 )
534588
535589 if logit_error is not None :
536590 raise logit_error
537591
538592 # Benchmarking.
539593 if args .benchmark :
540- benchmark_sampling (model , draft_model , generation_config )
594+ benchmark_sampling (model , draft_model , generation_config , benchmark_report_path = args . benchmark_report_path )
541595
542596
543597def load_tokenizer (model_path , compiled_model_path , neuron_config ):
@@ -555,9 +609,12 @@ def run_generation(
555609 draft_model = None ,
556610 adapter_ids = None ,
557611 input_capture_hook = None ,
612+ input_start_offsets = None ,
558613):
559614 print ("\n Generating outputs..." )
560615 print (f"Prompts: { prompts } " )
616+ if len (prompts ) == 1 and model .config .neuron_config .batch_size > 1 :
617+ prompts = prompts * model .config .neuron_config .batch_size
561618
562619 _ , output_tokens = get_generate_outputs (
563620 model ,
@@ -569,6 +626,7 @@ def run_generation(
569626 adapter_ids = adapter_ids ,
570627 max_length = model .neuron_config .max_length ,
571628 input_capture_hook = input_capture_hook ,
629+ input_start_offsets = input_start_offsets
572630 )
573631
574632 print ("Generated outputs:" )
@@ -587,13 +645,15 @@ def run_accuracy_check(
587645 num_tokens_to_check = None ,
588646 draft_model = None ,
589647 expected_outputs_path = None ,
648+ input_start_offsets = None ,
590649):
591650 if model .neuron_config .is_medusa :
592651 # Medusa doesn't use greedy sampling, so check accuracy doesn't work.
593652 assert (
594653 check_accuracy_mode == CheckAccuracyMode .SKIP_ACCURACY_CHECK
595654 ), "Accuracy checking not supported for Medusa"
596-
655+ if input_start_offsets :
656+ assert all (offset < model .config .neuron_config .max_context_length for offset in input_start_offsets ), "Input offset has to be less than max context length"
597657 if check_accuracy_mode == CheckAccuracyMode .SKIP_ACCURACY_CHECK :
598658 print ("\n Skipping accuracy check" )
599659 return
@@ -612,6 +672,7 @@ def run_accuracy_check(
612672 draft_model = draft_model ,
613673 expected_token_ids = expected_outputs ,
614674 num_tokens_to_check = num_tokens_to_check ,
675+ input_start_offsets = input_start_offsets ,
615676 )
616677 elif check_accuracy_mode == CheckAccuracyMode .LOGIT_MATCHING :
617678 assert draft_model is None , "Logit matching not supported for speculation"
@@ -633,6 +694,7 @@ def run_accuracy_check(
633694 divergence_difference_tol = divergence_difference_tol ,
634695 tol_map = tol_map ,
635696 num_tokens_to_check = num_tokens_to_check ,
697+ input_start_offsets = input_start_offsets ,
636698 )
637699 else :
638700 raise ValueError (f"Unsupported check accuracy mode: { check_accuracy_mode } " )
0 commit comments