Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions nanochat/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,46 @@ def get_dist_info():
else:
return False, 0, 0, 1

def compute_init():
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type

def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""

# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"

# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls

# Distributed setup: Distributed Data Parallel (DDP), optional
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device(device_type) # mps|cpu

if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
Expand Down
6 changes: 3 additions & 3 deletions nanochat/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer

def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
Expand Down Expand Up @@ -44,6 +44,6 @@ def document_batches():
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: I'm getting the following error at L42

!self.is_mps() INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp":1414, please report a bug to PyTorch. as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix it, change L41-42 to

tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=device.type == "cuda")

targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets
5 changes: 3 additions & 2 deletions nanochat/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ def __init__(self, config):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)

def init_weights(self):
self.apply(self._init_weights)
Expand All @@ -184,6 +182,9 @@ def init_weights(self):
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)

def _init_weights(self, module):
if isinstance(module, nn.Linear):
Expand Down
2 changes: 1 addition & 1 deletion nanochat/loss_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y < 0).any():
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0
Expand Down
4 changes: 4 additions & 0 deletions nanochat/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ def generate(self):
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections
for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name)
Expand Down
12 changes: 1 addition & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"numpy==1.26.4",
"psutil>=7.1.0",
"regex>=2025.9.1",
"setuptools>=80.9.0",
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
Expand All @@ -22,17 +23,6 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"

# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.maturin]
module-name = "rustbpe"
bindings = "pyo3"
Expand Down
20 changes: 13 additions & 7 deletions scripts/base_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import json
import random
import yaml
from contextlib import nullcontext

import pandas as pd
import torch

from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
Expand Down Expand Up @@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):

# -----------------------------------------------------------------------------
def main():
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
args = parser.parse_args()

# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()

# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2:
if args.hf_path is not None:
# atm assume that if a path is given, it's a huggingface model path
hf_path = sys.argv[1]
hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device)
model_name = hf_path # just for logging
Expand All @@ -140,7 +146,7 @@ def main():

# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device)
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)

# Write out the results to a csv file
core_metric = None
Expand Down
13 changes: 7 additions & 6 deletions scripts/base_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
"""
import os
from contextlib import nullcontext
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
Expand All @@ -20,15 +21,15 @@
split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file

# Load the base model and the tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really

# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()

# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
Expand All @@ -37,7 +38,7 @@
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")
Expand Down
39 changes: 25 additions & 14 deletions scripts/base_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
or distributed as:

torchrun --nproc_per_node=8 base_train.py

If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
"""

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
from contextlib import nullcontext

import wandb
import torch

from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
Expand All @@ -27,6 +32,8 @@
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
Expand All @@ -45,7 +52,7 @@
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
# Output
Expand All @@ -57,9 +64,12 @@
# -----------------------------------------------------------------------------

# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0

# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
Expand Down Expand Up @@ -96,7 +106,7 @@
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device="cuda")
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
Expand Down Expand Up @@ -133,8 +143,8 @@
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -193,7 +203,8 @@ def get_muon_momentum(it):

# once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
if last_step or (step > 0 and step % core_metric_every == 0):
results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
Expand All @@ -219,7 +230,7 @@ def get_muon_momentum(it):
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
engine = Engine(orig_model, tokenizer)
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
Expand Down Expand Up @@ -252,7 +263,7 @@ def get_muon_momentum(it):
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
Expand All @@ -275,7 +286,7 @@ def get_muon_momentum(it):
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -304,7 +315,7 @@ def get_muon_momentum(it):
})

# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

Expand All @@ -326,11 +337,11 @@ def get_muon_momentum(it):
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb,
"CORE metric estimate": results["core_metric"],
"CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}%",
"Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
}
])

Expand Down
Loading