cayley-small-3L-mlp_in-20B

A 202.5M-parameter GPT with a 3-level CayleySAE inserted at mlp_in in every transformer block, trained on 20B tokens of FineWeb-Edu.

The "small" companion to cayley-large-3L-mlp_in-20B, and the 3-level counterpart to cayley-small-2L-mlp_in-20B.

Headline

  • Val loss: 3.1330 (best, iter 12500; final iter 12716 was 3.1461 β€” slight noise-floor wobble)
  • Training tokens: 20B (FineWeb-Edu, sample-100BT slice)
  • Wall clock: ~17h12m on 2Γ— NVIDIA H200 (143 GB)

Beats the 2-level sibling cayley-small-2L-mlp_in-20B (val 3.1584) by 0.025 nats at the same 20B token budget β€” the third level buys a small but real expressivity gain at this scale.

Beats the older 1-level aemack-org/cayley-10b (val 3.173, 10B tokens) by 0.040 nats.

Backbone

  • 12 transformer blocks, d_model 1024, 8 heads (head_dim 128)
  • 202.5M total parameters
  • seq_len 1024
  • RMSNorm, learned absolute position embeddings (no RoPE)

CayleySAE

Inserted at mlp_in in every block: RMSNorm β†’ CayleySAE β†’ MLP. Output is dense d=1024; sparsity lives in the intermediate code.

  • 3 levels with hierarchy 10,16,0;15,32,256;17,64,256
    • L0: n=10 (1024 coords), k=16
    • L1: n=15 (32k leaves), k=32, parent budget 256
    • L2: n=17 (131k leaves), k=64, parent budget 256
  • 112 active features per token (16 + 32 + 64)
  • Parameter-free algebraic dictionary; only per-feature biases are learned
  • cayley-per-parent-budget, cayley-score-standardize, and cayley-forward-standardized (the "zombie fix" β€” see report 29 in repo) all enabled

n0 is forced to log2(d_model) = 10.

Training recipe

Knob Value
Optimizer Muon (2D weights) + AdamW (embeddings, biases)
Peak Muon LR 1.2e-2 (lockstep with AdamW)
Min Muon LR 1.5e-4
Peak AdamW LR 1.2e-2 (lockstep)
Min AdamW LR 1.5e-4
LR schedule linear_warmdown
warmdown_frac 0.9 (super-Chinchilla)
Warmup iters 200
Batch size (per rank) 32
Gradient accumulation 48 (global, 24 microsteps per rank)
Tokens per iter 1,572,864
Total iters 12,716
World size 2Γ— NVIDIA H200 (143 GB)
Steady-state tok/s ~328k
Peak VRAM 44.8 GB / 143 GB per GPU
Dataset FineWeb-Edu sample-100BT (25B-token slice)

Warmup 0 β†’ 200; flat phase 200 β†’ 1272; warmdown 1272 β†’ 12716 (linear 1.2e-2 β†’ 1.5e-4 for both Muon and AdamW).

This is the v6-fullrun-floor recipe with the zombie fix applied β€” the established formula for 12L/d=1024 cayley + super-Chinchilla data budgets.

Recipe relation to siblings

Identical recipe to cayley-small-2L-mlp_in-20B; the only architectural difference is the third Cayley level (L2: n=17, k=64). Apples-to-apples: peak/min LRs, warmdown fraction, batchΓ—ga, and zombie-fix flags all match, so the 0.025-nat improvement is attributable to the added L2.

Training health

Signal Outcome
Steady-state tok/s 328k (unchanging from iter 50 onward)
grad_norm in warmdown declined 0.20 β†’ 0.06 (clean)
Dead features (C=0.1) L0 0/1024, L1 0/31744, L2 0/98304 β€” 0% throughout
L1 reachable (of 32768) 31744 (97%)
L2 reachable (of 131072) 98304 (75%)
l0_mu_mean ~0.033 (final)
l0_sigma_mean ~1.216 (final)

Val descent (every 500 iters):

iter   500   1000   1500   2000   2500   3000   3500   4000   4500
val   3.826  3.543  3.469  3.407  3.383  3.386  3.345  3.333  3.314
iter  5000   5500   6000   6500   7000   7500   8000   8500   9000
val   3.304  3.279  3.264  3.256  3.248  3.224  3.225  3.211  3.195
iter  9500  10000  10500  11000  11500  12000  12500  ...
val   3.196  3.176  3.181  3.155  3.163  3.147  3.133

Two minor wobbles (iter 3000, 8000, 11500) all within Β±0.005 of prior best β€” floor noise, not pathology.

Quick / hierarchy evals

quick_eval (hellaswag/lambada/squad) and hierarchy_eval both skipped β€” the evals module wasn't on PYTHONPATH on this node ("No module named 'evals'"). Logged gracefully with no NCCL impact. Re-evaluate from this checkpoint to populate downstream metrics.

VRAM-headroom note (for future runs)

This run held bs=32 for apples-to-apples comparison with the 8Γ— A100 small-2L sibling. On 2Γ— H200 (143 GB each), VRAM at bs=32 was only 44.8 GB. A brief bs=64/ga=24 attempt before the canonical run showed 84.6 GB / 99% util at ~342k tok/s; linear extrapolation suggests bs=96/ga=16 would fit at ~124 GB. For future H200 runs of this model size, bs=96 is the practical ceiling (bs=128 OOMs at ~164 GB extrapolated).

Files

  • ckpt.pt β€” PyTorch checkpoint (~2.5 GB). Contains model, optimizer_states, config, model_config, iter_num, best_val_loss.
  • config.json β€” training config snapshot.
  • train_cayley_small_3L_mlp_in_20B.sh β€” exact training script.

Loading

import torch
from sparse_nanogpt.model import GPT
from sparse_nanogpt.config import DeepTopKGPTConfig

ckpt = torch.load("ckpt.pt", map_location="cpu", weights_only=False)
model_config = DeepTopKGPTConfig(**ckpt["model_config"])
model = GPT(model_config)
model.load_state_dict(ckpt["model"])

Lineage

Citation

Part of the Sparse NanoGPT project.

Downloads last month
35
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Collection including markhenry/cayley-small-3L-mlp_in-20B