Dong Wang A machine learning blog

Nanochat: A Deep Dive

Nanochat is a GPT-2-beating LLM training codebase by Karpathy — tokenizer, pretraining, SFT, RL, and eval in readable single-file scripts. No frameworks, no abstractions you can’t trace end-to-end. The speedrun pipeline trains a 24-layer model that beats GPT-2’s CORE score on wall clock, using value embeddings, FP8 training, and a collection of non-standard but justified architectural choices.

Summary: Why Nanochat Is Worth a Look

  1. Complete pipeline in one repo: tokenizer → pretrain → SFT → RL → eval, all readable single-file scripts.
  2. Value Embeddings (VE): per-layer token lookup tables added to attention values at every other layer. Zero FLOPs, ~55% of total parameters. The model “loves” them.
  3. Scaling laws done right: IsoFLOP grid sweep (4 budgets × 6 depths), ~10.5 tokens-per-parameter compute-optimal ratio. Speedrun uses ratio=8 to optimize wall-clock.
  4. Minimal FP8 training: ~150 lines vs torchao’s ~2000.
  5. No DataLoader, no DistributedSampler: manual rank-based striding over parquet row groups. BOS-aligned best-fit packing, ~100% sequence utilization.
  6. Non-standard architecture choices: ReLU², QK norm after RoPE, logit softcap, x0 residual, smear, backout, sliding window (SSSL pattern).
  7. Five eval methods: base has 3 (multiple_choice, schema, language_modeling — all loss/argmax-based, no generation), SFT has 2 (categorical via logit argmax over letter tokens, generative via sampling with tool use). Base uses CORE metric, SFT uses ChatCORE metric — both are mean centered accuracy across their respective task sets.
  8. RL is dead simple: vanilla REINFORCE on GSM8K, no std normalization, no PPO, no value network.
  9. Engine with tool use: KV-cache inference with calculator calls mid-generation via forced token injection.
  10. Experiment log as first-class artifact: dev/LOG.md documents every decision — what worked, what failed, and why.
  11. Attention learns document boundaries without hard masks: BOS markers, RoPE distance decay, and sliding window. No varlen attention needed.
  12. Everything is traceable: weight init math derived, per-type optimizer LRs, batch size auto-scales from Power Lines paper. No magic numbers.

Q&A

Naming and Conventions

Q: Why is the token embedding called wte?

“weight token embedding” — inherited from GPT-2’s naming convention. w = weight, t = token, e = embedding.

Q: Why are attention weights named with c_ prefix, like c_q, c_proj?

From GPT-2’s original implementation which used Conv1D (1×1 convolution) instead of nn.Linear for all projections. The c_ prefix meant “conv”. Nanochat uses nn.Linear but keeps the naming for continuity with the GPT lineage.


Model Architecture

Q: One-sentence summaries for all non-GPT-2 features?

  • RoPE: Rotary position embeddings replacing learned positional embeddings, applied to Q and K.
  • RMSNorm: Root-mean-square normalization with no learnable parameters, replacing LayerNorm.
  • QK Norm: L2-normalizes Q and K after RoPE, scaled by 1.2, preventing attention logit explosion.
  • ReLU²: F.relu(x).square() activation in MLP — sparse and cheap, replacing GELU.
  • GQA: Grouped-query attention allowing fewer KV heads than query heads.
  • Untied embeddings: Separate wte (input) and lm_head (output) embedding matrices.
  • Sliding window attention: SSSL pattern (Short-Short-Short-Long) alternating window sizes across layers.
  • Value Embeddings: Per-layer token lookups gated and added to attention values at every other layer.
  • x0 residual: Every layer adds to the original input (layer 0 output), not just the previous layer.
  • Smear: Bigram-like mixing of previous token into current token representation at input.
  • Backout: Cache residual at halfway layer, subtract backout_lambda-scaled version before final norm — removes low-level features from logit computation.
  • Logit softcap: 15 * tanh(logits/15) — soft clipping to prevent extreme logit values.

Q: What are all the weight initialization choices?

The full init scheme (gpt.py:init_weights):

Component Method Value Rationale
wte (input embed) normal std=0.8  
lm_head (output embed) normal std=0.001 Near-zero so logits start small and uniform
c_q, c_k, c_v uniform U(-s, s), s=√(3/n_embd) Uniform avoids outliers vs normal; variance-matched (Var of U(-s,s) = s²/3 = 1/n_embd)
c_proj (attn output) zeros 0 Zero-init output projections — each block starts as identity
c_fc (MLP up) uniform U(-0.4s, 0.4s) 0.4× reduced scale for MLP input
c_proj (MLP output) zeros 0 Same zero-init pattern as attn output
Value embeddings uniform U(-s, s) Same as c_v
ve_gate uniform U(0, 0.02) Small positive so gates start slightly above neutral
resid_lambdas per-layer linear decay 1.15 → 1.05 Stronger residual at early layers, weaker at deep layers
x0_lambdas per-layer linear decay 0.20 → 0.05 Earlier layers get more input embedding blending
backout_lambda constant 0.2  
smear_lambda constant 0.0 Disabled at init, learned during training

Key design principles:

  • Output projections are zero-initialized (c_proj in both attn and MLP). At init each transformer block is effectively an identity — the residual stream passes through untouched. This is a common trick for stable deep training.
  • Uniform over normal for transformer matrices: same variance but bounded, avoiding outlier weights at init.
  • Per-layer decay for resid_lambdas and x0_lambdas: early layers get stronger residual connections and more x0 blending, deep layers are more independent.
  • From dev/LOG.md: orthogonal init was tried but didn’t help. The x0 residual scalars were originally zero-initialized (disabled at start), but the current per-layer decaying init (1.15→1.05 and 0.20→0.05) was found to work better.

Q: What is backout?

Cache the residual at the halfway layer (n_layer // 2), then subtract a learned-scaled version (backout_lambda, init 0.2) before the final norm and LM head. The idea is to remove low-level features from the mid-layer that helped internal processing but shouldn’t influence logit computation. Active in the codebase at gpt.py:449-459.

Q: What are value embeddings?

Per-layer token lookup tables (same shape as wte but sized to kv_dim) at every other layer, whose output is gated and added directly to the attention values. Determined by has_ve(layer_idx, n_layer) which uses layer_idx % 2 == (n_layer - 1) % 2. The gating is gate = 3 * sigmoid(ve_gate(x[..., :12])), then v = v + gate.unsqueeze(-1) * ve. VE is unconditionally enabled — there is no toggle flag.

Q: How big are VE embeddings compared to wte?

For the speedrun d24 config (n_layer=24, n_embd=1536, vocab_size=32768): wte is ~50M params. VE exists at 12 of 24 layers, each sized (vocab_size, kv_dim). Total VE is ~55-61% of all model parameters. They cost zero FLOPs (just table lookups).

Q: How much do VE help? (from dev/LOG.md)

The experiment log shows VE is one of the biggest wins. Every attempt to reduce VE capacity hurt performance — the model “loves” them. They provide a massive parameter reservoir that enriches attention values without adding compute.


Tensor Operations

Q: Does gate.unsqueeze(-1) add a new dimension?

Yes. If gate is (B, T, H), then gate.unsqueeze(-1) becomes (B, T, H, 1) — a new trailing dimension for broadcasting against the (B, T, H, D) value tensor.

Q: What other ways to do x[..., :self.ve_gate_channels]?

  • x.narrow(-1, 0, channels) — no copy, same as slicing
  • x[..., :channels] — the ellipsis form used in code
  • x[:, :, :channels] — explicit dims (fragile if rank changes)
  • torch.split(x, [channels, rest], dim=-1)[0] — overkill here

Q: Is torch.zeros(batch_size) a row or column vector?

Neither — it’s a 1D tensor of shape (batch_size,). PyTorch 1D tensors have no row/column distinction. You need (1, n) for row or (n, 1) for column.

Q: How to outer product two 1D tensors?

torch.outer(a, b) or equivalently a.unsqueeze(-1) * b.unsqueeze(0) or torch.einsum('i,j->ij', a, b). All produce shape (len(a), len(b)).

Q: Explain idx.gather(1, choice)?

gather(dim, index) selects elements along dim using positions from index. The result always has the same shape as index, regardless of the source tensor’s shape. For idx of shape (B, V) and choice of shape (B, 1):

# For each row b, picks idx[b, choice[b, 0]]
result = idx.gather(1, choice)  # (B, 1) — same shape as choice, not idx

Used in sample_next_token to map multinomial’s position indices back to actual vocab token IDs after top-k filtering.

Q: What are all the tensor shape tricks used in the repo?

Dimension Expansion / Broadcasting:

Code File:Line Pre Shape Post Shape Purpose
gate.unsqueeze(-1) gpt.py:95 (B,T,n_kv_head) (B,T,n_kv_head,1) Broadcast gate scalar per-head across head_dim
cos[None,:,None,:] gpt.py:277 (T,D/2) (1,T,1,D/2) Add batch+head dims for broadcasting with (B,T,H,D)
arange().unsqueeze(1) flash_attention.py:94 (Tq,) (Tq,1) Column vector for outer-product-style mask
arange().unsqueeze(0) flash_attention.py:95 (Tk,) (1,Tk) Row vector for outer-product-style mask
tokens.unsqueeze(1) engine.py:279 (B,) (B,1) Add seq dim for single-token decode
advantages.unsqueeze(-1) chat_rl.py:266 (B,) (B,1) Broadcast per-sample advantage across token positions
logits.expand(N,-1) engine.py:211 (1,V) (N,V) Copy batch=1 logits to N parallel samples
prev_emb.expand(B,-1,-1) engine.py:137 (1,1,C) (B,1,C) Expand cached embedding to batch for prefill

Reshape (split/merge dimensions):

Code File:Line Pre Shape Post Shape Purpose
c_q(x).view(B,T,H,D) gpt.py:87-89 (B,T,H*D) (B,T,H,D) Split linear output into heads
ve.view(B,T,n_kv,D) gpt.py:93 (B,T,kv_dim) (B,T,n_kv_head,D) Split VE into heads
y.contiguous().view(B,T,-1) gpt.py:124 (B,T,H,D) (B,T,C) Merge heads back after attention
logits.view(-1,V) gpt.py:472 (B,T,V) (B*T,V) Flatten for cross_entropy
targets.view(-1) gpt.py:472 (B,T) (B*T,) Flatten for cross_entropy
buffer.view(B,T) dataloader.py:117-120 (B*T,) (B,T) Shape flat token buffer into batch
input.reshape(-1,C) fp8.py:208 (B,T,C) (B*T,C) Flatten for 2D-only _scaled_mm

Concatenation / Stacking:

Code File:Line Pre Shapes Post Shape Purpose
torch.cat([y1,y2], 3) gpt.py:63 2×(B,T,H,D/2) (B,T,H,D) Rejoin rotated halves in RoPE
torch.cat([x[:,:1], smeared], 1) gpt.py:432 (B,1,C)+(B,T-1,C) (B,T,C) Smear: keep 1st token, mix rest with previous
torch.cat((ids, next_ids), 1) gpt.py:505 (B,T)+(B,1) (B,T+1) Append generated token
torch.stack(grads) optim.py:259 K×(M,N) (K,M,N) Batch same-shape grads for Muon Newton-Schulz

Layout Transpose (SDPA fallback):

Code File:Line Pre Shape Post Shape Purpose
q.transpose(1,2) flash_attention.py:123 (B,T,H,D) (B,H,T,D) FA3 layout → SDPA layout
y.transpose(1,2) flash_attention.py:128 (B,H,T,D) (B,T,H,D) SDPA layout → FA3 layout back

Python / PyTorch Patterns

Q: @torch.inference_mode() vs no_grad() and model.eval()?

  • inference_mode(): stricter and faster than no_grad() — disables both gradient computation AND version tracking on tensors. Tensors created inside cannot be used for backward later.
  • no_grad(): only disables gradient computation, tensors still track versions.
  • model.eval(): changes module behavior (e.g., BatchNorm uses running stats, Dropout disabled). Must be called separately — neither inference_mode nor no_grad does this.

Q: How does the timeout context manager work?

Uses Unix SIGALRM signals:

@contextmanager
def timeout(seconds):
    signal.signal(signal.SIGALRM, handler)  # register alarm handler
    signal.alarm(seconds)                    # start countdown
    try:
        yield                                # run the guarded code
    finally:
        signal.alarm(0)                      # cancel alarm

The with statement doesn’t need an as clause — the context manager yields nothing. If you write with timeout(5) as x, x would be None.

Q: torch.compile(model, dynamic=False) — what does dynamic mean?

dynamic=False tells the compiler all tensor shapes are fixed. It compiles one optimized kernel per shape and reuses it. dynamic=True would generate shape-generic code that handles varying sizes at runtime, but with less optimization. Nanochat uses False because sequence length and batch size are constant during training.

Q: What is the meta device pattern? What does model.to_empty(device) do?

meta device creates tensors with shapes and dtypes but no memory allocation — useful for building large models without OOM during init. to_empty(device) allocates real memory on the target device but leaves values uninitialized (garbage). Then init_weights() fills them. This differs from model.to(device) which would try to copy the (nonexistent) meta tensor data.

Q: inv_scale = scale.reciprocal() vs 1/scale?

Functionally identical for scalars. .reciprocal() is a method that stays in the tensor computation graph, while 1/scale calls torch.div. In practice no measurable difference — it’s a style choice.

Q: ColoredFormatter.COLORS — class variable or instance variable?

Class variable. Defined directly in the class body (not in __init__), shared across all instances.

Q: Why s = 3**0.5 * n_embd**-0.5 for weight init?

The goal is U(-s, s) with variance matching N(0, 1/n_embd). Variance of U(-s, s) is s²/3. Setting s²/3 = 1/n_embd gives s = sqrt(3/n_embd) = 3^0.5 * n_embd^(-0.5).

Derivation of Var(U(-s,s)) = s²/3:

Var = E[X²] - E[X]² = E[X²] - 0
E[X²] = ∫₋ₛˢ x²·(1/2s) dx = (1/2s)·[x³/3]₋ₛˢ = (1/2s)·(2s³/3) = s²/3

Q: Is apply_rotary_emb’s rotation direction correct?

Yes. The rotation is applied in the opposite direction compared to the standard formulation, but since both Q and K are rotated the same way, the signs cancel in the Q·K dot product. The relative rotation (which encodes relative position) is preserved.

Q: Why short_window = -(-long_window // 4 // 128) * 128?

Double negation implements ceiling division. -(-x // n) = ceil(x/n). This computes ceil(long_window / 4) rounded up to the next multiple of 128 (for tensor core alignment).


Flash Attention and FP8

Q: Why is Flash Attention 3 loaded via kernels.get_kernel(...)?

FA3 requires Hopper (SM90) and isn’t in PyTorch’s standard distribution. HuggingFace’s kernels package provides pre-built wheels that can be fetched at runtime: get_kernel('varunneal/flash-attention-3'). Falls back to SDPA on non-Hopper GPUs.

Q: What does the SDPA fallback do for sliding window?

SDPA doesn’t natively support sliding window, so the code builds explicit boolean attention masks. For chunk inference (where Tq != Tk), it also handles the offset between query and key positions. The mask is constructed to only allow attention within the window size.

Q: In fp8.py, what is rowwise vs tensorwise scaling?

Nanochat uses tensorwise scaling: one scale factor per entire tensor (scale = FP8_MAX / max(|tensor|)). Rowwise would compute a separate scale per row — finer grained but more overhead. Tensorwise is simpler and sufficient here.


Data Loading

Q: How do documents become batches in the dataloader?

_document_batches is an infinite iterator over parquet row groups, sharded by rank (rg_idx += ddp_world_size). Documents are tokenized, then tokenizing_distributed_data_loader_with_state_bos_bestfit packs multiple documents per row using BOS-aligned best-fit packing. A pre-allocated pinned CPU buffer collects packed rows, then a single HtoD transfer moves the batch to GPU. ~100% sequence utilization, ~35% document crop waste.

Q: Can one row include more than one document?

Yes — that’s the point of best-fit packing. Multiple documents are packed into a single sequence row, separated by BOS tokens. The model learns soft document boundaries without hard attention masks.

Q: So no PyTorch DataLoader or DistributedSampler?

Correct. All data distribution is manual. Each rank strides through row groups: rg_idx += ddp_world_size. No sampler, no DataLoader.

Q: How does attention learn to ignore cross-document tokens when packed together?

Through soft signals: BOS tokens mark boundaries, RoPE resets positional distance, and sliding window limits how far attention can reach. Experiments with varlen attention (hard boundaries) showed no benefit — the model learns effective boundaries on its own.

Q: Why was midtraining removed? (commit 1ddaad1)

The pipeline was originally 3 stages: pretrain → midtrain → SFT. Midtraining ran on the SFT data mixture (SmolTalk, GSM8K, MMLU, etc.) with the pretraining-style packed dataloader, but discarded the loss mask (ids, _ = render_conversation(...)) — so loss was on all tokens (user + assistant). It served as a domain adaptation step at pretraining-scale throughput.

The old SFT was inefficient: no sequence packing, one conversation per row, small batch sizes. It couldn’t process enough tokens for domain adaptation on its own, so midtraining bridged that gap.

Once the BOS-aligned best-fit packing dataloader was introduced, SFT was upgraded to use it — gaining the same packed throughput (total_batch_size=524288 in tokens) while retaining assistant-only loss masking. The new SFT absorbed midtraining’s role, making it redundant. The pipeline simplified to pretrain → SFT.


Distributed Training

Q: Does dist.init_process_group(backend='nccl', device_id=device) need a host/port?

Yes — NCCL needs a rendezvous point. torchrun sets environment variables MASTER_ADDR and MASTER_PORT automatically. init_process_group reads them from the environment by default.


Evaluation

Q: Explain evaluate_bpb — the where makes it complex.

BPB = bits per byte. The function:

  1. Runs cross-entropy loss with reduction='none' → per-token NLL in nats
  2. Uses torch.where to safely index token_bytes[y] — when y contains -1 (ignore_index), you can’t index a tensor with negative values (it would wrap around). So torch.where replaces -1 with 0 for safe indexing, then zeros out those entries’ byte counts
  3. Sums nats over all valid tokens
  4. Divides by total UTF-8 bytes of the original text
  5. Converts nats to bits: / log(2)

This gives a vocab-independent metric: how many bits the model needs per byte of raw text.

Q: token_bytes is 0 for special tokens — what is token_bytes[y]?

token_bytes is a lookup table: token_bytes[token_id] = number of UTF-8 bytes that token represents. Special tokens map to 0 bytes, so they don’t inflate the denominator. The total bytes denominator counts only real text.

Q: One-sentence summary of evaluate_bpb?

Sums the cross-entropy loss (in bits) over all non-special tokens, then divides by the total UTF-8 bytes of the original text — giving a vocab-independent measure of how many bits the model needs per byte.

Q: How can a base LLM evaluate tasks — explain evaluate_example?

Three task types, all loss-based (no generation):

  1. multiple_choice: Common prefix + different continuations. Score = mean loss over the continuation tokens. Pick the continuation with lowest loss. Example: “The capital of France is” → [“Paris”, “London”, “Berlin”] — lowest loss on “Paris”.

  2. schema: Different contexts with the same continuation. Score = mean loss over the shared suffix. Pick the context that makes the suffix most likely. Example: contexts describe different scenarios, suffix is “The answer is yes” — which context makes that suffix most natural?

  3. language_modeling: Feed prompt + answer, check if the model’s argmax predictions match the answer tokens exactly. No loss comparison — binary match/no-match.

Q: For multiple_choice, the continuations have different lengths — do we batch them?

Yes. find_common_length identifies the shared prefix. Different continuations are batched together. Mean loss normalizes for length differences so longer continuations aren’t penalized.

Q: For language_modeling, how do we know how long to generate?

We don’t generate. The answer’s token length is known. We feed the prompt + answer and check if argmax predictions at each position match. It’s a teacher-forcing check, not free-form generation.

Q: In both eval types, is it greedy decoding or sampling?

Neither — base CORE eval uses no decoding at all. It’s purely loss-based (multiple_choice, schema) or argmax-matching (language_modeling). No temperature, no sampling, no top-k.

Q: What is the CORE metric?

A 22-task composite benchmark using centered accuracies. Each task’s accuracy is centered (subtract expected random performance), then averaged across all tasks. Used to benchmark against GPT-2.

Q: Explain run_chat_eval — how does it work?

Two modes:

  1. Categorical: Forward pass only. At the answer position, look at logits for letter tokens (A, B, C, D). Argmax over just those logits = the model’s answer. No generation needed.

  2. Generative: Use engine.generate_batch to actually sample completions. Each task’s evaluate() method parses the output (e.g., regex for numbers in GSM8K). pass@k = any of k completions correct.

Q: For categorical, the model needs to respond with one letter — what if it says “the answer is A”?

It doesn’t generate at all. The logits method looks at the probability distribution at the single answer position and compares logits for just the letter tokens. It’s not about what the model would generate — it’s about which letter the model assigns highest probability to.

Q: Is categorical the same as pretrain’s loss-based eval?

No. Pretrain CORE uses mean loss over continuations (comparing multiple completions). Chat categorical compares logits at a single position for specific letter tokens. Different approach, different task format.

Q: For GSM8K, how robust is the answer parser?

It uses regex to find #### number at the end of the output. Strips commas for comparison. It’s simple string matching — evaluate() returns True/False, and reward() reuses this for RL.

Q: In SFT generation eval, does the model know to stop at assistant_end?

Yes. The engine checks for the assistant_end token during generation and marks that row as complete. Generation stops for that sequence.


Training Scripts

Q: How does build_model_meta(depth) work?

model_dim = depth × aspect_ratio (aspect_ratio=64), rounded to head_dim (128). Builds on meta device (no memory), then to_empty(device) for uninitialized allocation, then init_weights().

Q: How is target-flops enforced — does it convert to num_iterations?

Yes: num_iterations = round(target_flops / (flops_per_token × total_batch_size)). Precedence: --num-iterations > --target-flops > --target-param-data-ratio.

Q: Where does weight decay ramp to zero? I don’t see it in base_train.py.

It’s there — weight decay follows a cosine schedule: wd * 0.5 * (1 + cos(π * it / num_iterations)). At it = num_iterations, cosine = -1, so wd * 0.5 * 0 = 0. The SFT script inherits this (already zero) and sets weight_decay=0.0 explicitly.

Q: How does sft_data_generator_bos_bestfit work? What format is SFT data?

Tasks (SmolTalk, MMLU×3, GSM8K×4, etc.) → TaskMixture flattens and interleaves → each example is a conversation dict → render_conversation() converts to (token_ids, loss_mask) where mask=1 for assistant text and assistant_end, mask=0 for user text and special tokens → best-fit packing into sequences → padded positions get target=-1.

Q: Does <|assistant_end|> have mask 0 or 1?

Mask = 1. The model needs to learn to predict when to stop. assistant_end is part of the assistant’s output that should be learned.


SFT Data and Chat

Q: How does SFT handle DDP without a DistributedSampler?

Manual striding: cursor = ddp_rank, then cursor += ddp_world_size after each example. Each rank sees every Nth example. Same pattern as pretraining’s row-group striding.


RL Training (chat_rl.py)

Q: Is there sequence packing in RL? One row = one task?

No packing. One question per batch → num_samples rollouts → all rollouts for that question are processed together.

Q: Pseudocode for the RL training loop?

for step in range(num_steps):
    # Generate rollouts
    for each question:
        generate num_samples completions using engine
        score each: reward = 1 if correct, 0 if wrong
        advantages = rewards - mean(rewards)  # NO std normalization

    # Train on rollouts
    for each question's rollouts:
        for each micro-batch (device_batch_size):
            logp = -cross_entropy(model(inputs), targets, reduction='none')  # (B, T)
            pg_obj = (logp * advantages.unsqueeze(-1)).sum()
            num_valid = (targets >= 0).sum()
            loss = -pg_obj / (num_valid * num_passes * examples_per_rank)
            loss.backward()
    optimizer.step()

Q: What are the shapes of logp, pg_obj, num_valid?

With device_batch_size=4, max_seq_len=512:

Variable Shape Notes
inputs (4, 512) (B, T)
targets (4, 512) (B, T), -1 for masked positions
model(..., loss_reduction='none') (2048,) (B*T,) flat
logp (4, 512) .view_as(inputs) reshapes back, 0 at masked positions
advantages (4,) (B,) one per rollout
advantages.unsqueeze(-1) (4, 1) broadcasts across T
logp * advantages.unsqueeze(-1) (4, 512) broadcast multiply
pg_obj scalar .sum() over everything
num_valid scalar count of targets >= 0

logp is 0 at masked positions because cross_entropy(..., ignore_index=-1, reduction='none') returns 0.0 where target=-1. So masked positions contribute nothing to pg_obj.

Q: There is no std(rewards) in advantage normalization?

Correct. advantages = rewards - mean(rewards) only. No std normalization. This is an explicit design choice (commented in the code). Works because the system is on-policy — no need for PPO ratio clipping or value baselines.


Scaling Laws and Speedrun

Q: What does scaling_laws.sh do?

Runs a grid: 4 FLOP budgets × 6 depths. For each combination, trains a model and records the loss. This produces IsoFLOP curves — for each FLOP budget, plot loss vs depth to find the compute-optimal depth.

Q: What are IsoFLOP curves?

Fix a FLOP budget, vary model size (depth). Plot final loss vs size. The minimum of each curve gives the compute-optimal model size for that budget. Connecting the minima across budgets reveals scaling laws.

Q: What does miniseries.sh do?

Trains depths 12-26 at compute-optimal token-per-parameter ratio (~10.5), recording results to CSV. This maps out how performance scales with model size when each model is optimally trained.

Q: How did the speedrun determine depth 24?

From scaling law experiments. The CORE metric (22-task composite centered accuracy) is the target — beat GPT-2’s score. d24 with aspect_ratio=64 gives ~124M active parameters, enough to surpass GPT-2 (124M params) on CORE.

Q: Is the speedrun model compute-optimal?

No. It uses ratio=8 (tokens per parameter), while compute-optimal is ~10.5. It’s intentionally undertrained to optimize wall-clock time, not FLOP efficiency. The goal is beating GPT-2 CORE score as fast as possible, not achieving the best loss per FLOP.

Q: Why not use a smaller model at compute-optimal ratio 10.5 for shorter wall-clock?

Because “compute-optimal” refers to FLOP efficiency, not wall-clock. A smaller model at ratio 10.5 uses fewer FLOPs total but may not beat GPT-2’s CORE score. The speedrun needs a model large enough to exceed GPT-2 performance — d24 is that threshold. Undertrained (ratio=8) gets there faster on the clock because fewer iterations are needed.

Q: So compute-optimal is about FLOPs, not wall-clock?

Yes. Compute-optimal means best loss per FLOP. An architecture bottlenecked by memory bandwidth (like VE: massive parameters, zero FLOPs) shifts the optimal ratio. VE tokens-per-parameter ratio would be small because parameters are cheap (no FLOPs) but the FLOP-based scaling law doesn’t directly capture this.


MoE (Mixture of Experts)

Q: How does the MoE layer work?

Drop-in replacement for the dense MLP. For each token:

  1. Sigmoid router scores all experts: sigmoid(gate(x))(T, E)
  2. Top-K experts selected per token (with auxiliary-loss-free bias from DeepSeekV3)
  3. Tokens sorted by expert assignment for contiguous processing
  4. Pre-multiplied by routing scores before expert computation
  5. Routed experts process via grouped_mm (single kernel per projection)
  6. Shared expert processes ALL tokens via standard dense matmul
  7. Outputs scattered back and summed: routed + shared

Key sizing: expert_hidden_dim = round(4 * dim / (top_k + num_shared) / 128) * 128 ensures iso-FLOP with dense MLP.

Q: Why token_ids = token_indices_sorted // self.top_k?

selected_experts is (T, K). Flattened to (T*K,), positions [0..K-1] = token 0, [K..2K-1] = token 1, etc. token_indices_sorted are indices into this flat array, so // top_k recovers the original token index.


Limitations

The d32 model is available for chat at nanochat.karpathy.ai. From the official training log:

  • 1.88B total parameters, 12.1B FLOPs per token
  • 37.6B training tokens (tokens:params ratio = 20.0)
  • 71,680 iterations on 8 GPUs, ~31 hours total
  • Final validation BPB: 0.7236, CORE metric: 0.3274, MFU: 51.79%

While it demonstrates the full pipeline works end-to-end, the model frequently halluccinates and lacks factual knowledge. This is understandable, 1.88B params with 37.6B tokens is tiny compared to production models (e.g. Llama 3 8B trained on 15T tokens). Nanochat’s value is as a research and educational codebase, not as a general-purpose chat model.