Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 86 additions & 45 deletions examples/draft_spd_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


@dataclass
class PerfMetrics:
class SpDPerfMetrics:
"""
Holds all performance metrics

Expand All @@ -31,6 +31,11 @@ class PerfMetrics:
:mean_num_accepted_tokens (float): Average number of accepted tokens.
:max_gen_len (int): Max generation length.
:generated_tokens_per_prompt (List[int]): Total generated tokens per prompt.
:e2e_time (float): Total end-to-end time.
:decode_time (float): Total decode time.
:decode_draft_time (float): Total draft time.
:decode_target_time (float): Total target time.
:decode_iterations (int): Total decode iterations.
"""

mean_ttft: float
Expand All @@ -40,10 +45,15 @@ class PerfMetrics:
mean_num_accepted_tokens: float
max_gen_len: int
generated_tokens_per_prompt: List[int]
e2e_time: float
decode_time: float
decode_draft_time: float
decode_target_time: float
decode_iterations: int


@dataclass
class CloudAI100ExecInfo:
class SpDCloudAI100ExecInfo:
"""
Holds all the information about Cloud AI 100 execution

Expand All @@ -52,7 +62,7 @@ class CloudAI100ExecInfo:
:batch_size (int): Batch size of the QPC compilation.
:generated_texts (Union[List[List[str]], List[str]]): Generated text(s).
:generated_ids (Union[List[np.ndarray], np.ndarray]): Generated IDs.
:perf_metrics (PerfMetrics): Performance metrics.
:perf_metrics (SpDPerfMetrics): Performance metrics.
:num_speculative_tokens (int): Number of speculative tokens.
:prefill_seq_len (int): Prefill sequence length.
:ctx_len (int): Context length.
Expand All @@ -66,7 +76,7 @@ class CloudAI100ExecInfo:
batch_size: int
generated_texts: Union[List[str], List[List[str]]]
generated_ids: Union[List[np.ndarray], np.ndarray]
perf_metrics: PerfMetrics
perf_metrics: SpDPerfMetrics
num_speculative_tokens: int
prefill_seq_len: int
ctx_len: int
Expand Down Expand Up @@ -156,8 +166,11 @@ def draft_spec_decode_inference(
draft_model_name: str,
target_model_name: str,
full_batch_size: Optional[int],
device_group: List[int],
) -> CloudAI100ExecInfo:
target_device_group: List[int],
draft_device_group: List[int],
draft_model_session: Optional[QAICInferenceSession] = None,
target_model_session: Optional[QAICInferenceSession] = None,
) -> SpDCloudAI100ExecInfo:
"""
Perform draft speculative decode inference on the given prompts.

Expand All @@ -170,10 +183,11 @@ def draft_spec_decode_inference(
draft_model_name (str): Name of the draft model.
target_model_name (str): Name of the target model.
full_batch_size (Optional[int]): Full batch size.
device_group (List[int]): List of device IDs.
target_device_group (List[int]): List of device IDs for target model.
draft_device_group (List[int]): List of device IDs for draft model.

Returns:
CloudAI100ExecInfo: Execution information, including performance metrics and generated text.
SpDCloudAI100ExecInfo: Execution information, including performance metrics and generated text.
"""
# assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size
# get vocab size
Expand All @@ -184,31 +198,34 @@ def draft_spec_decode_inference(

# export_and_compile tlm and dlm
continuous_batching = full_batch_size is not None
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name, continuous_batching=continuous_batching, is_tlm=True
)
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching)

num_devices = len(device_group)
target_model_qpc_path: str = target_model.compile(
num_cores=11,
num_devices=num_devices,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
)
draft_model_qpc_path: str = draft_model.compile(
num_cores=5,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
)
# init qaic session
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group)
if target_model_session is None:
target_model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to gracefully handle the else case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically target_model_session is passed as an argument in line 172. So if the target_model_session is None the model session is being created here. Do we need to have an else condition here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine then, we don't need else case here then.

target_model_name, continuous_batching=continuous_batching, is_tlm=True
)
target_num_devices = len(target_device_group)
target_model_qpc_path: str = target_model.compile(
num_cores=11,
num_devices=target_num_devices,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
)
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=target_device_group)
if draft_model_session is None:
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=continuous_batching)
draft_num_devices = len(draft_device_group)
draft_model_qpc_path: str = draft_model.compile(
num_cores=5,
num_devices=draft_num_devices,
prefill_seq_len=prefill_seq_len,
ctx_len=ctx_len,
aic_enable_depth_first=True,
full_batch_size=full_batch_size,
)
# init qaic session
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=draft_device_group)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

# skip inputs/outputs buffers
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))
Expand Down Expand Up @@ -293,12 +310,15 @@ def draft_spec_decode_inference(
valid_batch_indices = np.full(decode_batch_size, True, dtype=bool)
all_accept = False
it = 0
decode_draft_time = 0.0
decode_target_time = 0.0
decode_start = perf_counter()
mean_num_accepted_tokens = 0
all_accept = np.full(decode_batch_size, False, dtype=bool)
while True:
it += 1
# generate proposals from draft model
draft_start = perf_counter()
for k_ in range(num_speculative_tokens):
if all_accept.any():
# running decode one extra time in the first speculative iteration
Expand All @@ -311,31 +331,30 @@ def draft_spec_decode_inference(
tlm_precode_inputs["input_ids"][:, k_ + 1] = input_ids.flatten()
dlm_decode_inputs["input_ids"] = input_ids
dlm_decode_inputs["position_ids"][valid_batch_indices] += 1
draft_end = perf_counter() - draft_start
decode_draft_time += draft_end
# run precode on TLM to score the proposed tokens
target_start = perf_counter()
tlm_outputs = target_model_session.run(tlm_precode_inputs)
target_logits = tlm_outputs["logits"]
# greedy sampling from target model
target_tokens = target_logits.argmax(-1)
target_end = perf_counter() - target_start
decode_target_time += target_end
# exact matching between draft and target tokens
draft_tokens = tlm_precode_inputs["input_ids"][:, 1:]
matching = draft_tokens == target_tokens[:, :-1] # shape: [decode_batch_size, num_speculative_tokens]
num_tokens_selected = matching.cumprod(axis=1).sum(axis=1) + 1 # shape: [decode_batch_size]
all_accept[valid_batch_indices] = num_tokens_selected[valid_batch_indices] == num_speculative_tokens + 1
mean_num_accepted_tokens += num_tokens_selected[valid_batch_indices].mean().item()
# append selected tokens to the generated_ids
tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(
decode_batch_size, 1
)
# tlm_precode_position_ids = tlm_precode_inputs["position_ids"] + num_tokens_selected.reshape(decode_batch_size,1)+1
for bi, valid in enumerate(valid_batch_indices):
if not valid:
continue
accepted_tokens = num_tokens_selected[bi]
num_tokens_to_append = min(accepted_tokens, max_gen_len[bi] - len(generated_ids[bi]))
generated_ids[bi].extend(target_tokens[bi, :num_tokens_to_append].tolist())
# position_ids > ctx_len-1 result in erronous output for logits at each seq_len of TLM
# (e.g., ctx_len=128 -> position_ids=[127,128,129] will give erronous output at each predicted token)
if len(generated_ids[bi]) >= max_gen_len[bi] or (tlm_precode_position_ids[bi] > ctx_len - 1).any():
if len(generated_ids[bi]) >= max_gen_len[bi]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are having (>=) instead of (>) greater check, unless we are using it as an iterator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should stop the generation process when the generated IDs match the max_gen_length for the batch index. If we dont it will generate max_gen_len + 1 token IDs which is not correct for our case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks okay to me

valid_batch_indices[bi] = False
# check if all generations are done
if not valid_batch_indices.any():
Expand Down Expand Up @@ -379,16 +398,21 @@ def draft_spec_decode_inference(
e2e_throughput = (sum(generated_tokens_per_prompt) + decode_batch_size) / e2e_end
batch_decode = tokenizer.batch_decode(generated_ids)
mean_num_accepted_tokens /= it
perf_metrics = PerfMetrics(
perf_metrics = SpDPerfMetrics(
mean_ttft,
batch_ttft,
decode_throughput,
e2e_throughput,
mean_num_accepted_tokens,
max_gen_len,
generated_tokens_per_prompt,
e2e_end,
decode_end,
decode_draft_time,
decode_target_time,
it,
)
exec_info = CloudAI100ExecInfo(
exec_info = SpDCloudAI100ExecInfo(
prompts,
decode_batch_size,
batch_decode,
Expand All @@ -405,15 +429,19 @@ def draft_spec_decode_inference(
return exec_info


def optional_int(x):
def optional_int(x: Optional[str]):
if x is None:
return None
return int(x)


def comma_separated_ints(x: str):
return [int(qid) for qid in x.split(",")]


def arg_parse():
parser = ArgumentParser(description="Draft-based SpD Inference")
parser.add_argument("--prompts", type=str, nargs="+", default=Constants.INPUT_STR, help="Input prompt(s)")
parser.add_argument("--prompts", action="append", default=None, help="Input prompt(s)")
parser.add_argument("--num-speculative-tokens", type=int, default=4, help="Number of speculative tokens")
parser.add_argument("--prefill-seq-len", type=int, default=32, help="Prefill sequence length")
parser.add_argument("--ctx-len", type=int, default=128, help="Context length")
Expand All @@ -425,13 +453,26 @@ def arg_parse():
"--target-model-name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", help="Target model name"
)
parser.add_argument("--full-batch-size", type=optional_int, default=None, help="Full batch size")
parser.add_argument("--device-group", type=int, nargs="+", default=[0], help="device QIDs")
parser.add_argument(
"--target-device-group",
type=comma_separated_ints,
default="0",
help="comma separated device QIDs (e.g., '1,2,3')",
)
parser.add_argument(
"--draft-device-group",
type=comma_separated_ints,
default="0",
help="comma separated device QIDs (e.g., '1,2,3')",
)
args = parser.parse_args()
return args


def main():
args = arg_parse()
if args.prompts is None:
args.prompts = Constants.INPUT_STR
exec_info = draft_spec_decode_inference(**vars(args))
print(exec_info)
prompts = exec_info.prompts
Expand Down
Loading
Loading