-
Notifications
You must be signed in to change notification settings - Fork 59
prompt-lookup decoding example #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5717907
266cbc1
71ffbc0
426b8d1
b097f15
1d69d10
16cd33b
79c0610
437ae6a
3c6d369
66585e1
5cad9da
de282ba
914116d
9603f7b
88531b1
9e11221
98b2872
cd44d8b
4c87bd7
94bf299
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
|
||
|
||
@dataclass | ||
class PerfMetrics: | ||
class SpDPerfMetrics: | ||
""" | ||
Holds all performance metrics | ||
|
||
|
@@ -31,6 +31,11 @@ class PerfMetrics: | |
:mean_num_accepted_tokens (float): Average number of accepted tokens. | ||
:max_gen_len (int): Max generation length. | ||
:generated_tokens_per_prompt (List[int]): Total generated tokens per prompt. | ||
:e2e_time (float): Total end-to-end time. | ||
:decode_time (float): Total decode time. | ||
:decode_draft_time (float): Total draft time. | ||
:decode_target_time (float): Total target time. | ||
:decode_iterations (int): Total decode iterations. | ||
""" | ||
|
||
mean_ttft: float | ||
|
@@ -40,10 +45,15 @@ class PerfMetrics: | |
mean_num_accepted_tokens: float | ||
max_gen_len: int | ||
generated_tokens_per_prompt: List[int] | ||
e2e_time: float | ||
decode_time: float | ||
decode_draft_time: float | ||
decode_target_time: float | ||
decode_iterations: int | ||
|
||
|
||
@dataclass | ||
class CloudAI100ExecInfo: | ||
class SpDCloudAI100ExecInfo: | ||
""" | ||
Holds all the information about Cloud AI 100 execution | ||
|
||
|
@@ -52,7 +62,7 @@ class CloudAI100ExecInfo: | |
:batch_size (int): Batch size of the QPC compilation. | ||
:generated_texts (Union[List[List[str]], List[str]]): Generated text(s). | ||
:generated_ids (Union[List[np.ndarray], np.ndarray]): Generated IDs. | ||
:perf_metrics (PerfMetrics): Performance metrics. | ||
:perf_metrics (SpDPerfMetrics): Performance metrics. | ||
:num_speculative_tokens (int): Number of speculative tokens. | ||
:prefill_seq_len (int): Prefill sequence length. | ||
:ctx_len (int): Context length. | ||
|
@@ -66,7 +76,7 @@ class CloudAI100ExecInfo: | |
batch_size: int | ||
generated_texts: Union[List[str], List[List[str]]] | ||
generated_ids: Union[List[np.ndarray], np.ndarray] | ||
perf_metrics: PerfMetrics | ||
perf_metrics: SpDPerfMetrics | ||
num_speculative_tokens: int | ||
prefill_seq_len: int | ||
ctx_len: int | ||
|
@@ -156,8 +166,11 @@ def draft_spec_decode_inference( | |
draft_model_name: str, | ||
target_model_name: str, | ||
full_batch_size: Optional[int], | ||
device_group: List[int], | ||
) -> CloudAI100ExecInfo: | ||
target_device_group: List[int], | ||
draft_device_group: List[int], | ||
draft_model_session: Optional[QAICInferenceSession] = None, | ||
target_model_session: Optional[QAICInferenceSession] = None, | ||
) -> SpDCloudAI100ExecInfo: | ||
""" | ||
Perform draft speculative decode inference on the given prompts. | ||
|
||
|
@@ -170,10 +183,11 @@ def draft_spec_decode_inference( | |
draft_model_name (str): Name of the draft model. | ||
target_model_name (str): Name of the target model. | ||
full_batch_size (Optional[int]): Full batch size. | ||
device_group (List[int]): List of device IDs. | ||
target_device_group (List[int]): List of device IDs for target model. | ||
draft_device_group (List[int]): List of device IDs for draft model. | ||
|
||
Returns: | ||
CloudAI100ExecInfo: Execution information, including performance metrics and generated text. | ||
SpDCloudAI100ExecInfo: Execution information, including performance metrics and generated text. | ||
""" | ||
# assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size | ||
# get vocab size | ||
|
@@ -184,31 +198,34 @@ def draft_spec_decode_inference( | |
|
||
# export_and_compile tlm and dlm | ||
continuous_batching = full_batch_size is not None | ||
target_model = AutoModelForCausalLM.from_pretrained( | ||
target_model_name, continuous_batching=continuous_batching, is_tlm=True | ||
) | ||
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) | ||
|
||
num_devices = len(device_group) | ||
target_model_qpc_path: str = target_model.compile( | ||
num_cores=11, | ||
num_devices=num_devices, | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
aic_enable_depth_first=True, | ||
full_batch_size=full_batch_size, | ||
num_speculative_tokens=num_speculative_tokens, | ||
) | ||
draft_model_qpc_path: str = draft_model.compile( | ||
num_cores=5, | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
aic_enable_depth_first=True, | ||
full_batch_size=full_batch_size, | ||
) | ||
# init qaic session | ||
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group) | ||
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group) | ||
if target_model_session is None: | ||
target_model = AutoModelForCausalLM.from_pretrained( | ||
target_model_name, continuous_batching=continuous_batching, is_tlm=True | ||
) | ||
target_num_devices = len(target_device_group) | ||
target_model_qpc_path: str = target_model.compile( | ||
num_cores=11, | ||
num_devices=target_num_devices, | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
aic_enable_depth_first=True, | ||
full_batch_size=full_batch_size, | ||
num_speculative_tokens=num_speculative_tokens, | ||
) | ||
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=target_device_group) | ||
if draft_model_session is None: | ||
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching) | ||
draft_num_devices = len(draft_device_group) | ||
draft_model_qpc_path: str = draft_model.compile( | ||
num_cores=5, | ||
num_devices=draft_num_devices, | ||
prefill_seq_len=prefill_seq_len, | ||
ctx_len=ctx_len, | ||
aic_enable_depth_first=True, | ||
full_batch_size=full_batch_size, | ||
) | ||
# init qaic session | ||
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=draft_device_group) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above |
||
# skip inputs/outputs buffers | ||
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")])) | ||
|
@@ -293,12 +310,15 @@ def draft_spec_decode_inference( | |
valid_batch_indices = np.full(decode_batch_size, True, dtype=bool) | ||
all_accept = False | ||
it = 0 | ||
decode_draft_time = 0.0 | ||
decode_target_time = 0.0 | ||
decode_start = perf_counter() | ||
mean_num_accepted_tokens = 0 | ||
all_accept = np.full(decode_batch_size, False, dtype=bool) | ||
while True: | ||
it += 1 | ||
# generate proposals from draft model | ||
draft_start = perf_counter() | ||
for k_ in range(num_speculative_tokens): | ||
if all_accept.any(): | ||
# running decode one extra time in the first speculative iteration | ||
|
@@ -311,31 +331,30 @@ def draft_spec_decode_inference( | |
tlm_precode_inputs["input_ids"][:, k_ + 1] = input_ids.flatten() | ||
dlm_decode_inputs["input_ids"] = input_ids | ||
dlm_decode_inputs["position_ids"][valid_batch_indices] += 1 | ||
draft_end = perf_counter() - draft_start | ||
decode_draft_time += draft_end | ||
# run precode on TLM to score the proposed tokens | ||
target_start = perf_counter() | ||
tlm_outputs = target_model_session.run(tlm_precode_inputs) | ||
target_logits = tlm_outputs["logits"] | ||
# greedy sampling from target model | ||
target_tokens = target_logits.argmax(-1) | ||
target_end = perf_counter() - target_start | ||
decode_target_time += target_end | ||
# exact matching between draft and target tokens | ||
draft_tokens = tlm_precode_inputs["input_ids"][:, 1:] | ||
matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens] | ||
num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) + 1 # shape: [decode_batch_size] | ||
all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1 | ||
mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item() | ||
# append selected tokens to the generated_ids | ||
tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape( | ||
decode_batch_size, 1 | ||
) | ||
# tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1 | ||
for bi, valid in enumerate(valid_batch_indices): | ||
if not valid: | ||
continue | ||
accepted_tokens = num_tokens_selected[bi] | ||
num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi])) | ||
generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist()) | ||
# position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM | ||
# (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token) | ||
if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len - 1).any(): | ||
if len(generated_ids[bi]) >= max_gen_len[bi]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we are having (>=) instead of (>) greater check, unless we are using it as an iterator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we should stop the generation process when the generated IDs match the max_gen_length for the batch index. If we dont it will generate max_gen_len + 1 token IDs which is not correct for our case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks okay to me |
||
valid_batch_indices[bi] = False | ||
# check if all generations are done | ||
if not valid_batch_indices.any(): | ||
|
@@ -379,16 +398,21 @@ def draft_spec_decode_inference( | |
e2e_throughput = (sum(generated_tokens_per_prompt) + decode_batch_size) / e2e_end | ||
batch_decode = tokenizer.batch_decode(generated_ids) | ||
mean_num_accepted_tokens /= it | ||
perf_metrics = PerfMetrics( | ||
perf_metrics = SpDPerfMetrics( | ||
mean_ttft, | ||
batch_ttft, | ||
decode_throughput, | ||
e2e_throughput, | ||
mean_num_accepted_tokens, | ||
max_gen_len, | ||
generated_tokens_per_prompt, | ||
e2e_end, | ||
decode_end, | ||
decode_draft_time, | ||
decode_target_time, | ||
it, | ||
) | ||
exec_info = CloudAI100ExecInfo( | ||
exec_info = SpDCloudAI100ExecInfo( | ||
prompts, | ||
decode_batch_size, | ||
batch_decode, | ||
|
@@ -405,15 +429,19 @@ def draft_spec_decode_inference( | |
return exec_info | ||
|
||
|
||
def optional_int(x): | ||
def optional_int(x: Optional[str]): | ||
if x is None: | ||
return None | ||
return int(x) | ||
|
||
|
||
def comma_separated_ints(x: str): | ||
return [int(qid) for qid in x.split(",")] | ||
|
||
|
||
def arg_parse(): | ||
parser = ArgumentParser(description="Draft-based SpD Inference") | ||
parser.add_argument("--prompts", type=str, nargs="+", default=Constants.INPUT_STR, help="Input prompt(s)") | ||
parser.add_argument("--prompts", action="append", default=None, help="Input prompt(s)") | ||
parser.add_argument("--num-speculative-tokens", type=int, default=4, help="Number of speculative tokens") | ||
parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length") | ||
parser.add_argument("--ctx-len", type=int, default=128, help="Context length") | ||
|
@@ -425,13 +453,26 @@ def arg_parse(): | |
"--target-model-name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", help="Target model name" | ||
) | ||
parser.add_argument("--full-batch-size", type=optional_int, default=None, help="Full batch size") | ||
parser.add_argument("--device-group", type=int, nargs="+", default=[0], help="device QIDs") | ||
parser.add_argument( | ||
"--target-device-group", | ||
type=comma_separated_ints, | ||
default="0", | ||
help="comma separated device QIDs (e.g., '1,2,3')", | ||
) | ||
parser.add_argument( | ||
"--draft-device-group", | ||
type=comma_separated_ints, | ||
default="0", | ||
help="comma separated device QIDs (e.g., '1,2,3')", | ||
) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = arg_parse() | ||
if args.prompts is None: | ||
args.prompts = Constants.INPUT_STR | ||
exec_info = draft_spec_decode_inference(**vars(args)) | ||
print(exec_info) | ||
prompts = exec_info.prompts | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need to gracefully handle the else case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically target_model_session is passed as an argument in line 172. So if the target_model_session is None the model session is being created here. Do we need to have an else condition here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fine then, we don't need else case here then.