Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 165 additions & 70 deletions scripts/collect_dictionary_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,60 @@


@th.no_grad()
def get_positive_activations(sequences, ranges, dataset, cc, latent_ids):
def get_positive_activations_incremental(tokens, boundaries, dataset, cc, latent_ids, checkpoint_every_n_seqs=5000, temp_dir=None, dataset_name=""):
"""
Extract positive activations and their indices from sequences.
Also compute the maximum activation for each latent feature.

Extract positive activations and save to temp directory.

Args:
sequences: List of sequences
ranges: List of (start_idx, end_idx) tuples for each sequence
tokens: All tokens
boundaries: List of sequence boundary indices
dataset: Dataset containing activations
cc: Object with get_activations method
latent_ids: Tensor of latent indices to extract
checkpoint_every_n_seqs: Save checkpoint every N sequences (default 5000)
temp_dir: Directory to save temporary results
dataset_name: Name prefix for files (e.g., "fineweb" or "lmsys")

Returns:
Tuple of:
- activations tensor: positive activation values
- indices tensor: in (seq_idx, seq_pos, feature_pos) format
- max_activations: maximum activation value for each latent feature
Path to the saved results directory
"""
out_activations = []
out_ids = []
seq_ranges = [0]

sequences = []

num_sequences = len(boundaries) - 1

# Initialize tensors to track max activations for each latent
max_activations = th.zeros(len(latent_ids), device="cuda")

for seq_idx in trange(len(sequences)):

# Check for existing checkpoint
start_seq_idx = 0
checkpoint_file = None
if temp_dir is not None:
checkpoint_file = temp_dir / f"checkpoint_{dataset_name}.pt"
if checkpoint_file.exists():
logger.info(f"Loading checkpoint from {checkpoint_file}")
checkpoint = th.load(checkpoint_file, weights_only=True)
out_activations = checkpoint["activations"]
out_ids = checkpoint["ids"]
seq_ranges = checkpoint["seq_ranges"]
sequences = checkpoint["sequences"]
max_activations = checkpoint["max_activations"].cuda()
start_seq_idx = checkpoint["last_seq_idx"] + 1
logger.info(f"Resuming from sequence {start_seq_idx}/{num_sequences}")

for seq_idx in trange(start_seq_idx, num_sequences, desc=f"Processing {dataset_name} sequences"):
start_idx = boundaries[seq_idx]
end_idx = boundaries[seq_idx + 1]

# Get the sequence tokens
sequence = tokens[start_idx:end_idx]
sequences.append(sequence)

# Get activations for this sequence
activations = th.stack(
[dataset[j].cuda() for j in range(ranges[seq_idx][0], ranges[seq_idx][1])]
[dataset[j].cuda() for j in range(start_idx, end_idx)]
)
feature_activations = cc.get_activations(activations)
assert feature_activations.shape == (
Expand All @@ -52,10 +78,7 @@ def get_positive_activations(sequences, ranges, dataset, cc, latent_ids):
), f"Feature activations shape: {feature_activations.shape}, expected: {(len(activations), len(latent_ids))}"

# Track maximum activations
# For each latent feature, find the max activation in this sequence
seq_max_values, seq_max_positions = feature_activations.max(dim=0)

# Update global maximums where this sequence has a higher value
seq_max_values, _ = feature_activations.max(dim=0)
update_mask = seq_max_values > max_activations
max_activations[update_mask] = seq_max_values[update_mask]

Expand All @@ -72,38 +95,73 @@ def get_positive_activations(sequences, ranges, dataset, cc, latent_ids):
# Stack indices into (seq_idx, seq_pos, feature_pos) format
pos_ids = th.stack([seq_idx_tensor, pos_indices[0], pos_indices[1]], dim=1)

out_activations.append(pos_activations)
out_ids.append(pos_ids)
# Move to CPU immediately to free GPU memory
out_activations.append(pos_activations.cpu())
out_ids.append(pos_ids.cpu())
seq_ranges.append(seq_ranges[-1] + len(pos_ids))

out_activations = th.cat(out_activations).cpu()
out_ids = th.cat(out_ids).cpu()
return out_activations, out_ids, seq_ranges, max_activations


def split_into_sequences(tokenizer, tokens):
# Find indices of BOS tokens

# Clean up GPU memory
del activations, feature_activations, pos_mask, pos_indices, pos_activations, seq_idx_tensor, pos_ids

# Save checkpoint periodically
if checkpoint_file is not None and (seq_idx + 1) % checkpoint_every_n_seqs == 0:
logger.info(f"Saving checkpoint at sequence {seq_idx + 1}/{num_sequences}")
checkpoint = {
"activations": out_activations,
"ids": out_ids,
"seq_ranges": seq_ranges,
"sequences": sequences,
"max_activations": max_activations.cpu(),
"last_seq_idx": seq_idx
}
th.save(checkpoint, checkpoint_file)
th.cuda.empty_cache()

# Final concatenation and save to temp
logger.info(f"Finalizing {dataset_name} results...")
out_activations = th.cat(out_activations) if out_activations else th.tensor([])
out_ids = th.cat(out_ids) if out_ids else th.tensor([])

# Save results to temp directory
result_dir = temp_dir / dataset_name
result_dir.mkdir(exist_ok=True)

th.save(out_activations, result_dir / "out_acts.pt")
th.save(out_ids, result_dir / "out_ids.pt")
th.save(sequences, result_dir / "sequences.pt") # Keep original name since not padded yet
th.save(th.tensor(seq_ranges), result_dir / "seq_ranges.pt")
th.save(max_activations.cpu(), result_dir / "max_activations.pt")

logger.info(f"Saved {dataset_name} results to {result_dir}")

# Free memory but keep files
del out_activations, out_ids, sequences
th.cuda.empty_cache()

return result_dir


def get_sequence_boundaries(tokenizer, tokens):
"""
Get sequence boundaries without creating all sequences upfront.
Returns indices of BOS tokens or fixed boundaries for efficient processing.
"""
bos_mask = tokens == tokenizer.bos_token_id
indices_of_bos = th.where(bos_mask)[0]

if not bos_mask.any():
raise NotImplementedError(
"Sorry, can't fix into sequence as the model doesn't have BOS or those have been filtered out. We need to implement this in a cleaner way using the dataset directly"
)
# Split tokens into sequences starting with BOS token
sequences = []
index_to_seq_pos = [] # List of (sequence_idx, idx_in_sequence) tuples
ranges = []
for i in trange(len(indices_of_bos)):
start_idx = indices_of_bos[i]
end_idx = indices_of_bos[i + 1] if i < len(indices_of_bos) - 1 else len(tokens)
sequence = tokens[start_idx:end_idx]
sequences.append(sequence)
ranges.append((start_idx, end_idx))
# Add mapping for each token in this sequence
for j in range(len(sequence)):
index_to_seq_pos.append((i, j))

return sequences, index_to_seq_pos, ranges
# If no BOS tokens, create fixed-size sequence boundaries
seq_len = 1024
logger.warning(f"No BOS tokens found in data, using fixed-size sequences of {seq_len} tokens")
boundaries = list(range(0, len(tokens), seq_len))
boundaries.append(len(tokens))
return boundaries, False
else:
# Return BOS token positions as boundaries
boundaries = indices_of_bos.tolist()
if len(tokens) not in boundaries:
boundaries.append(len(tokens))
return boundaries, True


def add_get_activations_sae(sae, model_idx, is_difference=False):
Expand Down Expand Up @@ -204,6 +262,7 @@ def collect_dictionary_activations(
is_difference_sae: bool = False,
sae_model_idx: int | None = None,
cache_suffix: str = "",
checkpoint_every_n_seqs: int = 5000,
) -> None:
"""
Compute and save latent activations for a given dictionary model.
Expand Down Expand Up @@ -236,6 +295,8 @@ def collect_dictionary_activations(
Defaults to False.
is_difference_sae (bool, optional): Whether the SAE is trained on activation differences.
Defaults to False.
checkpoint_every_n_seqs (int, optional): Save checkpoint and clear GPU memory every N sequences.
Helps prevent OOM errors on large datasets. Defaults to 5000.

Returns:
None
Expand All @@ -246,7 +307,13 @@ def collect_dictionary_activations(
"sae_model_idx must be provided if is_sae is True. This is the index of the model activations to use for the SAE."
)

out_dir = Path(latent_activations_dir) / dictionary_model_name
# Handle case where dictionary_model_name is a file path
if "/" in dictionary_model_name and dictionary_model_name.endswith(".pt"):
# Extract meaningful name from path like /path/to/model-name/model_final.pt
model_dir_name = Path(dictionary_model_name).parent.name
out_dir = Path(latent_activations_dir) / model_dir_name
else:
out_dir = Path(latent_activations_dir) / dictionary_model_name
if cache_suffix:
out_dir = out_dir / cache_suffix
out_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -288,35 +355,54 @@ def collect_dictionary_activations(
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

seq_lmsys, idx_to_seq_pos_lmsys, ranges_lmsys = split_into_sequences(
tokenizer, tokens_lmsys
)
seq_fineweb, idx_to_seq_pos_fineweb, ranges_fineweb = split_into_sequences(
tokenizer, tokens_fineweb
)
# Create temp directory
temp_dir = out_dir / "temp"
temp_dir.mkdir(exist_ok=True)

# Get sequence boundaries without creating all sequences upfront
boundaries_fineweb, has_bos_fineweb = get_sequence_boundaries(tokenizer, tokens_fineweb)
boundaries_lmsys, has_bos_lmsys = get_sequence_boundaries(tokenizer, tokens_lmsys)

print(
f"Collecting activations for {len(seq_fineweb)} fineweb sequences and {len(seq_lmsys)} lmsys sequences"
f"Collecting activations for {len(boundaries_fineweb)-1} fineweb sequences and {len(boundaries_lmsys)-1} lmsys sequences"
)

(
out_acts_fineweb,
out_ids_fineweb,
seq_ranges_fineweb,
max_activations_fineweb,
) = get_positive_activations(
seq_fineweb, ranges_fineweb, fineweb_cache, dictionary_model, latent_ids
# Process fineweb and save to temp
fineweb_result_dir = get_positive_activations_incremental(
tokens_fineweb, boundaries_fineweb, fineweb_cache, dictionary_model, latent_ids,
checkpoint_every_n_seqs=checkpoint_every_n_seqs,
temp_dir=temp_dir,
dataset_name="fineweb"
)
out_acts_lmsys, out_ids_lmsys, seq_ranges_lmsys, max_activations_lmsys = (
get_positive_activations(
seq_lmsys, ranges_lmsys, lmsys_cache, dictionary_model, latent_ids
)

# Process lmsys and save to temp
lmsys_result_dir = get_positive_activations_incremental(
tokens_lmsys, boundaries_lmsys, lmsys_cache, dictionary_model, latent_ids,
checkpoint_every_n_seqs=checkpoint_every_n_seqs,
temp_dir=temp_dir,
dataset_name="lmsys"
)


# Load results from temp
logger.info("Loading results from temp...")
out_acts_fineweb = th.load(fineweb_result_dir / "out_acts.pt", weights_only=True)
out_ids_fineweb = th.load(fineweb_result_dir / "out_ids.pt", weights_only=True)
seq_fineweb = th.load(fineweb_result_dir / "sequences.pt", weights_only=True)
seq_ranges_fineweb = th.load(fineweb_result_dir / "seq_ranges.pt", weights_only=True).tolist()
max_activations_fineweb = th.load(fineweb_result_dir / "max_activations.pt", weights_only=True)

out_acts_lmsys = th.load(lmsys_result_dir / "out_acts.pt", weights_only=True)
out_ids_lmsys = th.load(lmsys_result_dir / "out_ids.pt", weights_only=True)
seq_lmsys = th.load(lmsys_result_dir / "sequences.pt", weights_only=True)
seq_ranges_lmsys = th.load(lmsys_result_dir / "seq_ranges.pt", weights_only=True).tolist()
max_activations_lmsys = th.load(lmsys_result_dir / "max_activations.pt", weights_only=True)

# Combine datasets for the merged output
out_acts = th.cat([out_acts_fineweb, out_acts_lmsys])
# add offset to seq_idx in out_ids_lmsys
out_ids_lmsys[:, 0] += len(seq_fineweb)
out_ids = th.cat([out_ids_fineweb, out_ids_lmsys])
out_ids_lmsys_combined = out_ids_lmsys.clone()
out_ids_lmsys_combined[:, 0] += len(seq_fineweb)
out_ids = th.cat([out_ids_fineweb, out_ids_lmsys_combined])

seq_ranges_lmsys = [i + len(out_acts_fineweb) for i in seq_ranges_lmsys]
seq_ranges = th.cat(
Expand Down Expand Up @@ -348,20 +434,25 @@ def collect_dictionary_activations(
# Convert to tensor and save
padded_tensor = th.stack(padded_seqs)

# Save tensors
# Save combined tensors
print("Saving combined dataset results...")
th.save(out_acts.cpu(), out_dir / "out_acts.pt")
th.save(out_ids.cpu(), out_dir / "out_ids.pt")
th.save(padded_tensor.cpu(), out_dir / "padded_sequences.pt")
th.save(latent_ids.cpu(), out_dir / "latent_ids.pt")
th.save(seq_ranges.cpu(), out_dir / "seq_ranges.pt")
th.save(seq_lengths.cpu(), out_dir / "seq_lengths.pt")
th.save(combined_max_activations.cpu(), out_dir / "max_activations.pt")
print(f" Saved combined results to {out_dir}")

# Print some stats about max activations
print("Maximum activation statistics:")
print("\nMaximum activation statistics (combined):")
print(f" Average: {combined_max_activations.mean().item():.4f}")
print(f" Maximum: {combined_max_activations.max().item():.4f}")
print(f" Minimum: {combined_max_activations.min().item():.4f}")

# Keep temp directory - it has the individual dataset results
logger.info(f"Individual dataset results are in: {temp_dir}")

if upload_to_hub:
# Initialize Hugging Face API
Expand Down Expand Up @@ -484,6 +575,9 @@ def hf_path(name: str):
)
latent_ids = th.cat(indices)

# Set default checkpoint interval - can be modified here if needed
checkpoint_every_n_seqs = 5000

collect_dictionary_activations(
dictionary_model_name=args.dictionary_model,
activation_store_dir=args.activation_store_dir,
Expand All @@ -499,4 +593,5 @@ def hf_path(name: str):
is_difference_sae=args.is_difference_sae,
sae_model_idx=args.sae_model_idx,
cache_suffix=args.cache_suffix,
checkpoint_every_n_seqs=checkpoint_every_n_seqs,
)