88from vllm import LLM , SamplingParams
99
1010
11- def print_msg (msg ):
12- print (f"[debug] { msg } \n " )
13-
14-
1511def load_prompts (dataset_path , num_prompts ):
1612 if os .path .exists (dataset_path ):
1713 prompts = []
@@ -21,14 +17,12 @@ def load_prompts(dataset_path, num_prompts):
2117 data = json .loads (line )
2218 prompts .append (data ["turns" ][0 ])
2319 except Exception as e :
24- print_msg (f"Error reading dataset: { e } " )
20+ print (f"Error reading dataset: { e } " )
2521 return []
2622 else :
2723 prompts = [
2824 "The future of AI is" , "The president of the United States is"
2925 ]
30- print_msg (
31- f"Dataset not found at { dataset_path } , using prompts:\n { prompts } ." )
3226
3327 return prompts [:num_prompts ]
3428
@@ -53,22 +47,15 @@ def main():
5347 parser .add_argument ("--temp" , type = float , default = 0 )
5448 args = parser .parse_args ()
5549
56- print_msg (f"Starting inference with the following parameters:\n { args } " )
57-
5850 model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
5951 eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
6052
6153 max_model_len = 2048
6254
63- # Initialize tokenizer
64- print_msg (f"Loading tokenizer for model { model_dir } " )
6555 tokenizer = AutoTokenizer .from_pretrained (model_dir )
6656
67- # Load prompts
6857 prompts = load_prompts (args .dataset , args .num_prompts )
69- print_msg (f"Loaded and tokenized { len (prompts )} prompts" )
7058
71- # Tokenize prompts
7259 prompt_ids = [
7360 tokenizer .apply_chat_template ([{
7461 "role" : "user" ,
@@ -78,9 +65,6 @@ def main():
7865 for prompt in prompts
7966 ]
8067
81- # Initialize LLM
82- print_msg (
83- f"Initializing model { model_dir } with tensor parallel size { args .tp } " )
8468 llm = LLM (
8569 model = model_dir ,
8670 trust_remote_code = True ,
@@ -102,8 +86,6 @@ def main():
10286
10387 sampling_params = SamplingParams (temperature = args .temp , max_tokens = 256 )
10488
105- # Start inference
106- print_msg ("Starting inference..." )
10789 outputs = llm .generate (prompt_token_ids = prompt_ids ,
10890 sampling_params = sampling_params )
10991
0 commit comments