@@ -74,36 +74,42 @@ def setup_class(cls):
7474
7575    def  test_generate (self ):
7676        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
77-         outputs  =  self .client .generate (prompts )["completion_ids" ]
77+         outputs  =  self .client .generate (prompts )
78+         prompt_ids  =  outputs ["prompt_ids" ]
79+         completion_ids  =  outputs ["completion_ids" ]
7880
79-         # Check that the output is a list 
80-         assert  isinstance (outputs , list )
81+         # Check that the outputs are lists 
82+         assert  isinstance (prompt_ids , list )
83+         assert  isinstance (completion_ids , list )
8184
82-         # Check that the number of generated sequences is equal to the number of prompts 
83-         assert  len (outputs ) ==  len (prompts )
85+         # Check that the number of sequences are equal to the number of prompts 
86+         assert  len (prompt_ids ) ==  len (prompts )
87+         assert  len (completion_ids ) ==  len (prompts )
8488
85-         # Check that the generated sequences are lists of integers 
86-         for  seq  in  outputs :
89+         # Check that the sequences are lists of integers 
90+         for  seq  in  prompt_ids :
91+             assert  all (isinstance (tok , int ) for  tok  in  seq )
92+         for  seq  in  completion_ids :
8793            assert  all (isinstance (tok , int ) for  tok  in  seq )
8894
8995    def  test_generate_with_params (self ):
9096        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
91-         outputs  =  self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
97+         completion_ids  =  self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
9298            "completion_ids" 
9399        ]
94100
95101        # Check that the output is a list 
96-         assert  isinstance (outputs , list )
102+         assert  isinstance (completion_ids , list )
97103
98104        # Check that the number of generated sequences is 2 times the number of prompts 
99-         assert  len (outputs ) ==  2  *  len (prompts )
105+         assert  len (completion_ids ) ==  2  *  len (prompts )
100106
101107        # Check that the generated sequences are lists of integers 
102-         for  seq  in  outputs :
108+         for  seq  in  completion_ids :
103109            assert  all (isinstance (tok , int ) for  tok  in  seq )
104110
105111        # Check that the length of the generated sequences is less than or equal to 32 
106-         for  seq  in  outputs :
112+         for  seq  in  completion_ids :
107113            assert  len (seq ) <=  32 
108114
109115    def  test_update_model_params (self ):
@@ -148,36 +154,42 @@ def setup_class(cls):
148154
149155    def  test_generate (self ):
150156        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
151-         outputs  =  self .client .generate (prompts )["completion_ids" ]
157+         outputs  =  self .client .generate (prompts )
158+         prompt_ids  =  outputs ["prompt_ids" ]
159+         completion_ids  =  outputs ["completion_ids" ]
152160
153-         # Check that the output is a list 
154-         assert  isinstance (outputs , list )
161+         # Check that the outputs are lists 
162+         assert  isinstance (prompt_ids , list )
163+         assert  isinstance (completion_ids , list )
155164
156-         # Check that the number of generated sequences is equal to the number of prompts 
157-         assert  len (outputs ) ==  len (prompts )
165+         # Check that the number of sequences are equal to the number of prompts 
166+         assert  len (prompt_ids ) ==  len (prompts )
167+         assert  len (completion_ids ) ==  len (prompts )
158168
159-         # Check that the generated sequences are lists of integers 
160-         for  seq  in  outputs :
169+         # Check that the sequences are lists of integers 
170+         for  seq  in  prompt_ids :
171+             assert  all (isinstance (tok , int ) for  tok  in  seq )
172+         for  seq  in  completion_ids :
161173            assert  all (isinstance (tok , int ) for  tok  in  seq )
162174
163175    def  test_generate_with_params (self ):
164176        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
165-         outputs  =  self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
177+         completion_ids  =  self .client .generate (prompts , n = 2 , repetition_penalty = 0.9 , temperature = 0.8 , max_tokens = 32 )[
166178            "completion_ids" 
167179        ]
168180
169181        # Check that the output is a list 
170-         assert  isinstance (outputs , list )
182+         assert  isinstance (completion_ids , list )
171183
172184        # Check that the number of generated sequences is 2 times the number of prompts 
173-         assert  len (outputs ) ==  2  *  len (prompts )
185+         assert  len (completion_ids ) ==  2  *  len (prompts )
174186
175187        # Check that the generated sequences are lists of integers 
176-         for  seq  in  outputs :
188+         for  seq  in  completion_ids :
177189            assert  all (isinstance (tok , int ) for  tok  in  seq )
178190
179191        # Check that the length of the generated sequences is less than or equal to 32 
180-         for  seq  in  outputs :
192+         for  seq  in  completion_ids :
181193            assert  len (seq ) <=  32 
182194
183195    def  test_update_model_params (self ):
@@ -224,16 +236,22 @@ def setup_class(cls):
224236
225237    def  test_generate (self ):
226238        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
227-         outputs  =  self .client .generate (prompts )["completion_ids" ]
239+         outputs  =  self .client .generate (prompts )
240+         prompt_ids  =  outputs ["prompt_ids" ]
241+         completion_ids  =  outputs ["completion_ids" ]
228242
229-         # Check that the output is a list 
230-         assert  isinstance (outputs , list )
243+         # Check that the outputs are lists 
244+         assert  isinstance (prompt_ids , list )
245+         assert  isinstance (completion_ids , list )
231246
232-         # Check that the number of generated sequences is equal to the number of prompts 
233-         assert  len (outputs ) ==  len (prompts )
247+         # Check that the number of sequences are equal to the number of prompts 
248+         assert  len (prompt_ids ) ==  len (prompts )
249+         assert  len (completion_ids ) ==  len (prompts )
234250
235-         # Check that the generated sequences are lists of integers 
236-         for  seq  in  outputs :
251+         # Check that the sequences are lists of integers 
252+         for  seq  in  prompt_ids :
253+             assert  all (isinstance (tok , int ) for  tok  in  seq )
254+         for  seq  in  completion_ids :
237255            assert  all (isinstance (tok , int ) for  tok  in  seq )
238256
239257    def  test_update_model_params (self ):
@@ -280,16 +298,22 @@ def setup_class(cls):
280298
281299    def  test_generate (self ):
282300        prompts  =  ["Hello, AI!" , "Tell me a joke" ]
283-         outputs  =  self .client .generate (prompts )["completion_ids" ]
301+         outputs  =  self .client .generate (prompts )
302+         prompt_ids  =  outputs ["prompt_ids" ]
303+         completion_ids  =  outputs ["completion_ids" ]
284304
285-         # Check that the output is a list 
286-         assert  isinstance (outputs , list )
305+         # Check that the outputs are lists 
306+         assert  isinstance (prompt_ids , list )
307+         assert  isinstance (completion_ids , list )
287308
288-         # Check that the number of generated sequences is equal to the number of prompts 
289-         assert  len (outputs ) ==  len (prompts )
309+         # Check that the number of sequences are equal to the number of prompts 
310+         assert  len (prompt_ids ) ==  len (prompts )
311+         assert  len (completion_ids ) ==  len (prompts )
290312
291-         # Check that the generated sequences are lists of integers 
292-         for  seq  in  outputs :
313+         # Check that the sequences are lists of integers 
314+         for  seq  in  prompt_ids :
315+             assert  all (isinstance (tok , int ) for  tok  in  seq )
316+         for  seq  in  completion_ids :
293317            assert  all (isinstance (tok , int ) for  tok  in  seq )
294318
295319    def  test_update_model_params (self ):
@@ -336,9 +360,13 @@ def test_init_communicator_with_device_int(self):
336360
337361        # Test basic functionality 
338362        prompts  =  ["Hello, AI!" ]
339-         outputs  =  client .generate (prompts )["completion_ids" ]
340-         assert  isinstance (outputs , list )
341-         assert  len (outputs ) ==  len (prompts )
363+         outputs  =  client .generate (prompts )
364+         prompt_ids  =  outputs ["prompt_ids" ]
365+         completion_ids  =  outputs ["completion_ids" ]
366+         assert  isinstance (prompt_ids , list )
367+         assert  len (prompt_ids ) ==  len (prompts )
368+         assert  isinstance (completion_ids , list )
369+         assert  len (completion_ids ) ==  len (prompts )
342370
343371        client .close_communicator ()
344372
0 commit comments