@@ -252,11 +252,11 @@ def init_split_limits(self, num_dataset_splits):
252252 )
253253
254254 if len (self .sorted_data ) > 0 :
255- all_sorting_criterion = [self ._sorting_criteria (self .sorted_data [0 ])[:2 ]]
255+ all_sorting_criterion = [self ._sorting_criteria (self .sorted_data [0 ])[:- 1 ]]
256256 splits_indices = [[0 , None ]]
257257 for ix , req in enumerate (self .sorted_data ):
258258 current_sorting_criteria = self ._sorting_criteria (req )
259- current_key = current_sorting_criteria [:2 ]
259+ current_key = current_sorting_criteria [:- 1 ]
260260 if current_key not in all_sorting_criterion :
261261 all_sorting_criterion .append (current_key )
262262 splits_indices [- 1 ][1 ] = ix
@@ -269,7 +269,7 @@ def init_split_limits(self, num_dataset_splits):
269269 splits_indices = [tuple (e ) for e in splits_indices ]
270270 return num_dataset_splits , splits_indices
271271
272- def _sorting_criteria (self , request : GreedyUntilRequest ) -> tuple [bool , bool , list , int ]:
272+ def _sorting_criteria (self , request : GreedyUntilRequest ) -> tuple [bool , bool , list , int , int ]:
273273 """
274274 Collate function for generating batches.
275275
@@ -284,7 +284,13 @@ def _sorting_criteria(self, request: GreedyUntilRequest) -> tuple[bool, bool, li
284284 # The generative task has no limit except the model context
285285 if gen_length is None :
286286 gen_length = 0
287- return request .do_sample , request .use_logits , request .stop_sequence , - (len (toks ) + gen_length )
287+ return (
288+ request .do_sample ,
289+ request .use_logits ,
290+ tuple (request .stop_sequence ),
291+ gen_length ,
292+ - (len (toks ) + gen_length ),
293+ )
288294
289295
290296class GenerativeTaskDatasetNanotron (GenerativeTaskDataset ):
0 commit comments