grokking, from scratch, in the browser

a neural network memorizes, then suddenly understands

let's say you show a neural network half the answers to a addition table (a + b mod p). it memorizes them perfectly. train accuracy 100%, test stays near zero. looks like it just memorized, right?

then you keep training. nothing happens for a long time. and then, thousands of steps later, something clicks. test accuracy jumps from near-zero to near-perfect. the network figured out the rule, not just the answers. that's grokking.

you can try an MLP or a Transformer (the original paper used a Transformer). both are written from scratch in javascript. the one knob that matters is weight decay. play with it.
press to watch the network learn
Memorizing Generalizing Grokked!
0
training step
0
steps/sec
0.0%
train accuracy
0.0%
test accuracy

what the network sees — before training

correct wrong no prediction yet
training (white dot) test (no dot)
x-axis = a, y-axis = b. watch the test cells turn green.

Accuracy Over Time

blue = seen data, green = unseen data. watch the green line.

Loss Over Time

lower is better. watch test loss plateau then crash — that's the transition.

fourier spectrum — what frequencies the network uses

each bar = one frequency component (k=1 to p/2). after grokking, only a few bars should spike — those are the frequencies that solve modular addition.

try it yourself idle

+ mod 97
Network: not initialized
Try this before and after grokking. Early on, it memorizes training pairs but guesses randomly on new ones. After grokking, it gets everything right.

what the neurons learn — first-layer weights on a clock

each circle = one hidden neuron. the p points around the clock show its weight for each input value. smooth waves = Fourier features (the network found the circular structure of mod p).
before (step 0)
now (step 0)

learned representations — PCA of weight vectors

each dot = one hidden neuron's weight vector, projected to 2D via PCA. if neurons learn Fourier features, similar-frequency neurons should cluster. color = index.
before (step 0)
now (step 0)