This repository contains the code for the Overcoming Sparsity Artifacts in Crosscoders to Interpret Chat-Tuning paper.
The trained models, along with statistics and maximally activating examples for each latent, are hosted at our huggingface page. We also provide an interactive Colab notebook and training logs in our wandb.
Our code heavily relies on an adapted version of the dictionary_learning library. Install requirements with
pip install -r requirements.txtWe cache model activations to disk. Our code assumes that you have around 4TB of storage per model available and that the environment variable $DATASTORE points to it. The training scripts will log progress to wandb. All models will be checkpointed to the checkpoints folder. The resulting plots will be generated in $DATASTORE/results.
For Gemma 2 2b:
bash train_gemma2b.shFor Llama 3.2 1b:
bash train_llama1b.shFor Llama 3.1 8b:
bash train_llama8b.shCheck out notebooks/art.py for generating the more complex plots.
The code that implements the actual crosscoders is found in our dictionary_learning fork.
This repository is organized into two main directories:
The folder scripts contains the main execution scripts
-
train_crosscoder.py- Primary crosscoder training script. Trains crosscoders on paired activations from base and chat models with support for various architectures (ReLU, batch-top-k) and normalization schemes. -
compute_scalers.py- Computes Latent Scalers using closed-form solution. Calculates beta values for a given crosscoder. -
collect_activations.py- Caches model activations for training -
collect_dictionary_activations.py- Collects activations through trained dictionaries -
collect_activating_examples.py- Gathers max-activating examples for analysis. Required for demo.
The tools folder contains various utility functions. The steering_app folder contains a streamlit app to generate steered outputs.
With the notebooks/dashboard-and-demo.ipynb notebook you can explore the crosscoders and their latents.