@@ -61,6 +61,7 @@ def step_server_config(context, server_fqdn, server_port):
6161 context .server_metrics = False
6262 context .server_process = None
6363 context .seed = None
64+ context .draft = None
6465 context .server_seed = None
6566 context .user_api_key = None
6667 context .response_format = None
@@ -107,6 +108,11 @@ def step_n_gpu_layer(context, ngl):
107108 context .n_gpu_layer = ngl
108109
109110
111+ @step ('{draft:d} as draft' )
112+ def step_draft (context , draft ):
113+ context .draft = draft
114+
115+
110116@step ('{n_ctx:d} KV cache size' )
111117def step_n_ctx (context , n_ctx ):
112118 context .n_ctx = n_ctx
@@ -254,6 +260,15 @@ def step_n_tokens_predicted(context, predicted_n):
254260 assert_n_tokens_predicted (context .completion , predicted_n )
255261
256262
263+ @step ('all predictions are equal' )
264+ @async_run_until_complete
265+ async def step_predictions_equal (context ):
266+ n_completions = await gather_tasks_results (context )
267+ assert n_completions >= 2 , "need at least 2 completions"
268+ assert_all_predictions_equal (context .tasks_result )
269+ context .tasks_result = []
270+
271+
257272@step ('the completion is truncated' )
258273def step_assert_completion_truncated (context ):
259274 step_assert_completion_truncated (context , '' )
@@ -1020,6 +1035,23 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
10201035 assert n_predicted == expected_predicted_n , (f'invalid number of tokens predicted:'
10211036 f' { n_predicted } <> { expected_predicted_n } ' )
10221037
1038+ def assert_all_predictions_equal (completion_responses ):
1039+ content_0 = completion_responses [0 ]['content' ]
1040+
1041+ if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
1042+ print (f"content 0: { content_0 } " )
1043+
1044+ i = 1
1045+ for response in completion_responses [1 :]:
1046+ content = response ['content' ]
1047+
1048+ if 'DEBUG' in os .environ and os .environ ['DEBUG' ] == 'ON' :
1049+ print (f"content { i } : { content } " )
1050+
1051+ assert content == content_0 , "contents not equal"
1052+
1053+ i += 1
1054+
10231055
10241056async def gather_tasks_results (context ):
10251057 n_tasks = len (context .concurrent_tasks )
@@ -1148,6 +1180,8 @@ def start_server_background(context):
11481180 server_args .extend (['--ubatch-size' , context .n_ubatch ])
11491181 if context .n_gpu_layer :
11501182 server_args .extend (['--n-gpu-layers' , context .n_gpu_layer ])
1183+ if context .draft is not None :
1184+ server_args .extend (['--draft' , context .draft ])
11511185 if context .server_continuous_batching :
11521186 server_args .append ('--cont-batching' )
11531187 if context .server_embeddings :
0 commit comments