Nanochat: A Deep Dive
25 Mar 2026Nanochat is a GPT-2-beating LLM training codebase by Karpathy — tokenizer, pretraining, SFT, RL, and eval in readable single-file scripts. No frameworks, no abstractions you can’t trace end-to-end. The speedrun pipeline trains a 24-layer model that beats GPT-2’s CORE score on wall clock, using value embeddings, FP8 training, and a collection of non-standard but justified architectural choices.
Summary: Why Nanochat Is Worth a Look
- Complete pipeline in one repo: tokenizer → pretrain → SFT → RL → eval, all readable single-file scripts.
- Value Embeddings (VE): per-layer token lookup tables added to attention values at every other layer. Zero FLOPs, ~55% of total parameters. The model “loves” them.
- Scaling laws done right: IsoFLOP grid sweep (4 budgets × 6 depths), ~10.5 tokens-per-parameter compute-optimal ratio. Speedrun uses ratio=8 to optimize wall-clock.
- Minimal FP8 training: ~150 lines vs torchao’s ~2000.
- No DataLoader, no DistributedSampler: manual rank-based striding over parquet row groups. BOS-aligned best-fit packing, ~100% sequence utilization.
- Non-standard architecture choices: ReLU², QK norm after RoPE, logit softcap, x0 residual, smear, backout, sliding window (SSSL pattern).
- Five eval methods: base has 3 (multiple_choice, schema, language_modeling — all loss/argmax-based, no generation), SFT has 2 (categorical via logit argmax over letter tokens, generative via sampling with tool use). Base uses CORE metric, SFT uses ChatCORE metric — both are mean centered accuracy across their respective task sets.
- RL is dead simple: vanilla REINFORCE on GSM8K, no std normalization, no PPO, no value network.
- Engine with tool use: KV-cache inference with calculator calls mid-generation via forced token injection.
- Experiment log as first-class artifact:
dev/LOG.mddocuments every decision — what worked, what failed, and why. - Attention learns document boundaries without hard masks: BOS markers, RoPE distance decay, and sliding window. No varlen attention needed.
- Everything is traceable: weight init math derived, per-type optimizer LRs, batch size auto-scales from Power Lines paper. No magic numbers.
Q&A
Naming and Conventions
Q: Why is the token embedding called wte?
“weight token embedding” — inherited from GPT-2’s naming convention. w = weight, t = token, e = embedding.
Q: Why are attention weights named with c_ prefix, like c_q, c_proj?
From GPT-2’s original implementation which used Conv1D (1×1 convolution) instead of nn.Linear for all projections. The c_ prefix meant “conv”. Nanochat uses nn.Linear but keeps the naming for continuity with the GPT lineage.
Model Architecture
Q: One-sentence summaries for all non-GPT-2 features?
- RoPE: Rotary position embeddings replacing learned positional embeddings, applied to Q and K.
- RMSNorm: Root-mean-square normalization with no learnable parameters, replacing LayerNorm.
- QK Norm: L2-normalizes Q and K after RoPE, scaled by 1.2, preventing attention logit explosion.
- ReLU²:
F.relu(x).square()activation in MLP — sparse and cheap, replacing GELU. - GQA: Grouped-query attention allowing fewer KV heads than query heads.
- Untied embeddings: Separate
wte(input) andlm_head(output) embedding matrices. - Sliding window attention: SSSL pattern (Short-Short-Short-Long) alternating window sizes across layers.
- Value Embeddings: Per-layer token lookups gated and added to attention values at every other layer.
- x0 residual: Every layer adds to the original input (layer 0 output), not just the previous layer.
- Smear: Bigram-like mixing of previous token into current token representation at input.
- Backout: Cache residual at halfway layer, subtract
backout_lambda-scaled version before final norm — removes low-level features from logit computation. - Logit softcap:
15 * tanh(logits/15)— soft clipping to prevent extreme logit values.
Q: What are all the weight initialization choices?
The full init scheme (gpt.py:init_weights):
| Component | Method | Value | Rationale |
|---|---|---|---|
wte (input embed) |
normal | std=0.8 | |
lm_head (output embed) |
normal | std=0.001 | Near-zero so logits start small and uniform |
c_q, c_k, c_v |
uniform | U(-s, s), s=√(3/n_embd) |
Uniform avoids outliers vs normal; variance-matched (Var of U(-s,s) = s²/3 = 1/n_embd) |
c_proj (attn output) |
zeros | 0 | Zero-init output projections — each block starts as identity |
c_fc (MLP up) |
uniform | U(-0.4s, 0.4s) |
0.4× reduced scale for MLP input |
c_proj (MLP output) |
zeros | 0 | Same zero-init pattern as attn output |
| Value embeddings | uniform | U(-s, s) |
Same as c_v |
ve_gate |
uniform | U(0, 0.02) |
Small positive so gates start slightly above neutral |
resid_lambdas |
per-layer linear decay | 1.15 → 1.05 | Stronger residual at early layers, weaker at deep layers |
x0_lambdas |
per-layer linear decay | 0.20 → 0.05 | Earlier layers get more input embedding blending |
backout_lambda |
constant | 0.2 | |
smear_lambda |
constant | 0.0 | Disabled at init, learned during training |
Key design principles:
- Output projections are zero-initialized (
c_projin both attn and MLP). At init each transformer block is effectively an identity — the residual stream passes through untouched. This is a common trick for stable deep training. - Uniform over normal for transformer matrices: same variance but bounded, avoiding outlier weights at init.
- Per-layer decay for
resid_lambdasandx0_lambdas: early layers get stronger residual connections and more x0 blending, deep layers are more independent. - From dev/LOG.md: orthogonal init was tried but didn’t help. The x0 residual scalars were originally zero-initialized (disabled at start), but the current per-layer decaying init (1.15→1.05 and 0.20→0.05) was found to work better.
How does this compare to Kaiming initialization?
Kaiming (He) init sets Var(W) = gain²/fan_in, with gain chosen for the downstream nonlinearity. The three common gains:
| Setting | gain² | Variance | Uniform bound |
|---|---|---|---|
| Linear / identity (≡ Xavier) | 1 | 1/fan_in | √(3/fan_in) |
| ReLU (standard Kaiming-He) | 2 | 2/fan_in | √(6/fan_in) |
PyTorch nn.Linear default (a=√5) |
1/3 | 1/(3·fan_in) | 1/√fan_in |
Mapping nanochat’s choices onto this table:
c_q,c_k,c_v, VE uses = √(3/n_embd), i.e. Var = 1/n_embd. That’s the Xavier / linear-gain Kaiming-uniform setting. A natural choice for attention projections since they feed into a softmax (no ReLU-style activation), and QK norm already controls the attention-logit scale so the “linear gain” baseline is appropriate.c_fc(MLP up-proj) uses0.4s, i.e. Var = 0.16/n_embd — about 6× smaller than Xavier and 12× smaller than Kaiming-ReLU. This deliberately undershoots because the MLP activation is ReLU² (squared ReLU), which amplifies variance much more aggressively than plain ReLU. Shrinking the input weights keeps the post-activation variance in check.c_proj(both attn and MLP output) is zero-initialized — a departure from Kaiming, which would never set weights to zero (dead neurons). The trick works here because the residual stream carries gradients past the block: zero-init makes each block an identity at start, and gradients still flow freely through the skip connection.wte(input embed) uses std=0.8 — much larger than any Kaiming-style setting (Kaiming at fan_in=1 would give Var=1). This is token-embedding lore: scale up so embeddings dominate early in training, thenlm_head’s near-zero init (std=0.001) ensures logits start small and uniform regardless.- PyTorch’s default
nn.Linear(a=√5) gives Var = 1/(3·fan_in) — a historical accident that’s 3× smaller than Xavier. Nanochat explicitly overrides this everywhere, using the Xavier/linear baseline for most matrices and zero-init for output projections.
Q: What is backout?
Cache the residual at the halfway layer (n_layer // 2), then subtract a learned-scaled version (backout_lambda, init 0.2) before the final norm and LM head. The idea is to remove low-level features from the mid-layer that helped internal processing but shouldn’t influence logit computation. Active in the codebase at gpt.py:449-459.
Q: What are value embeddings?
Per-layer token lookup tables (same shape as wte but sized to kv_dim) at every other layer, whose output is gated and added directly to the attention values. Determined by has_ve(layer_idx, n_layer) which uses layer_idx % 2 == (n_layer - 1) % 2. The gating is gate = 3 * sigmoid(ve_gate(x[..., :12])), then v = v + gate.unsqueeze(-1) * ve. VE is unconditionally enabled — there is no toggle flag.
Q: How big are VE embeddings compared to wte?
For the speedrun d24 config (n_layer=24, n_embd=1536, vocab_size=32768): wte is ~50M params. VE exists at 12 of 24 layers, each sized (vocab_size, kv_dim). Total VE is ~55-61% of all model parameters. They cost zero FLOPs (just table lookups).
Q: How much do VE help? (from dev/LOG.md)
The experiment log shows VE is one of the biggest wins. Every attempt to reduce VE capacity hurt performance — the model “loves” them. They provide a massive parameter reservoir that enriches attention values without adding compute.
Tensor Operations
Q: Does gate.unsqueeze(-1) add a new dimension?
Yes. If gate is (B, T, H), then gate.unsqueeze(-1) becomes (B, T, H, 1) — a new trailing dimension for broadcasting against the (B, T, H, D) value tensor.
Q: What other ways to do x[..., :self.ve_gate_channels]?
x.narrow(-1, 0, channels)— no copy, same as slicingx[..., :channels]— the ellipsis form used in codex[:, :, :channels]— explicit dims (fragile if rank changes)torch.split(x, [channels, rest], dim=-1)[0]— overkill here
Q: Is torch.zeros(batch_size) a row or column vector?
Neither — it’s a 1D tensor of shape (batch_size,). PyTorch 1D tensors have no row/column distinction. You need (1, n) for row or (n, 1) for column.
Q: How to outer product two 1D tensors?
torch.outer(a, b) or equivalently a.unsqueeze(-1) * b.unsqueeze(0) or torch.einsum('i,j->ij', a, b). All produce shape (len(a), len(b)).
Q: Explain idx.gather(1, choice)?
gather(dim, index) selects elements along dim using positions from index. The result always has the same shape as index, regardless of the source tensor’s shape. For idx of shape (B, V) and choice of shape (B, 1):
# For each row b, picks idx[b, choice[b, 0]]
result = idx.gather(1, choice) # (B, 1) — same shape as choice, not idx
Used in sample_next_token to map multinomial’s position indices back to actual vocab token IDs after top-k filtering.
Q: Intuitively what is torch.gather doing?
Gather is a parallel lookup where the output keeps the same grid layout as the input, but along one axis each cell gets to pick which element it wants.
Imagine input as a 2D spreadsheet of shape (rows, cols):
col0 col1 col2 col3
row0 10 20 30 40
row1 50 60 70 80
row2 90 100 110 120
gather(dim=1, index=...) means: “keep the rows fixed, each row picks which columns it wants.” You give each row its own shopping list:
index = [[2, 0], # row 0: "I want col 2, then col 0" → [30, 10]
[1, 1], # row 1: "I want col 1 twice" → [60, 60]
[3, 2]] # row 2: "I want col 3, then col 2" → [120, 110]
The rows stay aligned. Row 0’s picks end up in row 0 of the output. That’s why index must have the same ndim as input — each row’s shopping list lives in that row of the index. If index were just 1D [2, 0], “is this row 0’s list or column 0’s list?” would be ambiguous.
Three-sentence summary:
gatherdoes a parallel lookup — every output cell picks one element from input.- The non-
dimcoordinates stay the same (rows stay rows, batches stay batches). - Along
dim, each cell picks according toindex— which is whyindexmust share the grid shape ofinput.
Concrete use in LMs (selective_log_softmax in one line):
logits: (B, T, V) # batch × time × vocab
tokens: (B, T, 1) # each (b, t) asks: "what's the logit for MY token?"
logits.gather(dim=2, index=tokens) # (B, T, 1)
For each (batch, timestep) cell, the output picks out the logit of the actual token at that cell. The (B, T) grid is preserved; only the vocab axis collapses to the chosen token’s value.
Contrast with torch.index_select, which picks entire slices along dim (all rows/all cols). Its index is 1D because you only need to say which slice — the other dims come along for free. gather picks individual elements with a potentially different choice per row, so it needs the full grid-shaped index.
Q: What are all the tensor shape tricks used in the repo?
Dimension Expansion / Broadcasting:
| Code | File:Line | Pre Shape | Post Shape | Purpose |
|---|---|---|---|---|
gate.unsqueeze(-1) |
gpt.py:95 | (B,T,n_kv_head) | (B,T,n_kv_head,1) | Broadcast gate scalar per-head across head_dim |
cos[None,:,None,:] |
gpt.py:277 | (T,D/2) | (1,T,1,D/2) | Add batch+head dims for broadcasting with (B,T,H,D) |
arange().unsqueeze(1) |
flash_attention.py:94 | (Tq,) | (Tq,1) | Column vector for outer-product-style mask |
arange().unsqueeze(0) |
flash_attention.py:95 | (Tk,) | (1,Tk) | Row vector for outer-product-style mask |
tokens.unsqueeze(1) |
engine.py:279 | (B,) | (B,1) | Add seq dim for single-token decode |
advantages.unsqueeze(-1) |
chat_rl.py:266 | (B,) | (B,1) | Broadcast per-sample advantage across token positions |
logits.expand(N,-1) |
engine.py:211 | (1,V) | (N,V) | Copy batch=1 logits to N parallel samples |
prev_emb.expand(B,-1,-1) |
engine.py:137 | (1,1,C) | (B,1,C) | Expand cached embedding to batch for prefill |
Reshape (split/merge dimensions):
| Code | File:Line | Pre Shape | Post Shape | Purpose |
|---|---|---|---|---|
c_q(x).view(B,T,H,D) |
gpt.py:87-89 | (B,T,H*D) | (B,T,H,D) | Split linear output into heads |
ve.view(B,T,n_kv,D) |
gpt.py:93 | (B,T,kv_dim) | (B,T,n_kv_head,D) | Split VE into heads |
y.contiguous().view(B,T,-1) |
gpt.py:124 | (B,T,H,D) | (B,T,C) | Merge heads back after attention |
logits.view(-1,V) |
gpt.py:472 | (B,T,V) | (B*T,V) | Flatten for cross_entropy |
targets.view(-1) |
gpt.py:472 | (B,T) | (B*T,) | Flatten for cross_entropy |
buffer.view(B,T) |
dataloader.py:117-120 | (B*T,) | (B,T) | Shape flat token buffer into batch |
input.reshape(-1,C) |
fp8.py:208 | (B,T,C) | (B*T,C) | Flatten for 2D-only _scaled_mm |
Concatenation / Stacking:
| Code | File:Line | Pre Shapes | Post Shape | Purpose |
|---|---|---|---|---|
torch.cat([y1,y2], 3) |
gpt.py:63 | 2×(B,T,H,D/2) | (B,T,H,D) | Rejoin rotated halves in RoPE |
torch.cat([x[:,:1], smeared], 1) |
gpt.py:432 | (B,1,C)+(B,T-1,C) | (B,T,C) | Smear: keep 1st token, mix rest with previous |
torch.cat((ids, next_ids), 1) |
gpt.py:505 | (B,T)+(B,1) | (B,T+1) | Append generated token |
torch.stack(grads) |
optim.py:259 | K×(M,N) | (K,M,N) | Batch same-shape grads for Muon Newton-Schulz |
Layout Transpose (SDPA fallback):
| Code | File:Line | Pre Shape | Post Shape | Purpose |
|---|---|---|---|---|
q.transpose(1,2) |
flash_attention.py:123 | (B,T,H,D) | (B,H,T,D) | FA3 layout → SDPA layout |
y.transpose(1,2) |
flash_attention.py:128 | (B,H,T,D) | (B,T,H,D) | SDPA layout → FA3 layout back |
Python / PyTorch Patterns
Q: @torch.inference_mode() vs no_grad() and model.eval()?
inference_mode(): stricter and faster thanno_grad()— disables both gradient computation AND version tracking on tensors. Tensors created inside cannot be used for backward later.no_grad(): only disables gradient computation, tensors still track versions.model.eval(): changes module behavior (e.g., BatchNorm uses running stats, Dropout disabled). Must be called separately — neitherinference_modenorno_graddoes this.
Q: How does the timeout context manager work?
Uses Unix SIGALRM signals:
@contextmanager
def timeout(seconds):
signal.signal(signal.SIGALRM, handler) # register alarm handler
signal.alarm(seconds) # start countdown
try:
yield # run the guarded code
finally:
signal.alarm(0) # cancel alarm
The with statement doesn’t need an as clause — the context manager yields nothing. If you write with timeout(5) as x, x would be None.
Q: torch.compile(model, dynamic=False) — what does dynamic mean?
dynamic=False tells the compiler all tensor shapes are fixed. It compiles one optimized kernel per shape and reuses it. dynamic=True would generate shape-generic code that handles varying sizes at runtime, but with less optimization. Nanochat uses False because sequence length and batch size are constant during training.
Q: What is the meta device pattern? What does model.to_empty(device) do?
meta device creates tensors with shapes and dtypes but no memory allocation — useful for building large models without OOM during init. to_empty(device) allocates real memory on the target device but leaves values uninitialized (garbage). Then init_weights() fills them. This differs from model.to(device) which would try to copy the (nonexistent) meta tensor data.
Q: inv_scale = scale.reciprocal() vs 1/scale?
Functionally identical for scalars. .reciprocal() is a method that stays in the tensor computation graph, while 1/scale calls torch.div. In practice no measurable difference — it’s a style choice.
Q: ColoredFormatter.COLORS — class variable or instance variable?
Class variable. Defined directly in the class body (not in __init__), shared across all instances.
Q: Why s = 3**0.5 * n_embd**-0.5 for weight init?
The goal is U(-s, s) with variance matching N(0, 1/n_embd). Variance of U(-s, s) is s²/3. Setting s²/3 = 1/n_embd gives s = sqrt(3/n_embd) = 3^0.5 * n_embd^(-0.5).
Derivation of Var(U(-s,s)) = s²/3:
Var = E[X²] - E[X]² = E[X²] - 0
E[X²] = ∫₋ₛˢ x²·(1/2s) dx = (1/2s)·[x³/3]₋ₛˢ = (1/2s)·(2s³/3) = s²/3
Q: Is apply_rotary_emb’s rotation direction correct?
Yes. The rotation is applied in the opposite direction compared to the standard formulation, but since both Q and K are rotated the same way, the signs cancel in the Q·K dot product. The relative rotation (which encodes relative position) is preserved.
Q: Why short_window = -(-long_window // 4 // 128) * 128?
Double negation implements ceiling division. -(-x // n) = ceil(x/n). This computes ceil(long_window / 4) rounded up to the next multiple of 128 (for tensor core alignment).
Flash Attention and FP8
Q: Why is Flash Attention 3 loaded via kernels.get_kernel(...)?
FA3 requires Hopper (SM90) and isn’t in PyTorch’s standard distribution. HuggingFace’s kernels package provides pre-built wheels that can be fetched at runtime: get_kernel('varunneal/flash-attention-3'). Falls back to SDPA on non-Hopper GPUs.
Q: What does the SDPA fallback do for sliding window?
SDPA doesn’t natively support sliding window, so the code builds explicit boolean attention masks. For chunk inference (where Tq != Tk), it also handles the offset between query and key positions. The mask is constructed to only allow attention within the window size.
Q: In fp8.py, what is rowwise vs tensorwise scaling?
Nanochat uses tensorwise scaling: one scale factor per entire tensor (scale = FP8_MAX / max(|tensor|)). Rowwise would compute a separate scale per row — finer grained but more overhead. Tensorwise is simpler and sufficient here.
Data Loading
Q: How do documents become batches in the dataloader?
_document_batches is an infinite iterator over parquet row groups, sharded by rank (rg_idx += ddp_world_size). Documents are tokenized, then tokenizing_distributed_data_loader_with_state_bos_bestfit packs multiple documents per row using BOS-aligned best-fit packing. A pre-allocated pinned CPU buffer collects packed rows, then a single HtoD transfer moves the batch to GPU. ~100% sequence utilization, ~35% document crop waste.
Q: Can one row include more than one document?
Yes — that’s the point of best-fit packing. Multiple documents are packed into a single sequence row, separated by BOS tokens. The model learns soft document boundaries without hard attention masks.
Q: So no PyTorch DataLoader or DistributedSampler?
Correct. All data distribution is manual. Each rank strides through row groups: rg_idx += ddp_world_size. No sampler, no DataLoader.
Q: How does attention learn to ignore cross-document tokens when packed together?
Through soft signals: BOS tokens mark boundaries, RoPE resets positional distance, and sliding window limits how far attention can reach. Experiments with varlen attention (hard boundaries) showed no benefit — the model learns effective boundaries on its own.
Q: Why was midtraining removed? (commit 1ddaad1)
The pipeline was originally 3 stages: pretrain → midtrain → SFT. Midtraining ran on the SFT data mixture (SmolTalk, GSM8K, MMLU, etc.) with the pretraining-style packed dataloader, but discarded the loss mask (ids, _ = render_conversation(...)) — so loss was on all tokens (user + assistant). It served as a domain adaptation step at pretraining-scale throughput.
The old SFT was inefficient: no sequence packing, one conversation per row, small batch sizes. It couldn’t process enough tokens for domain adaptation on its own, so midtraining bridged that gap.
Once the BOS-aligned best-fit packing dataloader was introduced, SFT was upgraded to use it — gaining the same packed throughput (total_batch_size=524288 in tokens) while retaining assistant-only loss masking. The new SFT absorbed midtraining’s role, making it redundant. The pipeline simplified to pretrain → SFT.
Distributed Training
Q: Does dist.init_process_group(backend='nccl', device_id=device) need a host/port?
Yes — NCCL needs a rendezvous point. torchrun sets environment variables MASTER_ADDR and MASTER_PORT automatically. init_process_group reads them from the environment by default.
Evaluation
Q: Explain evaluate_bpb — the where makes it complex.
BPB = bits per byte. The function:
- Runs cross-entropy loss with
reduction='none'→ per-token NLL in nats - Uses
torch.whereto safely indextoken_bytes[y]— whenycontains -1 (ignore_index), you can’t index a tensor with negative values (it would wrap around). Sotorch.wherereplaces -1 with 0 for safe indexing, then zeros out those entries’ byte counts - Sums nats over all valid tokens
- Divides by total UTF-8 bytes of the original text
- Converts nats to bits:
/ log(2)
This gives a vocab-independent metric: how many bits the model needs per byte of raw text.
Q: token_bytes is 0 for special tokens — what is token_bytes[y]?
token_bytes is a lookup table: token_bytes[token_id] = number of UTF-8 bytes that token represents. Special tokens map to 0 bytes, so they don’t inflate the denominator. The total bytes denominator counts only real text.
Q: One-sentence summary of evaluate_bpb?
Sums the cross-entropy loss (in bits) over all non-special tokens, then divides by the total UTF-8 bytes of the original text — giving a vocab-independent measure of how many bits the model needs per byte.
Q: How can a base LLM evaluate tasks — explain evaluate_example?
Three task types, all loss-based (no generation):
-
multiple_choice: Common prefix + different continuations. Score = mean loss over the continuation tokens. Pick the continuation with lowest loss. Example: “The capital of France is” → [“Paris”, “London”, “Berlin”] — lowest loss on “Paris”.
-
schema: Different contexts with the same continuation. Score = mean loss over the shared suffix. Pick the context that makes the suffix most likely. Example: contexts describe different scenarios, suffix is “The answer is yes” — which context makes that suffix most natural?
-
language_modeling: Feed prompt + answer, check if the model’s argmax predictions match the answer tokens exactly. No loss comparison — binary match/no-match.
Q: For multiple_choice, the continuations have different lengths — do we batch them?
Yes. find_common_length identifies the shared prefix. Different continuations are batched together. Mean loss normalizes for length differences so longer continuations aren’t penalized.
Q: For language_modeling, how do we know how long to generate?
We don’t generate. The answer’s token length is known. We feed the prompt + answer and check if argmax predictions at each position match. It’s a teacher-forcing check, not free-form generation.
Q: In both eval types, is it greedy decoding or sampling?
Neither — base CORE eval uses no decoding at all. It’s purely loss-based (multiple_choice, schema) or argmax-matching (language_modeling). No temperature, no sampling, no top-k.
Q: What is the CORE metric?
A 22-task composite benchmark using centered accuracies. Each task’s accuracy is centered (subtract expected random performance), then averaged across all tasks. Used to benchmark against GPT-2.
Q: Explain run_chat_eval — how does it work?
Two modes:
-
Categorical: Forward pass only. At the answer position, look at logits for letter tokens (A, B, C, D). Argmax over just those logits = the model’s answer. No generation needed.
-
Generative: Use
engine.generate_batchto actually sample completions. Each task’sevaluate()method parses the output (e.g., regex for numbers in GSM8K).pass@k= any of k completions correct.
Q: For categorical, the model needs to respond with one letter — what if it says “the answer is A”?
It doesn’t generate at all. The logits method looks at the probability distribution at the single answer position and compares logits for just the letter tokens. It’s not about what the model would generate — it’s about which letter the model assigns highest probability to.
Q: Is categorical the same as pretrain’s loss-based eval?
No. Pretrain CORE uses mean loss over continuations (comparing multiple completions). Chat categorical compares logits at a single position for specific letter tokens. Different approach, different task format.
Q: For GSM8K, how robust is the answer parser?
It uses regex to find #### number at the end of the output. Strips commas for comparison. It’s simple string matching — evaluate() returns True/False, and reward() reuses this for RL.
Q: In SFT generation eval, does the model know to stop at assistant_end?
Yes. The engine checks for the assistant_end token during generation and marks that row as complete. Generation stops for that sequence.
Training Scripts
Q: How does build_model_meta(depth) work?
model_dim = depth × aspect_ratio (aspect_ratio=64), rounded to head_dim (128). Builds on meta device (no memory), then to_empty(device) for uninitialized allocation, then init_weights().
Q: How is target-flops enforced — does it convert to num_iterations?
Yes: num_iterations = round(target_flops / (flops_per_token × total_batch_size)). Precedence: --num-iterations > --target-flops > --target-param-data-ratio.
Q: Where does weight decay ramp to zero? I don’t see it in base_train.py.
It’s there — weight decay follows a cosine schedule: wd * 0.5 * (1 + cos(π * it / num_iterations)). At it = num_iterations, cosine = -1, so wd * 0.5 * 0 = 0. The SFT script inherits this (already zero) and sets weight_decay=0.0 explicitly.
Q: How does sft_data_generator_bos_bestfit work? What format is SFT data?
Tasks (SmolTalk, MMLU×3, GSM8K×4, etc.) → TaskMixture flattens and interleaves → each example is a conversation dict → render_conversation() converts to (token_ids, loss_mask) where mask=1 for assistant text and assistant_end, mask=0 for user text and special tokens → best-fit packing into sequences → padded positions get target=-1.
Q: Does <|assistant_end|> have mask 0 or 1?
Mask = 1. The model needs to learn to predict when to stop. assistant_end is part of the assistant’s output that should be learned.
SFT Data and Chat
Q: How does SFT handle DDP without a DistributedSampler?
Manual striding: cursor = ddp_rank, then cursor += ddp_world_size after each example. Each rank sees every Nth example. Same pattern as pretraining’s row-group striding.
RL Training (chat_rl.py)
Q: Is there sequence packing in RL? One row = one task?
No packing. One question per batch → num_samples rollouts → all rollouts for that question are processed together.
Q: Pseudocode for the RL training loop?
for step in range(num_steps):
# Generate rollouts
for each question:
generate num_samples completions using engine
score each: reward = 1 if correct, 0 if wrong
advantages = rewards - mean(rewards) # NO std normalization
# Train on rollouts
for each question's rollouts:
for each micro-batch (device_batch_size):
logp = -cross_entropy(model(inputs), targets, reduction='none') # (B, T)
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
num_valid = (targets >= 0).sum()
loss = -pg_obj / (num_valid * num_passes * examples_per_rank)
loss.backward()
optimizer.step()
Q: What are the shapes of logp, pg_obj, num_valid?
With device_batch_size=4, max_seq_len=512:
| Variable | Shape | Notes |
|---|---|---|
inputs |
(4, 512) |
(B, T) |
targets |
(4, 512) |
(B, T), -1 for masked positions |
model(..., loss_reduction='none') |
(2048,) |
(B*T,) flat |
logp |
(4, 512) |
.view_as(inputs) reshapes back, 0 at masked positions |
advantages |
(4,) |
(B,) one per rollout |
advantages.unsqueeze(-1) |
(4, 1) |
broadcasts across T |
logp * advantages.unsqueeze(-1) |
(4, 512) |
broadcast multiply |
pg_obj |
scalar | .sum() over everything |
num_valid |
scalar | count of targets >= 0 |
logp is 0 at masked positions because cross_entropy(..., ignore_index=-1, reduction='none') returns 0.0 where target=-1. So masked positions contribute nothing to pg_obj.
Q: There is no std(rewards) in advantage normalization?
Correct. advantages = rewards - mean(rewards) only. No std normalization. This is an explicit design choice (commented in the code). Works because the system is on-policy — no need for PPO ratio clipping or value baselines.
Scaling Laws and Speedrun
Q: What does scaling_laws.sh do?
Runs a grid: 4 FLOP budgets × 6 depths. For each combination, trains a model and records the loss. This produces IsoFLOP curves — for each FLOP budget, plot loss vs depth to find the compute-optimal depth.
Q: What are IsoFLOP curves?
Fix a FLOP budget, vary model size (depth). Plot final loss vs size. The minimum of each curve gives the compute-optimal model size for that budget. Connecting the minima across budgets reveals scaling laws.
Q: What is the Chinchilla scaling law formula?
Hoffmann et al. (2022) fit the loss as a sum of three terms — an irreducible floor plus a power-law decay in model size N (parameters) and dataset size D (tokens):
With the published fit coefficients (Table A.3 of the paper):
| Coefficient | Value | Meaning |
|---|---|---|
| E | 1.69 | Irreducible loss (entropy of natural text) |
| A | 406.4 | Model-size term amplitude |
| α | 0.34 | Model-size decay exponent |
| B | 410.7 | Data-size term amplitude |
| β | 0.28 | Data-size decay exponent |
Q: What’s the compute-optimal version — L as a function of C alone?
Given the FLOP constraint C = 6ND (standard transformer forward+backward), we can eliminate N and D by optimizing the loss subject to that budget. Lagrange multipliers on $L(N, D)$ give:
Substituting back into the loss:
\[L^*(C) = E + K \cdot C^{-\eta}, \quad \eta = \frac{\alpha\beta}{\alpha + \beta} \approx 0.154\]So compute-optimally-trained loss decays as $C^{-0.154}$ toward the floor E = 1.69. The compute-optimal tokens-per-parameter ratio is $D^/N^ \propto C^{b - a} = C^{0.096}$ — it grows weakly with scale. At Chinchilla’s scale (~$10^{24}$ FLOPs) this ratio landed at ~20.
Q: Why does nanochat use ~10.5, not Chinchilla’s 20?
The 20:1 ratio is not a universal constant — it’s the value the Chinchilla fit happens to predict at ~$10^{23}$–$10^{24}$ FLOPs. Since $D^/N^$ grows as $C^{0.096}$, the ratio shrinks at smaller scales. Extrapolating Chinchilla’s formula to nanochat scale ($C \sim 10^{18}$–$10^{19}$ FLOPs) predicts a ratio closer to ~8, which is in the same ballpark as the ~10.5 that nanochat actually finds.
Other reasons the exact number differs:
- Architecture: nanochat uses value embeddings (parameter-heavy, FLOP-free), ReLU², QK norm, and other non-standard choices that shift the effective FLOP/parameter relationship. Chinchilla’s fit assumed a standard dense transformer.
- Extrapolation is risky: Chinchilla fit at 70M–16B parameters. Nanochat’s ~100M scale sits near the low end, so re-fitting from IsoFLOP sweeps (the point of
scaling_laws.shandminiseries.sh) produces more accurate local coefficients than blindly trusting 20:1. - Tokenizer/data differ: both
Band the effectiveDdepend on tokenizer efficiency and dataset composition.
The 10.5 figure comes from nanochat’s own IsoFLOP experiments — the minima of its loss-vs-depth curves at each FLOP budget — not from assuming any prior scaling law.
Q: What does miniseries.sh do?
Trains depths 12-26 at compute-optimal token-per-parameter ratio (~10.5), recording results to CSV. This maps out how performance scales with model size when each model is optimally trained.
Q: How did the speedrun determine depth 24?
From scaling law experiments. The CORE metric (22-task composite centered accuracy) is the target — beat GPT-2’s score. d24 with aspect_ratio=64 gives ~124M active parameters, enough to surpass GPT-2 (124M params) on CORE.
Q: Is the speedrun model compute-optimal?
No. It uses ratio=8 (tokens per parameter), while compute-optimal is ~10.5. It’s intentionally undertrained to optimize wall-clock time, not FLOP efficiency. The goal is beating GPT-2 CORE score as fast as possible, not achieving the best loss per FLOP.
Q: Why not use a smaller model at compute-optimal ratio 10.5 for shorter wall-clock?
Because “compute-optimal” refers to FLOP efficiency, not wall-clock. A smaller model at ratio 10.5 uses fewer FLOPs total but may not beat GPT-2’s CORE score. The speedrun needs a model large enough to exceed GPT-2 performance — d24 is that threshold. Undertrained (ratio=8) gets there faster on the clock because fewer iterations are needed.
Q: So compute-optimal is about FLOPs, not wall-clock?
Yes. Compute-optimal means best loss per FLOP. An architecture bottlenecked by memory bandwidth (like VE: massive parameters, zero FLOPs) shifts the optimal ratio. VE tokens-per-parameter ratio would be small because parameters are cheap (no FLOPs) but the FLOP-based scaling law doesn’t directly capture this.
Q: What’s the difference between SP (standard parametrization) and μP (maximal update parametrization)?
Both are ways of setting init scale and per-parameter learning rates as a function of width n. They differ in where the width-dependence lives — inside the tunable knob η_base (SP) or inside fixed per-group scalings (μP).
| Aspect | SP (standard parametrization) | μP (maximal update parametrization, Adam variant) |
|---|---|---|
| Input layers (embed) init | Var = c/fan_in |
Var = σ² (O(1), width-independent) |
| Hidden matrices init | Var = c/fan_in |
Var = σ²/fan_in |
| Output layers (lm_head) init | Var = c/fan_in |
Var = σ²/fan_in² (near-zero) |
| Input LR (Adam) | η |
η × 1 |
| Hidden LR (Adam) | η |
η × 1/n |
| Output LR (Adam) | η |
η × 1/n |
| Output forward multiplier | none | 1/n |
Does tuned η change with width? |
Yes — retune every scale | No — width-invariant |
| Does effective LR on hidden weights change with width? | No (flat η) |
Yes (shrinks as 1/n) |
| Transfer recipe | None — run a fresh sweep at every size | μTransfer: tune η at tiny width, copy to huge width |
So the sharp answer to “does μP LR not change with scale?”:
- The tunable hyperparameter
η_basestays constant. That’s the whole point — you tune it once at a cheap width and reuse it. - The effective per-parameter LR that actually hits the weights does change: hidden and output matrices get a
1/nmultiplier baked into the parametrization. μP hasn’t made width go away; it’s moved the width-dependence from the knob into fixed machinery, so the knob itself becomes transferable.
(SGD has a different recipe: hidden LR scales as 1, embedding LR as n. The family of optimizers matters.)
Q: How does nanochat adapt this to pick hyperparameters?
Karpathy calls his approach “muP style” (base_train.py:271) — pragmatic, not textbook μP. He tunes hyperparameters at a d12 reference model (n_embd=768, B_ref=2^19 tokens) and applies two independent LR corrections plus a weight-decay rule to transfer to larger scales:
The two √-terms scale along different axes (width vs. batch size) and are independent; they just happen to share the same square-root shape.
-
Width scaling — μP knob (
gpt.py:383):dmodel_lr_scale = (model_dim / 768) ** -0.5. Applied to AdamW groups:wte(embedding_lr=0.3),lm_head(unembedding_lr=0.008),value_embeds(embedding_lr × 0.5). This is 1/√n, between SP (no width scaling) and strict μP (1/n for Adam). A “conservative Adam” heuristic. -
Muon handles matrix width invariance separately. Matrix parameters use the Muon optimizer at a fixed
matrix_lr=0.02— no width scaling applied at optimizer setup (gpt.py:398-403). Muon’s own update rule normalizes per-shape (optim.py:lr * max(1, shape[-2]/shape[-1])**0.5), giving Muon some μP-like width invariance “for free”. -
Batch size from the Power Lines paper (arxiv 2505.13738,
\[B_{\text{opt}} \propto D^{0.383}\]base_train.py:277-283) — this is where the second √-term comes from:predicted_B = B_ref · (target_tokens / D_ref)^0.383, rounded to the nearest power of 2. OnceBis known, LR picks up the batch correctionη *= √(B/B_ref)— the standard AdamW square-root batch-scaling rule, also applied to Muon “as an assumption” (base_train.py:291-293). -
Weight decay scaling from the T_epoch framework (arxiv 2405.13698,
\[\lambda = \lambda_{\text{ref}} \cdot \sqrt{B/B_{\text{ref}}} \cdot (D_{\text{ref}}/D)\]base_train.py:297-302): keepsT_epoch = B/(η·λ·D)constant across scales. Combined withη ∝ √(B/B_ref)from step 3, this forces:Then a cosine decay of
λto zero over the run, applied only to Muon groups. AdamW groups have fixed per-group weight decay (0.001 forwte, 0.01 forlm_head).
Note that steps 1 and 3 scale different things: step 1 is pure μP width correction (triggered by making the model wider), step 3 is pure batch-size correction (triggered by the Power Lines rule enlarging B as D grows). Both fire whenever you scale up from d12, so in practice both √-terms are active.
Q: Did this μP-style transfer actually work?
Partially. dev/LOG.md (2026-01-19 entry) is blunt about the limits:
Hyperparameters are scale-dependent. What works at d12 doesn’t transfer to d20. The elaborate fine-tuning that won at d12 actively hurts at d20. Don’t over-tune on small proxies. Validate at target scale before shipping.
So the explicit width/batch/data-size rules above get you into the right neighborhood, but the last few percent still needs retuning at target scale. This is the practical reality of μP-style transfer outside the idealized regime — Muon’s width invariance + 1/√n AdamW scaling is a good starting point, not a zero-shot guarantee.
MoE (Mixture of Experts)
Q: How does the MoE layer work?
Drop-in replacement for the dense MLP. For each token:
- Sigmoid router scores all experts:
sigmoid(gate(x))→(T, E) - Top-K experts selected per token (with auxiliary-loss-free bias from DeepSeekV3)
- Tokens sorted by expert assignment for contiguous processing
- Pre-multiplied by routing scores before expert computation
- Routed experts process via
grouped_mm(single kernel per projection) - Shared expert processes ALL tokens via standard dense matmul
- Outputs scattered back and summed: routed + shared
Key sizing: expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128 ensures iso-FLOP with dense MLP.
Q: Why token_ids = token_indices_sorted // self.top_k?
selected_experts is (T, K). Flattened to (T*K,), positions [0..K-1] = token 0, [K..2K-1] = token 1, etc. token_indices_sorted are indices into this flat array, so // top_k recovers the original token index.
Limitations
The d32 model is available for chat at nanochat.karpathy.ai. From the official training log:
- 1.88B total parameters, 12.1B FLOPs per token
- 37.6B training tokens (tokens:params ratio = 20.0)
- 71,680 iterations on 8 GPUs, ~31 hours total
- Final validation BPB: 0.7236, CORE metric: 0.3274, MFU: 51.79%
While it demonstrates the full pipeline works end-to-end, the model frequently halluccinates and lacks factual knowledge. This is understandable, 1.88B params with 37.6B tokens is tiny compared to production models (e.g. Llama 3 8B trained on 15T tokens). Nanochat’s value is as a research and educational codebase, not as a general-purpose chat model.