Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def get_collators(collator_cfgs, **kwargs):
_register_data(ForgetRetainDataset)

# Register collators
_register_collator(DataCollatorForSupervisedDataset)
_register_collator(DataCollatorForSupervisedDataset)
15 changes: 12 additions & 3 deletions src/data/unlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ def __init__(self, forget, retain, anchor="forget"):
self.forget = forget
self.retain = retain
self.anchor = anchor
self.generator = torch.Generator()

def set_rank_seed(self, seed: int):
"""Set the rank-specific seed for this dataset.

This should be called after trainer initialization to ensure each rank
uses a unique seed for different unanchored data.
"""
self.generator.manual_seed(seed)

def __len__(self):
"""Ensures the sampled dataset matches the anchor dataset's length."""
Expand All @@ -36,11 +45,11 @@ def __getitem__(self, idx):
if self.anchor == "forget":
item["forget"] = self.forget[idx]
if self.retain:
retain_idx = torch.randint(0, len(self.retain), (1,)).item()
retain_idx = torch.randint(0, len(self.retain), (1,), generator=self.generator).item()
item["retain"] = self.retain[retain_idx]
elif self.anchor == "retain":
item["retain"] = self.retain[idx]
if self.forget:
forget_idx = torch.randint(0, len(self.forget), (1,)).item()
forget_idx = torch.randint(0, len(self.forget), (1,), generator=self.generator).item()
item["forget"] = self.forget[forget_idx]
return item
return item
15 changes: 14 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import hydra
from omegaconf import DictConfig
from data import get_data, get_collators
from data.unlearn import ForgetRetainDataset
from model import get_model
from trainer import load_trainer
from evals import get_evaluators
Expand All @@ -23,7 +25,11 @@ def main(cfg: DictConfig):
# Load Dataset
data_cfg = cfg.data
data = get_data(
data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args
data_cfg,
mode=mode,
tokenizer=tokenizer,
template_args=template_args,
seed=cfg.trainer.args.seed,
)

# Load collator
Expand Down Expand Up @@ -56,6 +62,13 @@ def main(cfg: DictConfig):
template_args=template_args,
)

# Set rank-specific seed for ForgetRetainDataset after trainer initialization
train_dataset = data.get("train", None)
if isinstance(train_dataset, ForgetRetainDataset):
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
rank_seed = cfg.trainer.args.seed + rank
train_dataset.set_rank_seed(rank_seed)

if trainer_args.do_train:
trainer.train()
trainer.save_state()
Expand Down