File tree Expand file tree Collapse file tree 3 files changed +5
-6
lines changed Expand file tree Collapse file tree 3 files changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -503,15 +503,13 @@ def collect(
503503        ), "Shuffling shards and storing tokens is not supported yet" 
504504
505505        # Check if we need to store sequence ranges 
506-         has_bos_token  =  model .tokenizer .bos_token_id  is  not None 
506+         has_bos_token  =  model .tokenizer .bos_token  is  not None 
507507        store_sequence_ranges  =  (
508508            store_tokens  and  
509509            not  shuffle_shards  and  
510510            not  has_bos_token 
511511        )
512-         if  store_sequence_ranges :
513-             print ("No BOS token found. Will store sequence ranges." )
514-         
512+   
515513        dataloader  =  DataLoader (data , batch_size = batch_size , num_workers = num_workers )
516514
517515        activation_cache  =  [[] for  _  in  submodules ]
Original file line number Diff line number Diff line change @@ -173,7 +173,7 @@ def loss(
173173        if  step  >  self .threshold_start_step :
174174            self .update_threshold (f )
175175
176-         x_hat  =  self .ae .decode (f , denormalize_activations = normalize_activations )
176+         x_hat  =  self .ae .decode (f , denormalize_activations = False )
177177
178178        e  =  x  -  x_hat 
179179
Original file line number Diff line number Diff line change 1111import  wandb 
1212from  typing  import  List , Optional 
1313
14+ from  .trainers .batch_top_k  import  BatchTopKTrainer 
1415from  .trainers .crosscoder  import  CrossCoderTrainer , BatchTopKCrossCoderTrainer 
1516
1617
@@ -300,7 +301,7 @@ def trainSAE(
300301                    use_threshold = False ,
301302                    epoch_idx_per_step = epoch_idx_per_step ,
302303                )
303-                 if  isinstance (trainer , BatchTopKCrossCoderTrainer ):
304+                 if  isinstance (trainer , BatchTopKCrossCoderTrainer )  or   isinstance ( trainer ,  BatchTopKTrainer ) :
304305                    log_stats (
305306                        trainer ,
306307                        step ,
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments