@@ -530,21 +530,19 @@ def greedy_until(
530530 starting_batch_size = STARTING_BATCH_SIZE
531531 results = []
532532
533- for split_start , split_end in tqdm (
534- dataset .splits_start_end_iterator (),
533+ for split in tqdm (
534+ dataset .splits_iterator (),
535535 total = dataset .num_dataset_splits ,
536536 desc = "Splits" ,
537537 position = 0 ,
538538 disable = self .disable_tqdm ,
539539 ):
540- if dataset [0 ].generation_size is None :
540+ if split [0 ].generation_size is None :
541541 # No constraints on the generation size: max length allowed is the max model context
542542 max_context_continuation_size_allowed = self .max_length
543543 else :
544544 # Longest context in the current split is the first item (since we sort reversed)
545- longest_context_continuation_size_in_split = (
546- len (dataset [0 ].tokenized_context ) + dataset [0 ].generation_size
547- )
545+ longest_context_continuation_size_in_split = len (split [0 ].tokenized_context ) + split [0 ].generation_size
548546 max_context_continuation_size_allowed = min (
549547 longest_context_continuation_size_in_split , self .max_length
550548 )
@@ -556,7 +554,7 @@ def greedy_until(
556554 # For next iteration, since the batch will be smaller, we'll test a bigger batch size
557555 starting_batch_size = batch_size * 2
558556
559- dataloader = DataLoader (dataset , batch_size = batch_size , collate_fn = lambda batch : batch )
557+ dataloader = DataLoader (split , batch_size = batch_size , collate_fn = lambda batch : batch )
560558 if self .accelerator :
561559 dataloader = self .accelerator .prepare (dataloader )
562560
@@ -765,9 +763,9 @@ def _loglikelihood_tokens(
765763 starting_batch_size = STARTING_BATCH_SIZE
766764 res = []
767765
768- for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
769- context_enc = dataset [0 ].tokenized_context
770- continuation_enc = dataset [0 ].tokenized_continuation
766+ for split in tqdm (dataset .splits_iterator ()):
767+ context_enc = split [0 ].tokenized_context
768+ continuation_enc = split [0 ].tokenized_continuation
771769 if rolling : # we take all the sequence in rolling mode
772770 max_context_continuation_size_allowed = len (context_enc + continuation_enc )
773771 else : # in normal mode, we left cut the context if needed
@@ -782,7 +780,7 @@ def _loglikelihood_tokens(
782780 )
783781 starting_batch_size = batch_size * 2
784782
785- dataloader = DataLoader (dataset , batch_size = batch_size , collate_fn = lambda batch : batch )
783+ dataloader = DataLoader (split , batch_size = batch_size , collate_fn = lambda batch : batch )
786784 if self .accelerator :
787785 dataloader = self .accelerator .prepare (dataloader )
788786
@@ -1009,13 +1007,13 @@ def _loglikelihood_single_token(
10091007 starting_batch_size = STARTING_BATCH_SIZE
10101008 res = []
10111009
1012- for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
1013- context_enc = dataset [0 ].tokenized_context
1010+ for split in tqdm (dataset .splits_iterator ()):
1011+ context_enc = split [0 ].tokenized_context
10141012 max_context = len (context_enc [- self .max_length :])
10151013 batch_size = self ._get_batch_size (override_bs = self .config .batch_size , max_input_length = max_context )
10161014 starting_batch_size = batch_size * 2
10171015
1018- dataloader = DataLoader (dataset , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
1016+ dataloader = DataLoader (split , batch_size = starting_batch_size , collate_fn = lambda batch : batch )
10191017 if self .accelerator is not None :
10201018 dataloader = self .accelerator .prepare (dataloader )
10211019
0 commit comments