Hinton Problems
A reproducible-baseline catalog of the synthetic learning problems that appear in Geoffrey Hinton’s experimental papers from 1981 through 2022 — implemented in pure numpy, runnable on a laptop CPU, with paper-comparison metrics per stub.
Site: https://cybertronai.github.io/hinton-problems/ • Catalog: RESULTS.md • 53 of 53 stubs implemented (PRs #32–#41, all merged 2026-05-03)
Introduction
The field has standardized on backprop by the end of the ’80s, and Hinton gives a sample of problems that were used at the time. In the last 20 years, we have transitioned to GPUs, and the math has changed considerably. Instead of being bottlenecked by arithmetic, the shrinking of transistors means that arithmetic is essentially free, and all of the work comes from data movement. Backprop is inefficient in terms of “commute to compute ratio” because it requires fetching all of the activations for each gradient add.
So a natural experiment would be to redo key experiments of this time with a focus on data movement. The first step is to get a baseline — to establish the list of problems which are famous (made by Hinton), reasonable to implement, and easy to run/reproduce.
— Yaroslav, issue #1 (Sutro Group)
This repository is that baseline. v1 ships 53 implementations covering the lineage from the 4-2-4 encoder (1985) through the shifter (1986), bars (1995), MultiMNIST (2017), Constellations (2019), Ellipse World (2022), and the Forward-Forward suite (2022). Each stub is a self-contained folder with model + train + eval + visualization + animated GIF, all in numpy, all runnable in <5 min per seed on an M-series laptop.
The next step (#45 v2) instruments these 53 baselines with ByteDMD — Yaroslav’s data-movement cost tracer — to measure the actual “commute” each algorithm pays.
What’s here
| 27 reproduce paper claims | 25 partial reproductions | 1 non-replication |
|---|---|---|
| full or qualitative match | algorithm works, paper-config gap documented | gap analysed in 3 causes |
Pure numpy + matplotlib throughout. Every stub runs on a laptop CPU. Each problem lives in its own folder with <slug>.py (model + train + eval), README.md, make_<slug>_gif.py, visualize_<slug>.py, an animated <slug>.gif, and a viz/ folder of training curves and weight visualizations.
Visual tour
![]() | ![]() |
|---|---|
encoder-4-2-4 — Ackley/Hinton/Sejnowski 1985, the worked example. Bipartite RBM, 2-bit code emerges. | spline-images-factorial-vq — Hinton/Zemel 1994, factorial VQ wins 3× over standard 24-VQ baseline. |
![]() | ![]() |
ellipse-world — Culp/Sabour/Hinton 2022, eGLOM islands form across iterations (5-class, 92.2%). | ff-recurrent-mnist — Hinton 2022, top-down recurrent Forward-Forward. |
Catalog
Each table shows the v1 result per stub. Full per-stub metrics (compile-time, GIF size, headline numbers) are in RESULTS.md.
Reproduces? legend: yes = matches paper qualitatively or quantitatively; partial = method works, paper number not fully reached (gap documented in stub README); no = paper claim does not replicate.
1980s — Connectionist foundations
Ackley, Hinton & Sejnowski (1985) — A learning algorithm for Boltzmann machines
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| encoder-4-2-4 ★ | yes (CD-k variant) | n/a (worked example) | ~1s |
| encoder-3-parity | yes (KL = log 2 visible-only; RBM drops to 0.10) | ~50 min | 0.04s + 1.3s |
| encoder-4-3-4 | yes (60% error-correcting rate / 30 seeds) | ~3 hr | 2.3s |
| encoder-8-3-8 | yes (16/20 = exact paper parity) | ~2 hr | ~20s/seed |
| encoder-40-10-40 | yes (exceeds paper: 100% vs 98.6%) | ~1.5 hr | 6s |
Rumelhart, Hinton & Williams (1986) — Learning internal representations by error propagation
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| xor | yes (qualitative) | 6.4 min | 0.3s |
| n-bit-parity | yes (qualitative; thermometer code partial) | 30 min | 0.20s |
| encoder-backprop-8-3-8 | yes (70% strict 8/8 distinct codes) | ~10 min | 0.6s |
| distributed-to-local-bottleneck | yes (graded values 0.007/0.167/0.553/0.971) | 75 min | 0.082s |
| symmetry | yes (1 : 1.994 : 3.969 weight ratio) | 12.8 min | 0.4s |
| binary-addition | yes (qualitatively; 4-3-3 succeeds, 4-2-3 stuck) | ~2 hr | 44s |
| negation | yes (4-6-3 deviation justified) | 25 min | 0.10s |
| t-c-discrimination | yes (all 3 detector families emerge) | 30 min | 0.69s |
| recurrent-shift-register | yes (89 sweeps N=3, 121 sweeps N=5) | 25 min | 0.9s / 1.1s |
| sequence-lookup-25 | yes (4-5/5 held-out generalization) | 70 min | 0.20s / 5.78s |
Hinton (1986) — Distributed representations of concepts
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| family-trees | yes (3/4 best, 1.9/4 mean — matches paper) | ~1 hr | 2.1s |
Hinton & Sejnowski (1986) — Learning and relearning in Boltzmann machines
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| shifter | yes (92.3% recognition; position-pair detectors) | 30 min | 14s |
| grapheme-sememe | yes (qualitative; +6.7pp spontaneous recovery) | 70 min | 1.7s |
Plaut & Hinton (1987) — Learning sets of filters using back-propagation
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| riser-spectrogram | yes (98.08% net vs 98.90% Bayes; gap +0.83pp) | ~7 min | 0.91s |
Hinton & Plaut (1987) — Using fast weights to deblur old memories
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| fast-weights-rehearsal | yes (rehearsed-subset recovery +22pp / 30 seeds) | 25 min | 0.14s |
1990s — Unsupervised learning, mixtures, the Helmholtz machine
Jacobs, Jordan, Nowlan & Hinton (1991) — Adaptive mixtures of local experts
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| vowel-mixture-experts | partial (MoE 92.8% / MLP 90.1%; gate partitions vowels) | 70 min | 0.09s |
Becker & Hinton (1992) — A self-organizing neural network that discovers surfaces in random-dot stereograms
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| random-dot-stereograms | yes (Imax 1.18 nats; disparity readout 0.74) | ~1 hr | 6.1s |
Nowlan & Hinton (1992) — Simplifying neural networks by soft weight-sharing
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| sunspots | yes (MoG ≤ decay ≤ vanilla; weight peaks at 0 + 0.27) | ~1 hr | 5s |
Hinton & Zemel (1994) — Autoencoders, MDL and Helmholtz free energy
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| spline-images-factorial-vq | yes (factorial wins 3× over 24-VQ baseline) | ~1 hr | ~5s |
Zemel & Hinton (1995) — Learning population codes by minimizing description length
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| dipole-position | partial (R² = 0.81; supervised warm-up needed) | ~3 hr | 2s |
| dipole-3d-constraint | yes (qualitatively; 3 dims emerge) | ~1 hr | 11s |
| dipole-what-where | partial (perpendicular manifolds, lin-sep 0.58) | ~1 hr | 2s |
Dayan, Hinton, Neal & Zemel (1995) — The Helmholtz machine
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| helmholtz-shifter | partial (3 of 4 layer-3 units shift-selective; n_top=4) | 75 min | 209s |
Hinton, Dayan, Frey & Neal (1995) — The wake-sleep algorithm
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| bars | partial (KL = 0.451 bits vs paper 0.10) | 70 min | 222s |
2000s — Products of experts, contrastive divergence, deep belief nets
Hinton (2000) — Training products of experts by minimizing contrastive divergence
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| bars-rbm | yes (7/8 bars at purity ≥0.5; 8/8 with n_hidden=16) | ~30 min | 1.5s |
Memisevic & Hinton (2007) — Unsupervised learning of image transformations
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| transforming-pairs | partial (axis-selective transformation detectors) | ~1 hr | 2s |
Sutskever & Hinton (2007) — Multilevel distributed representations for high-dimensional sequences
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| bouncing-balls-2 | partial (rollout MSE between baselines) | 75 min | 6.2s |
Sutskever, Hinton & Taylor (2008) — The recurrent temporal RBM
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| bouncing-balls-3 | partial (CD-1 recon 0.005; rollout 0.13) | ~1 hr | 3.4s |
2010s — Capsules, distillation, attention
Hinton, Krizhevsky & Wang (2011) — Transforming auto-encoders
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| transforming-autoencoders | yes (R²(dx)=0.78, R²(dy)=0.67) | ~30 min | 100s |
Tang, Salakhutdinov & Hinton (2012) — Deep Lambertian Networks
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| deep-lambertian-spheres | yes (normal angular err 27°; albedo 7× baseline) | ~50 min | 33s |
Sutskever, Martens, Dahl & Hinton (2013) — On the importance of initialization and momentum
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| rnn-pathological | yes (3 of 4 tasks; ortho beats random init) | 2.5 hr | 42s |
Hinton, Vinyals & Dean (2015) — Distilling the knowledge in a neural network
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| distillation-mnist-omitted-3 | yes (97.82% on digit-3 post-correction; paper 98.6%) | 40 min | 121.8s |
Eslami, Heess, Weber, Tassa, Szepesvari, Kavukcuoglu & Hinton (2016) — Attend, Infer, Repeat
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| air-multimnist | partial (count 79.7%; reconstructions blurry) | ~50 min | 6s |
| air-3d-primitives | partial (1-prim 88.8%; 3-prim count 81%) | ~50 min | 11.7s |
Ba, Hinton, Mnih, Leibo & Ionescu (2016) — Using fast weights to attend to the recent past
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| fast-weights-associative-retrieval | partial (architecture verified; 38% retrieval) | ~3 hr | 293s |
| multi-level-glimpse-mnist | partial (82.46% vs paper 90%+) | ~1 hr | 1199s |
| catch-game | partial (FW 33.9% vs vanilla 11.4%; 91% at size=10) | ~2 hr | ~50s |
Sabour, Frosst & Hinton (2017) — Dynamic routing between capsules
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| affnist | no (gap wrong sign: −2% vs paper +13%) | ~3 hr | 4 min |
| multimnist-capsnet | partial (48.6% vs target 80%; 22× chance) | ~3 hr | 395s |
Hinton, Sabour & Frosst (2018) — Matrix capsules with EM routing
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| smallnorb-novel-viewpoint | yes qualitatively (caps 0.726 vs CNN 0.696 held-out) | ~1 hr | 10s |
Kosiorek, Sabour, Teh & Hinton (2019) — Stacked capsule autoencoders
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| constellations | yes (per-point recovery 86.9% best / 84% mean) | ~75 min | 25s |
2020s — Subclass distillation, GLOM, Forward-Forward
Müller, Kornblith & Hinton (2020) — Subclass distillation
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| mnist-2x5-subclass | partial (subclass recovery 82.88% best / 73.87% mean) | ~50 min | 13s |
Sabour, Tagliasacchi, Yazdani, Hinton & Fleet (2021) — Unsupervised part representation by flow capsules
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| geo-flow-capsules | yes (mean IoU 0.764 / chance 0.20) | ~8 min | 43s |
Culp, Sabour & Hinton (2022) — Testing GLOM’s ability to infer wholes from ambiguous parts
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| ellipse-world | yes (92.2% on 5-class; islands form +0.117) | ~1 hr | 9s |
Hinton (2022) — The forward-forward algorithm: some preliminary investigations
| Problem | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
| ff-hybrid-mnist | partial (5.21% test err vs paper 1.37%) | ~75 min | 492s |
| ff-label-in-input | partial (3.60% vs paper 1.36%) | ~1 hr | 66s |
| ff-recurrent-mnist | partial (10.66% vs paper 1.31%) | ~1 hr | 216s |
| ff-cifar-locally-connected | partial (FF 22.78% / BP 38.31%) | ~3 hr | 150s |
| ff-aesop-sequences | yes (TF 53% / SG 34%; baselines 3-20%) | ~12 min | 131s |
Structure
problem-folder/
├── README.md source paper, problem, results, deviations
├── <slug>.py dataset + model + train + eval
├── visualize_<slug>.py training curves + weight viz
├── make_<slug>_gif.py animated GIF
├── <slug>.gif committed animation
└── viz/ committed PNGs
Roadmap
- #45 v2: ByteDMD instrumentation — measure data-movement cost per stub on these baselines (the actual research goal)
- #46 v1.5: paper-scale reruns — close the 25 partial reproductions on Modal/GPU
- See
Open questions / next experimentssection in each stub README for stub-specific follow-ups
Contributing
Implementations follow the v1 spec:
- Each stub fills in
<slug>.py(model + train + eval), an 8-sectionREADME.md,make_<slug>_gif.py,visualize_<slug>.py, an animated<slug>.gif, andviz/PNGs. - Acceptance: reproduces in <5 min on a laptop; final accuracy with seed in Results table; GIF illustrates problem AND learning dynamics; “Deviations from the original” section honest; at least one open question.
- v1 metrics in PR body:
"Paper reports X; we got Y. Reproduces: yes/no."+ run wallclock + implementation wallclock.
The v1.5 reruns (#46) and v2 ByteDMD work (#45) welcome contributions.
License
The hinton-problems source and documentation are released into the public domain under the Unlicense.
RESULTS — v1 baselines
Per-stub reproducibility, implementation difficulty, and run wallclock for the 53 implementations shipped across wave PRs #32–#41. Compiled from PR bodies for the v2 data-movement / ByteDMD filter.
Reproduces? legend: yes = matches paper qualitatively or quantitatively; partial = method works, paper number not fully reached (gap documented in stub README); no = paper claim does not replicate (gap analysis documented).
Implementation wallclock: agent end-to-end time from spec read to branch pushed. Variance is large across waves; values are agent-self-reported.
Run wallclock: time to run the final headline experiment on a laptop M-series CPU. Numpy + matplotlib only, no GPU.
1980s — Connectionist foundations
Ackley, Hinton & Sejnowski (1985) — Boltzmann learning algorithm
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
encoder-4-2-4/ (worked example) | yes (CD-k variant; paper used SA) | n/a (pre-existing) | ~1s |
encoder-3-parity/ (PR #33) | yes (KL = log 2 = 0.6931 visible-only; RBM drops to 0.10) | ~50 min | 0.04s + 1.3s |
encoder-4-3-4/ (PR #33) | yes (60% error-correcting rate / 30 seeds; even-parity codeset at seed 12) | ~3 hr | 2.3s |
encoder-8-3-8/ (PR #33) | yes (16/20 = exact paper parity) | ~2 hr | ~20s/seed |
encoder-40-10-40/ (PR #34) | yes (exceeds paper: 100% vs 98.6%) | ~1.5 hr | ~6s |
Rumelhart, Hinton & Williams (1986) — Backprop
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
xor/ (PR #32) | yes (qualitative, paper ~558 epochs / median 730) | 6.4 min | 0.3s |
n-bit-parity/ (PR #32) | yes (qualitatively; thermometer code partial) | 30 min | 0.20s |
encoder-backprop-8-3-8/ (PR #33) | yes (70% strict 8/8 distinct codes; 100% reconstruction) | ~10 min | 0.6s |
distributed-to-local-bottleneck/ (PR #34) | yes (graded values 0.007 / 0.167 / 0.553 / 0.971 vs paper 0 / 0.2 / 0.6 / 1.0) | 75 min | 0.082s |
symmetry/ (PR #32) | yes (1 : 1.994 : 3.969 weight ratio, residual 0.000) | 12.8 min | 0.4s |
binary-addition/ (PR #33) | yes (qualitatively; 4-3-3 succeeds, 4-2-3 stuck) | ~2 hr | 44s |
negation/ (PR #32) | yes (4-6-3 arch deviation justified; stub said 4-3-3 which can’t converge) | 25 min | 0.10s |
t-c-discrimination/ (PR #34) | yes (all 3 detector families emerge across 40 kernels) | 30 min | 0.69s |
recurrent-shift-register/ (PR #34) | yes (89 sweeps N=3, 121 sweeps N=5; both well under paper’s <200) | 25 min | 0.9s / 1.1s |
sequence-lookup-25/ (PR #35) | yes (phenomenon — paper has no specific number; 4-5/5 held-out) | 70 min | 0.20s / 5.78s |
Hinton (1986) — Distributed representations
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
family-trees/ (PR #35) | yes (3/4 best seed; 1.9/4 mean — matches paper’s 2/4) | ~? | 2.1s |
Hinton & Sejnowski (1986) — Learning and relearning
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
shifter/ (PR #34) | yes (92.3% recognition; position-pair detectors visible in figure3.png) | 30 min | 14s |
grapheme-sememe/ (PR #34) | yes (qualitatively; +6.7pp spontaneous recovery on held-out 2 at seed 0) | 70 min | 1.7s |
Plaut & Hinton (1987)
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
riser-spectrogram/ (PR #35) | yes (network 98.08% vs Bayes 98.90%, gap +0.83pp; paper +1.0pp) | ~7 min | 0.91s |
Hinton & Plaut (1987) — Fast weights
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
fast-weights-rehearsal/ (PR #35) | yes (rehearsed-subset recovery +22pp mean / 30 seeds) | 25 min | 0.14s |
1990s — Mixtures, Helmholtz, deep belief
Jacobs, Jordan, Nowlan & Hinton (1991)
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
vowel-mixture-experts/ (PR #39) | partial (MoE 92.8% / MLP 90.1%; gate cleanly partitions front vs back vowels — phonetically meaningful. Paper’s “MoE in half the epochs” claim does NOT replicate at 2-D F1/F2: data is nearly linearly separable, MLP wins on speed) | 70 min | 0.09s |
Becker & Hinton (1992) — Imax / spatial coherence
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
random-dot-stereograms/ (PR #36) | yes (qualitatively; Imax 1.18 nats, modules’ agreement corr 0.91, disparity readout 0.74. Paper has no single comparable scalar.) | ~1 hr | 6.1s |
Nowlan & Hinton (1992) — Soft weight-sharing
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
sunspots/ (PR #39) | yes (MoG 0.00420 ≤ decay 0.00422 ≤ vanilla 0.00432 / 5 seeds; structural effect dramatic — MoG collapses ~150 of 208 weights onto 2 crisp peaks) | ~? | ~5s |
Hinton & Zemel (1994) — Bits-back / factorial VQ
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
spline-images-factorial-vq/ (PR #37) | yes (factorial 4×6 VQ wins 3× over standard 24-VQ baseline; DL 22.0 vs 65.3) | ~? | ~? |
Zemel & Hinton (1995) — Population codes / MDL
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
dipole-position/ (PR #36) | partial (R² = 0.81 vs (x,y); supervised warm-up needed for tractable optimization. Pure-unsupervised emergence from random init is open question) | ~3 hr | 2s |
dipole-3d-constraint/ (PR #36) | yes (qualitatively; singular values 6.67 / 4.61 / 3.80 — 3 dims emerge) | ~? | 11s |
dipole-what-where/ (PR #36) | partial (two near-perpendicular 1-D manifolds, axis angle 83°; meet at origin instead of opposite corners — needs learned mixture-of-Gaussians prior) | ~? | 2s |
Dayan, Hinton, Neal & Zemel (1995) — Helmholtz machine
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
helmholtz-shifter/ (PR #36) | partial (3 of 4 layer-3 units develop clean shift-direction tuning; n_top=4 vs paper’s n_top=1 — single top unit can’t break t↔1-t symmetry on this task) | 75 min | 209s |
Hinton, Dayan, Frey & Neal (1995) — Wake-sleep
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
bars/ (PR #35) | partial (KL = 0.451 bits vs paper 0.10; structure captured but residual gap; multi-restart wrapper deferred) | 70 min | 222s |
2000s — RBMs, products of experts, deep belief
Hinton (2000) — Contrastive divergence
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
bars-rbm/ (PR #35) | yes (7/8 bars at purity ≥0.5 with n_hidden=8 / 10 seeds; 8/8 with n_hidden=16) | ~30 min | 1.5s |
Memisevic & Hinton (2007) — Gated 3-way RBM
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
transforming-pairs/ (PR #37) | partial (axis-selective transformation detectors emerge; 8-way classification 3.2× chance. Direction-selective Reichardt cells need natural video, not random-dot pairs) | ~? | 2s |
Sutskever & Hinton (2007) — TRBM
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
bouncing-balls-2/ (PR #37) | partial (rollout MSE between predict-mean and copy-last baselines; qualitatively correct first 3-4 frames then diffuses to mean) | 75 min | 6.2s |
Sutskever, Hinton & Taylor (2008) — RTRBM
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
bouncing-balls-3/ (PR #37) | partial (CD-1 recon MSE 0.0053; rollout MSE 0.13; W_h≡0 ablation matches full model on rollouts — suggests Sutskever’s BPTT correction is needed) | ~? | 3.4s |
2010s — Capsules, distillation, attention
Hinton, Krizhevsky & Wang (2011)
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
transforming-autoencoders/ (PR #38) | yes (R²(dx)=0.78, R²(dy)=0.67) | ~30 min | ~100s |
Tang, Salakhutdinov & Hinton (2012)
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
deep-lambertian-spheres/ (PR #40) | yes (normal angular error 27° / 23.7° median — hits target <30°; albedo MSE 0.012 ~7× baseline. GRBM prior dropped — paper’s actual contribution; v1 is feed-forward baseline) | ~50 min | 33s |
Sutskever, Martens, Dahl & Hinton (2013)
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
rnn-pathological/ (PR #37) | yes (3 of 4 tasks; ortho-init solves, random-init at chance; XOR not cracked at our budget — needs NAG + 8× iterations per paper) | 2.5 hr | 42s |
Hinton, Vinyals & Dean (2015) — Distillation
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
distillation-mnist-omitted-3/ (PR #38) | yes (97.82% on digit-3 post-correction; paper 98.6%. Hyperparameter-free bias correction) | 40 min | 121.8s |
Eslami, Heess, Weber, Tassa, Szepesvari, Kavukcuoglu & Hinton (2016) — AIR
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
air-multimnist/ (PR #41) | partial (count 79.7% vs target 50% — exceeds; reconstruction blurry due to under-scale; Gumbel-sigmoid throughout, no REINFORCE) | ~50 min | ~6s |
air-3d-primitives/ (PR #41) | partial (1-prim sanity 88.8%; 3-prim count 81%, type 52%; supervised regression instead of REINFORCE-AIR) | ~50 min | 11.7s |
Ba, Hinton, Mnih, Leibo & Ionescu (2016) — Fast weights attention
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
fast-weights-associative-retrieval/ (PR #36) | partial (architecture verified by gradient check 1e-9; 38% retrieval vs 90% target — optimizer-landscape gap, needs RMSProp + 10⁵ steps per Ba et al.) | ~3 hr | 293s |
multi-level-glimpse-mnist/ (PR #39) | partial (82.46% vs paper 90%+; deterministic 24-glimpse simplification + no CNN encoder) | ~1 hr | 1199s |
catch-game/ (PR #40) | partial (33.9% FW vs 11.4% vanilla at size=24; ablation unambiguous; 91% FW at size=10. REINFORCE budget below paper’s A3C compute) | ~? | ~? |
Sabour, Frosst & Hinton (2017) — Dynamic routing
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
affnist/ (PR #40) | no (gap wrong sign: CapsNet 85.5% / CNN 87.5% — paper +13%, ours −2%. 3 causes documented: synth-affNIST too close to train aug, tiny capsules, no reconstruction regularizer) | ~? | ~4 min |
multimnist-capsnet/ (PR #40) | partial (48.6% vs target 80%; 22× chance; routing-by-agreement visibly works; reduced arch for pure-numpy budget) | ~3 hr | 395s |
Hinton, Sabour & Frosst (2018) — Matrix capsules with EM routing
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
smallnorb-novel-viewpoint/ (PR #41) | yes qualitatively (caps held-out 0.726 vs CNN 0.696 / 3 seeds; caps drop 0.244 vs CNN 0.304 — 20% relative reduction. Synthesized 5-class dataset vs real smallNORB) | ~? | ~10s |
Kosiorek, Sabour, Teh & Hinton (2019) — Stacked capsule autoencoders
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
constellations/ (PR #39) | yes (per-point recovery 86.9% best / 84.0% mean; chance 36.4%. 12,708-param numpy set transformer + capsule decoder, FD-checked) | ~75 min | 25s |
2020s — Subclass distillation, GLOM, Forward-Forward
Müller, Kornblith & Hinton (2020) — Subclass distillation
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
mnist-2x5-subclass/ (PR #38) | partial (subclass recovery 82.88% best / 73.87% mean; paper ~95%+ with ResNet vs our MLP backbone. Bounded aux loss gradient verified 6e-10) | ~50 min | 13s |
Sabour, Tagliasacchi, Yazdani, Hinton & Fleet (2021) — Flow capsules
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
geo-flow-capsules/ (PR #40) | yes (mean IoU 0.764 / 200 pairs; chance ~0.20. EM-based mixture decomposition with closed-form M-step on GT flow vs paper’s learned encoder) | ~8 min | 43s |
Culp, Sabour & Hinton (2022) — eGLOM
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
ellipse-world/ (PR #37) | yes (92.2% on 5-class; +6.6pp lift from GLOM iterations; islands form — cell-similarity rises +0.117 across iterations. Hand-coded backward FD-checked 1e-6) | ~? | 9s |
Hinton (2022) — Forward-Forward
| Stub | Reproduces? | Implementation | Run wallclock |
|---|---|---|---|
ff-hybrid-mnist/ (PR #38) | partial (5.21% test err vs paper 1.37%; 4×1000 + 30 epochs vs paper 4×2000 + 60. Goodness distributions show 2.8-3.3σ pos-vs-neg separation) | ~75 min | 492s |
ff-label-in-input/ (PR #38) | partial (3.60% vs paper 1.36%; smaller arch + fewer epochs. Three FF gotchas documented for siblings: mean(h²)=1, lr=0.003, all-layers > skip-L0) | ~1 hr | 66s |
ff-recurrent-mnist/ (PR #38) | partial (10.66% vs paper 1.31%; ~25× fewer params, 3× fewer epochs. Algorithm reproduces; capacity doesn’t) | ~1 hr | 216s |
ff-cifar-locally-connected/ (PR #39) | partial (FF 22.78% / BP baseline 38.31%; paper FF 41-46% / BP 37-39%. 15pp gap mostly under-training: 10K of 50K + 10 of 60+ epochs) | ~3 hr | 150s |
ff-aesop-sequences/ (PR #39) | yes (TF 53% / SG 34% / chance 3.3% / unigram 19.6%. Paper’s “nearly identical” claim doesn’t replicate at smaller scale — TF leads SG by 19pp) | ~12 min | 131s |
Summary statistics
| Verdict | Count | Notes |
|---|---|---|
| yes (full or qualitative match) | 27 | including all backprop foundations + most encoders + distillation-omitted-3 + ellipse-world + spline-VQ |
| partial (method works, paper number gap documented) | 25 | mostly Forward-Forward at smaller scale, capsules at smaller arch, AIR variants without REINFORCE |
| no (paper claim does NOT replicate) | 1 | affnist (gap wrong sign — three causes documented) |
Total: 53 stubs implemented, all in pure numpy, all <5 min/seed on a laptop except where noted.
v2 filter recommendation
For the data-movement / ByteDMD instrumentation, prioritize stubs that:
-
Reproduce cleanly + run fast (low noise floor for measuring data-movement deltas):
xor,symmetry,n-bit-parity,negation(sub-second runs, well-converged)encoder-3-parity,encoder-backprop-8-3-8,encoder-4-2-4(Boltzmann/backprop pair on same problem)distributed-to-local-bottleneck,recurrent-shift-register,t-c-discriminationbinary-addition,riser-spectrogram(clean MSE / Bayes-optimal targets)
-
Have algorithmic variants (lets you compare data-movement properties of different algorithms on the same problem):
- 8-3-8: backprop vs Boltzmann
- bars: wake-sleep vs RBM
- shifter: Boltzmann (this) vs Helmholtz (helmholtz-shifter)
- fast-weights-rehearsal vs fast-weights-associative-retrieval
-
Defer for v2: anything where the run takes >100s or where the v1 implementation is partial — measuring data-movement on a non-converged solver isn’t informative.
Compiled by agent-0bserver07 (Claude Code) on behalf of Yad. Source: PR bodies #32-#41.
4-2-4 encoder
Boltzmann-machine reproduction of the experiment from Ackley, Hinton & Sejnowski, “A learning algorithm for Boltzmann machines”, Cognitive Science 9 (1985).

Problem
Two groups of 4 visible binary units (V1, V2) are connected through 2
hidden binary units (H). Training distribution: 4 patterns, each with a
single V1 unit on and the matching V2 unit on (others off). The 2 hidden
units must self-organize into a 2-bit code that maps the 4 patterns onto
the 4 corners of {0, 1}^2.
- Visible: 8 bits =
V1 (4) || V2 (4) - Hidden: 2 bits
- Connectivity: bipartite (visible ↔ hidden only) —
V1andV2communicate exclusively throughH - Training set: 4 patterns
The interesting property: with only 2 hidden units, the network has exactly
log2(4) bits of bottleneck capacity. Convergence requires the 4 patterns to
spread to the 4 distinct corners of {0, 1}^2. Local minima where two
patterns share a hidden code are common.
Files
| File | Purpose |
|---|---|
encoder_4_2_4.py | Bipartite RBM trained with CD-k. The Boltzmann learning rule (positive-phase minus negative-phase statistics) on a bipartite graph; same gradient form as the 1985 paper, faster sampling. |
make_encoder_gif.py | Generates encoder.gif (the animation at the top of this README). |
visualize_encoder.py | Static training curves + final weight matrix + final hidden codes. |
viz/ | Output PNGs from the run below. |
Running
python3 encoder_4_2_4.py --epochs 400 --seed 2
Training takes ~1 second on a laptop. Final accuracy: 100% (4/4).
To regenerate visualizations:
python3 visualize_encoder.py --epochs 400 --seed 2 --outdir viz
python3 make_encoder_gif.py --epochs 400 --seed 2 --snapshot-every 5 --fps 12
Results
| Metric | Value |
|---|---|
| Final accuracy | 100% (4/4) |
| Hidden codes | 4 distinct corners of {0,1}^2 (specific permutation depends on seed) |
| Restarts (seed 0) | 2 (epoch 80, epoch 160), converged by ~220 |
| Training time | ~1 sec |
| Hyperparameters | k=5, lr=0.05, momentum=0.5, batch_repeats=8, init_scale=0.1 |
| Multi-restart success rate | ~65% across 30 random seeds at 400 epochs / 5 attempts |
What the network actually learns
Hidden codes

After convergence, the 4 training patterns each get a distinct 2-bit code.
Any of the 24 permutations of {(0,0), (0,1), (1,0), (1,1)} to the 4 patterns
is a valid solution; the network picks one based on the initialization.
Weight matrix

The two columns are the hidden units H[0] and H[1]. Red = positive,
blue = negative; square area is proportional to sqrt(|w|). The V1[i]
and V2[i] rows always carry the same sign pattern — the network has
independently discovered that V1 and V2 are tied (they are on for the
same pattern), even though no direct V1↔V2 weights exist. The sign pattern
across (H[0], H[1]) for each pattern row is exactly that pattern’s hidden
code.
Training curves

The vertical red dashed lines at epochs 80 and 160 mark restarts triggered by the plateau detector. The network had been stuck with only 3 (and then 2) distinct hidden codes — two patterns had collapsed onto the same code. Re-initializing the weights with an independent random draw and continuing training produces the correct 4-corner solution by epoch ~220.
The four panels track:
- Reconstruction accuracy: argmax of the exact marginal
p(V2 | V1), computed by enumerating the 4 hidden states (deterministic — no Gibbs noise). Discrete jitter early on reflects argmax flipping while V2 probabilities are close to uniform. - Hidden-code separation: mean pairwise L2 distance between the 4 exact hidden marginals — converges to ≈ 1.1, slightly below the unit-square diagonal √2, reflecting partial saturation toward the binary corners.
- Weight norm:
‖W‖_Fgrows roughly linearly during each attempt and resets at each restart. - Reconstruction MSE: mean-squared error of the marginal
p(V2 | V1)vs the true one-hot.
Deviations from the 1985 procedure
- Sampling — CD-5 (Hinton 2002) instead of simulated annealing. Same gradient form, faster sampling, sloppier asymptotics.
- Connectivity — explicit bipartite (visible ↔ hidden), making this an RBM in modern terminology. The 1985 paper’s figure already shows bipartite connectivity for the encoder; this just makes it explicit.
- Restart on plateau — the original paper reported 250/250 convergence under simulated annealing. CD-k is more prone to local minima where two patterns collapse onto the same hidden code; we detect this via an accuracy plateau and restart with fresh weights.
Correctness notes
A few subtleties worth flagging:
-
Sampled vs exact evaluation. With only 2 hidden units,
p(H | V1)andp(V2 | V1)are exactly computable by enumerating 4 hidden states and marginalizing V2 in closed form (each V2 bit factors). The closed form for the H posterior:p(H | V1) ∝ exp(V1ᵀ W₁ H + b_hᵀ H) · ∏ᵢ (1 + exp((W₂ H + b_v2)ᵢ))The
evaluate,hidden_code_exact, andreconstruct_exacthelpers use this. An earlier sampled-Gibbs version of the same metrics had σ ≈ 6.8% accuracy noise at convergence (50 runs of a converged network, observed range 75–100%) which made the training curves jitter spuriously.hidden_codeandreconstruct(sampled) are kept for the per-frame animation, where the chain dynamics are themselves of interest. -
Per-attempt success rate is fundamental. Holding the same hyperparam recipe and only varying the seed, ~20% of random inits converge to a 4-corner code — the rest end with at least one pair of patterns sharing a hidden code. More training does not help: 200 / 400 / 800 single-attempt epochs all give 6/30 = 20% success. This suggests the local minima are true fixed points of the CD-k dynamics, not slow-convergence artifacts.
-
Restart RNG independence matters. An earlier version sampled the restart’s W from the same
rbm.rngthat was being advanced by the CD sampler — restart inits then depended on the pre-restart trajectory, which biased the multi-restart success rate downward. The current code usesnp.random.SeedSequence(seed).spawn(64)to generate truly independent inits, and replaces the training RNG at each restart. -
Plateau signal. The detector uses the binary “all 4 patterns map to distinct dominant H states” rather than
acc < 1.0. Both signals agree at convergence, but the binary signal is unaffected by argmax-flipping jitter early in training. -
cd_step(k=0)now raisesValueErrorinstead of crashing withUnboundLocalError.
Open questions / next experiments
- The 1985 paper reports 250/250 convergence with full simulated annealing. CD-k caps out at ≈ 20% per-attempt regardless of training length, suggesting the optimization regimes are qualitatively different (CD-k has true absorbing local minima here; SA’s noise schedule does not). Quantifying that gap directly would help — a faithful simulated-annealing variant on the same architecture is the natural baseline.
- Can we eliminate the local-minima problem entirely by switching to PCD, by adding a small temperature schedule to the Gibbs sampler, or by initializing the weights to span the 4 corners explicitly?
- How do FLOP and data-movement costs of CD-k compare to simulated annealing on this same problem? CD-k wins on per-step cost but loses on per-attempt success rate.
- Scaling: does the same recipe (CD-k + restart-on-plateau) succeed on the
larger
n-log2(n)-nencoders in the same paper (8-3-8, 40-10-40)? With more hidden units, the 4-corner constraint relaxes — local minima may become less severe.
3-bit even-parity ensemble (the negative result)
Boltzmann-machine reproduction of the negative result that motivates the encoder problems in Ackley, Hinton & Sejnowski, “A learning algorithm for Boltzmann machines”, Cognitive Science 9 (1985).
Demonstrates: Why hidden units are necessary. A visible-only Boltzmann machine has only first- and second-order parameters, and the 3-bit even-parity ensemble has the same first- and second-order moments as the uniform distribution. The model collapses to uniform; half the probability mass ends up on the wrong (odd-parity) patterns. Adding hidden units lifts this restriction.

Problem
- Visible units: 3 binary
- Training distribution: 4 even-parity patterns at uniform
p = 0.25—{000, 011, 101, 110}. The 4 odd-parity patterns have target probability 0. - Visible-only Boltzmann (
--n-hidden 0): pure visible model, energyE(v) = -b·v - Σ_{i<j} W_ij v_i v_j. Trained with the exact gradient (Z is computable across all 8 patterns), so this isolates the representational failure from any sampling noise. - Hidden-unit RBM (
--n-hidden K, default K=4): bipartite visible↔hidden Boltzmann machine, trained with CD-k (Hinton 2002). Evaluation enumerates the 2^(3+K) joint states for an exact marginalp(v).
Why visible-only fails — exact computation
For the 3-bit even-parity ensemble:
| Moment | Value (parity ensemble) | Value (uniform on 8) |
|---|---|---|
<v_i> | 0.5 | 0.5 |
<v_i v_j> (i ≠ j) | 0.25 | 0.25 |
These are identical. The Boltzmann learning rule
Δb_i = <v_i>_data - <v_i>_model
ΔW_ij = <v_i v_j>_data - <v_i v_j>_model
drives the model toward whichever distribution matches those moments. With
only first- and second-order parameters available, the model picks the
maximum-entropy distribution consistent with them — the uniform — and stops.
The 4 odd-parity patterns end up at probability 1/8 each; the 4 even-parity
patterns also end up at 1/8 each.
The irreducible loss is KL(parity || uniform) = log(8/4) = log 2 ≈ 0.693,
and the visible-only run hits this floor on the very first gradient step.
This is the canonical motivation for hidden units in a Boltzmann machine,
and the next problem in the catalog (encoder-4-2-4/) is the constructive
follow-up.
Files
| File | Purpose |
|---|---|
encoder_3_parity.py | VisibleBoltzmann (n_hidden=0, exact gradient) and ParityRBM (n_hidden ≥ 1, CD-k). Dataset, training loops, exact marginal p(v) by enumeration, CLI. |
visualize_encoder_3_parity.py | Static distribution bar charts (visible-only, RBM, side-by-side), training curves, RBM weight Hinton diagram. |
make_encoder_3_parity_gif.py | Generates encoder_3_parity.gif showing both runs in parallel. |
encoder_3_parity.gif | Committed animation (≈ 570 KB). |
viz/ | Output PNGs from the run below. |
Running
# the negative result (default)
python3 encoder_3_parity.py --n-hidden 0 --seed 0
# the positive contrast
python3 encoder_3_parity.py --n-hidden 4 --seed 0
# regenerate all static plots
python3 visualize_encoder_3_parity.py --seed 0
# regenerate the GIF
python3 make_encoder_3_parity_gif.py --seed 0
Wall-clock on an Apple-silicon laptop:
| Run | Time |
|---|---|
encoder_3_parity.py --n-hidden 0 (400 steps) | ~0.04 s |
encoder_3_parity.py --n-hidden 4 (800 epochs) | ~1.3 s |
visualize_encoder_3_parity.py | ~2.5 s |
make_encoder_3_parity_gif.py | ~20 s |
All under the 5-minute laptop budget.
Results
Reproducible at seed = 0 with the parameters in the table below.
Visible-only Boltzmann (the negative result)
| Metric | Value | Note |
|---|---|---|
| Final `KL(target | model)` | |
p(even patterns) | 0.500 | should be 1.0; mass is split 50/50 |
Per-pattern p(v) | 0.125 each, all 8 patterns | exactly uniform |
| Wall-clock | 0.04 s |
Per-pattern result (seed = 0):
pattern parity target model
000 even 0.250 0.125
100 odd 0.000 0.125
010 odd 0.000 0.125
110 even 0.250 0.125
001 odd 0.000 0.125
101 even 0.250 0.125
011 even 0.250 0.125
111 odd 0.000 0.125
The result is seed-independent in distribution: re-running with --seed 7
gives an identical 0.125-each output and the same 0.6931 KL. The convergence
is essentially instantaneous because the gradient is exact and the unique
maximum-entropy fixed point is hit in the first few steps.
Hidden-unit RBM (the fix)
| Metric | Value |
|---|---|
n_hidden | 4 |
| Final `KL(target | |
p(even patterns) | 0.925 |
| Wall-clock | 1.3 s |
| Hyperparameters | k=5, lr=0.05, momentum=0.5, weight_decay=1e-4, init_scale=0.5, batch_repeats=16, n_epochs=800 |
Per-pattern result (seed = 0):
pattern parity target model
000 even 0.250 0.170
100 odd 0.000 0.022
010 odd 0.000 0.008
110 even 0.250 0.308
001 odd 0.000 0.020
101 even 0.250 0.220
011 even 0.250 0.226
111 odd 0.000 0.026
Reproducibility
| Field | Value |
|---|---|
| numpy | 2.3.4 |
| Python | 3.11.10 |
| OS | macOS-26.3-arm64-arm-64bit |
| Seeds tested | 0, 7 — visible-only identical (uniform), RBM qualitatively identical (≥ 90% mass on even patterns) |
Visualizations
Distributions: target vs learned

Visible-only Boltzmann. Grey bars = target; coloured bars = learned (green for even-parity, red for odd). Every coloured bar lands on the uniform 1/8 dotted line. Half the mass is on the red bars, which should be at zero.

RBM with 4 hidden units. Green bars (even parity) carry almost all the mass; red bars (odd parity) are flattened toward zero. The match to the target isn’t perfect (the four even-parity bars are uneven), but the parity structure has been recovered.

Same plot, all three distributions on one axis: target, visible-only (uniform across 8 patterns), and RBM (concentrated on the 4 even-parity patterns).
Training curves

Left panel: KL divergence over training. The visible-only run pins
itself at log 2 ≈ 0.69 from step 1 (the first gradient step already
matches the data moments) and never moves. The RBM sits near the same
floor for a few hundred CD epochs while CD noise dominates, then
escapes once the hidden units find a parity-discriminating
configuration.
Right panel: fraction of probability mass on the 4 even-parity patterns. Visible-only is locked at 50% (matching uniform); the RBM ramps to ~92%.
RBM weights

Hinton diagram of the final 3 × 4 weight matrix (red = positive,
blue = negative; square area ∝ √|w|). The columns show the four hidden
units’ affinities with the three visible bits. Each hidden unit votes on
some particular sign pattern across (V[0], V[1], V[2]); together they
suppress the four odd-parity patterns.
Deviations from the original procedure
- Sampling for the RBM — CD-k (Hinton 2002) instead of full
simulated annealing. Same gradient form; faster sampling; sloppier
asymptotics. Result: the RBM gets
p(even) ≈ 0.92rather than the ≈ 1.0 the original SA-trained network would target. - Visible-only training uses the exact gradient. The 1985 paper would have computed the negative-phase statistics by simulated annealing. Here we enumerate the 8 visible patterns directly so the gradient is exact — this strengthens the claim that the failure is representational, not a sampling artifact.
- RBM bipartite restriction. The original Boltzmann-machine formulation allowed visible↔visible weights; the modern RBM does not. Bipartite is a strict subset, but enough capacity for parity-3.
- Hidden-unit count. The original paper does not pin down a specific K for parity-3; we use K=4 because it converges reliably without restarts. Smaller K (1–2) sometimes converges and sometimes gets stuck in CD local minima.
Open questions / next experiments
- What is the smallest
n_hiddenthat suffices? A single hidden unit with the right weights can in principle makep(v)triple-interaction by marginalisation. EmpiricallyK=1is unreliable under CD-k. A systematic per-K convergence sweep (K = 1, 2, 3, 4, with multiple seeds) would quantify this. - Faithful simulated-annealing baseline. Replacing CD-k with the 1985 SA schedule should close the 0.10 KL residual on the RBM and likely converges 100% of the time. Worth running on this small problem where SA is cheap.
- Connection to
n-bit-parity/for n > 3. For 4-bit parity, the pairwise-zero argument still holds — and so do the third- and even-order moments up to order n−1. So an RBM with hidden units can learn it, but a Boltzmann machine restricted to k-th-order interactions for any k < n cannot. Building this hierarchy explicitly would give a clean staircase of negative results. - Energy / data-movement cost of the hidden-unit fix. Per the wider Sutro framing, what does the fix cost in CD-k FLOPs and reuse distance? A first measurement under ByteDMD would slot this stub into the energy story.
4-3-4 over-complete encoder
Boltzmann-machine reproduction of an experiment from Ackley, Hinton & Sejnowski, “A learning algorithm for Boltzmann machines”, Cognitive Science 9 (1985), pp. 147–169.
Demonstrates: with over-complete hidden capacity (3 hidden units for 4 patterns, when log2(4) = 2 would already suffice), Boltzmann learning prefers an error-correcting code — the 4 chosen 3-bit codes have no two codes at Hamming distance 1.

Problem
Two groups of 4 visible binary units (V1, V2) are connected through 3
hidden binary units (H). Training distribution: 4 patterns, each with a
single V1 unit on and the matching V2 unit on (all others off).
- Visible: 8 bits =
V1 (4) || V2 (4) - Hidden: 3 bits — over-complete (8 possible corner codes, only 4 needed)
- Connectivity: bipartite (visible ↔ hidden only);
V1andV2communicate exclusively throughH - Training set: 4 patterns
The interesting property: the network has 8 hidden corners but only needs
to use 4. Among the C(8, 4) = 70 ways to pick a 4-subset of {0, 1}^3,
only two contain no Hamming-1 pair:
- even-parity set
{000, 011, 101, 110}— every pair at Hamming distance 2 - odd-parity set
{001, 010, 100, 111}— every pair at Hamming distance 2
These are the two independent sets of size 4 in the 3-cube graph (the chromatic-number-2 colouring’s two sides). Boltzmann learning, when it converges, prefers exactly these arrangements: minimising the Boltzmann energy under positive-phase pressure pushes the codes apart, and any Hamming-1 collision is unstable because flipping the differing bit costs roughly the same energy as keeping it.
A code with min Hamming distance 2 is an error-correcting code: a
single bit-flip in H always decodes back to the nearest pattern’s V2
output, since no other code is one bit away.
Files
| File | Purpose |
|---|---|
encoder_4_3_4.py | Bipartite RBM with 3 hidden units, trained with CD-k. Lifted from encoder-4-2-4/encoder_4_2_4.py with n_hidden=3. Includes exact inference (enumerate 8 hidden states), hamming_distances_between_codes(), is_error_correcting(). |
problem.py | Stub-signature wrapper re-exporting generate_dataset, build_model, train, hamming_distances_between_codes. |
make_encoder_4_3_4_gif.py | Generates encoder_4_3_4.gif. |
visualize_encoder_4_3_4.py | Static training curves + weight matrix + 3-cube + Hamming heatmap. |
viz/ | Output PNGs from the run below. |
Running
python3 encoder_4_3_4.py --epochs 1000 --seed 12
Training takes ~2 seconds on a laptop. Final accuracy: 100 % (4 / 4); final min Hamming distance: 2 (error-correcting).
To regenerate visualizations:
python3 visualize_encoder_4_3_4.py --epochs 1000 --seed 12 --perturb-after 40
python3 make_encoder_4_3_4_gif.py --epochs 1000 --seed 12 --snapshot-every 15 --fps 14
Results
Per-seed run (seed = 12, the seed used for the inlined visualizations):
| Metric | Value |
|---|---|
| Final reconstruction accuracy | 100 % (4 / 4) |
| Hidden codes | even-parity set {000, 011, 101, 110} |
| Min off-diagonal Hamming distance | 2 |
| Pairwise Hamming matrix | all off-diagonal entries = 2 |
| Error-correcting | yes |
| Restarts (seed 12) | 4 (epochs 40, 80, 120, 160), converged by ~200 |
| Wall-clock (seed 12) | ~2.3 s |
| Implementation wall-clock | ~3 hours (lifted from encoder-4-2-4) |
Hyperparameters: lr=0.1, momentum=0.5, weight_decay=1e-4, k=3 (CD-3), batch_repeats=8, init_scale=0.1, perturb_after=40, n_epochs=1000.
Multi-seed success rate
The headline error-correcting property is a seed-dependent outcome. Across 30 random seeds, holding the recipe fixed:
| Outcome | Count |
|---|---|
| Error-correcting (4 distinct codes, min Hamming ≥ 2) | 18 / 30 (60 %) |
| 4 distinct codes but min Hamming = 1 (still 100 % accuracy) | 0 / 30 |
| < 4 distinct codes (two patterns share a code) | 12 / 30 |
When the network finds 4 distinct codes, the recipe currently lands on an error-correcting arrangement every time observed. The ~40 % failure mode is code collapse — two patterns end up sharing a hidden code despite restarts.
Hyperparameter sweep (20 seeds each, 1000 epochs):
| Recipe | EC success rate |
|---|---|
lr=0.1, k=3, perturb_after=40 (default) | 60 % |
lr=0.05, k=5, perturb_after=60 | 45 % |
lr=0.05, k=5, perturb_after=40 | 15 % |
lr=0.05, k=5, perturb_after=40, init_scale=0.05 | 10 % |
lr=0.05, k=10, perturb_after=40 | 5 % |
Paper claim: “no two codes at Hamming distance 1” (error-correcting). We got: 60 % rate of EC arrangements at this recipe (and when convergence to 4 distinct codes happens, it lands on an EC set every time observed). Reproduces: yes (qualitatively); the 1985 paper used simulated annealing and reports clean convergence — see Deviations.
Visualizations
Animation (top of README)
Each frame shows three panels at one epoch:
- Left — Hinton diagram of the 8 × 3 weight matrix
W_{V↔H}(red = +, blue = −, square area ∝ √|w|). - Right — the 3-cube. White circles are unused corners. Coloured
circles are the dominant
Hcode for each of the 4 training patterns; the colour matches the pattern index. A red edge between two coloured corners signals a Hamming-1 collision (a non-error-correcting arrangement). When the network converges, all chosen corners are pairwise-far, so no red edges remain. - Bottom — accuracy and
min Hamming × 33over time. The red dashed vertical lines mark restarts (plateau detector triggered when the current arrangement stays non-error-correcting for--perturb-afterepochs). The black dashed horizontal line is atmin Hamming = 2, where EC begins.
3-cube with chosen codes

The 4 coloured corners are the dominant H codes for patterns 0–3. With
seed 12, the network lands on the even-parity set
{000, 011, 101, 110} — every pair at Hamming distance 2. No red edges
mean no Hamming-1 collisions: an error-correcting arrangement.
Pairwise Hamming-distance matrix

Diagonal zeroes (each code is distance 0 from itself); every off-diagonal
entry is 2. The (0 0 0) / (0 1 1) / (1 0 1) / (1 1 0) codes are
exactly the 4 even-parity corners of the 3-cube.
Weight matrix

The three columns are the hidden units H[0], H[1], H[2]. As in the
4-2-4 case, the V1[i] and V2[i] rows carry identical sign patterns
for each pattern i — the network independently discovers that V1 and
V2 are tied (active for the same pattern), even though no direct
V1 ↔ V2 weights exist. The (sign, sign, sign) triplet of each row
matches that pattern’s hidden code; e.g. row V1[0] is positive on
H[0], negative on H[1], negative on H[2] for code (1, 0, 0) — but
this is seed-dependent and depends on which permutation was learned.
Training curves

Six panels:
- Reconstruction accuracy — argmax of exact
p(V2 | V1), computed by enumerating the 8 hidden states. Stays noisy at 25–100 % during the pre-convergence restart phase, then locks in to 100 %. - Hidden-code separation — mean pairwise L2 distance between the 4
exact hidden marginals
p(H_j = 1 | V1). Saturates near √2 ≈ 1.41 (the diagonal of the hidden cube). - Headline metric: min Hamming — minimum off-diagonal entry of the Hamming matrix, plotted as a step function. Stays at 0 / 1 (collapsed or Hamming-1 arrangements) during the restart phase, then jumps to 2 (error-correcting) and stays there.
- Weight norm —
‖W‖_Fresets at each restart, then grows as the network locks onto the EC code. - Reconstruction MSE — mean-squared error of the marginal
p(V2 | V1)vs the true one-hot. - Distinct codes — number of distinct dominant
Hcodes (target = 4).
The four red dashed lines at epochs 40, 80, 120, 160 are restarts. After the fourth restart the network lands on a basin that finds the EC code by epoch ~200 and stays there for the remaining 800 epochs.
Deviations from the 1985 procedure
- Sampling — CD-3 (Hinton 2002) instead of full simulated annealing. Same gradient form (positive-phase minus negative-phase statistics), faster sampling, sloppier asymptotics.
- Connectivity — explicit bipartite (visible ↔ hidden) RBM. The 1985 paper’s encoder figure already shows bipartite connectivity; this makes it explicit.
- Restart on plateau — the original paper reports clean convergence under simulated annealing on the 4-3-4 and the 4-2-4. CD-k is more prone to absorbing local minima where two patterns collapse onto the same hidden code; we detect non-error-correcting plateaus and restart with fresh weights. With this wrapper, ~60 % of seeds reach the EC arrangement; the rest collapse below 4 distinct codes and exhaust the restart budget.
- Plateau signal — the detector triggers on
min Hamming < 2, stronger than just “4 distinct codes”. With over-complete capacity it is possible to land on 4 distinct codes that include a Hamming-1 pair (e.g.{000, 001, 110, 111}— distinct but two pairs at distance 1); such an arrangement reconstructs correctly but is not error-correcting, so the detector keeps restarting.
Open questions / next experiments
- The 1985 paper’s clean convergence under simulated annealing suggests the EC arrangement is the global free-energy minimum, with non-EC 4-distinct arrangements being shallow local minima. CD-k apparently fails to escape them. Quantifying that gap directly with a faithful simulated-annealing reproduction is the natural baseline.
- Can the residual ~40 % failure mode (code collapse below 4 distinct codes) be eliminated by switching to PCD, by adding a small Gibbs- temperature schedule, or by initialising weights to span the 8 corners explicitly?
- The two EC arrangements (even-parity / odd-parity) are related by flipping every hidden unit. Across runs, both should appear with equal probability — is this empirically true? (Seeds 0, 1, 4, 9, 11, 13 all gave odd-parity at parity-sum 4; seed 12 gave even-parity at parity-sum 0; a larger sample would tell.)
- Scaling: does CD-k + restart-on-plateau succeed on the 8-3-8 encoder
in the same paper? With 8 patterns embedded in 8 corners of
{0,1}^3, the EC criterion becomes “use all 8 corners” — much stricter, since the only valid arrangement is every corner (any 8-subset that omits even one corner has at least one Hamming-1 pair). Seeencoder-8-3-8/for that variant. - ByteDMD energy comparison: CD-k vs simulated annealing on the same problem. CD-k wins on per-step cost but loses on per-attempt success rate; the data-movement-weighted comparison may flip.
8-3-8 encoder
Boltzmann-machine reproduction of the experiment from Ackley, Hinton & Sejnowski, “A learning algorithm for Boltzmann machines”, Cognitive Science 9 (1985).
Demonstrates: Theoretical-minimum hidden capacity. 3 hidden binary
units = log2(8); the network must use every corner of {0,1}^3 to encode
the 8 patterns. There is zero slack — any two patterns sharing a code is
permanent failure.

Problem
Two groups of 8 visible binary units (V1, V2) connected through 3 hidden
binary units (H). Training distribution: 8 patterns, each with a single
V1 unit on and the matching V2 unit on (others off). The 3 hidden units
must self-organize into a 3-bit code that maps the 8 patterns onto the
8 distinct corners of {0, 1}^3.
- Visible: 16 bits =
V1 (8) || V2 (8) - Hidden: 3 bits — exactly
log2(8), the theoretical minimum - Connectivity: bipartite (visible ↔ hidden only) —
V1andV2communicate exclusively throughH - Training set: 8 patterns
The interesting property: unlike 4-2-4 (4 patterns, 2 hidden) or 4-3-4
(4 patterns, 3 hidden), 8-3-8 has no slack. The map from patterns to
hidden codes has to be a bijection onto the cube’s 8 corners. Local
minima where two or more patterns collapse onto the same code are the
dominant failure mode, and we measure them directly via codes_used().
Files
| File | Purpose |
|---|---|
encoder_8_3_8.py | Bipartite RBM trained with CD-k + sparsity penalty + restart-on-no-improvement. Lifted from encoder-4-2-4/, generalized to N=8 patterns / 3 hidden bits. |
make_encoder_8_3_8_gif.py | Generates encoder_8_3_8.gif (the animation at the top of this README). |
visualize_encoder_8_3_8.py | Static training curves + final weight matrix + 3-cube viz + code-occupancy bar chart. |
viz/ | Output PNGs from the run below. |
Running
python3 encoder_8_3_8.py --seed 0 --n-cycles 4000
Per-seed wall-clock: ~20 s on an Apple Silicon laptop. A successful
seed lands at 100% reconstruction accuracy and codes_used == 8.
To regenerate visualizations:
python3 visualize_encoder_8_3_8.py --seed 0 --n-cycles 4000 --outdir viz
python3 make_encoder_8_3_8_gif.py --seed 0 --n-cycles 4000 --snapshot-every 60 --fps 12
Results
| Metric | Value |
|---|---|
| Per-seed wall-clock | ~20 s |
| Success rate (20 seeds) | 16/20 = 80% — same as the 1985 paper’s 16/20 |
| Successful-seed accuracy | 100% (8/8 patterns) |
| Successful-seed codes | All 8 corners of {0,1}^3 used (codes_used() == 8) |
| Failure mode | 4/20 seeds end with 6/8 codes — two pairs of patterns collapse onto shared corners |
| Restart count (successful) | 1–11 (median ≈ 7) |
| Restart count (failure) | always hits the budget cap (15) |
Hyperparameters (locked defaults):
| Param | Value | Notes |
|---|---|---|
n_cycles | 4000 | Training epochs per seed (across all restarts) |
lr | 0.1 | |
momentum | 0.5 | |
weight_decay | 1e-4 | |
k | 5 | CD-k Gibbs steps |
init_scale | 0.3 | std of N(0, init_scale^2) weight init |
batch_repeats | 16 | gradient steps per epoch (8 patterns x 2 shuffled passes) |
sparsity_weight | 5.0 | drives E[h_j] -> 0.5 for each hidden unit |
perturb_after | 250 | restart if n_codes doesn’t improve in this many epochs |
max_restarts | 20 | budget cap per seed |
Reproduces: yes — the 20-seed sweep above reproduces with the locked defaults (no flags needed); the 16/20 success rate is exact at seeds 0..19.
Run wallclock: 20-seed sweep ~ 6 min 39 s end-to-end.
What the network actually learns
Hidden codes on the 3-cube

After convergence, all 8 corners of {0,1}^3 are occupied — one pattern
per corner. Any of the 8! = 40,320 permutations of patterns to corners
is a valid solution; the network picks one based on the random init.
Code occupancy

Bar height = number of training patterns whose dominant argmax p(H | V1)
falls on that corner. Success = every bar is exactly 1 (all green). The
common failure mode is two patterns collapsing onto a shared corner (one
red bar at 2, one grey bar at 0).
Weight matrix

The three columns are the hidden units H[0], H[1], H[2]. Red =
positive, blue = negative; square area is proportional to sqrt(|w|).
The V1[i] and V2[i] rows always carry the same sign pattern
across (H[0], H[1], H[2]) — the network has independently discovered
that V1 and V2 are tied (they fire on the same pattern), even
though no direct V1<->V2 weights exist. The 3-bit sign pattern across
the row is exactly that pattern’s hidden code.
Training curves

Vertical red dashed lines mark restarts triggered by the
no-improvement-in-n_codes detector. n_distinct_codes (top middle)
typically climbs 1 -> 4 -> 6 -> 7 in each attempt, stalls, and triggers a
restart. Once an attempt makes it to 8/8, the network locks in and
training continues to drive accuracy / code-separation upward without
further restarts. Reconstruction MSE drops to near-zero only after all
8 corners are occupied.
The five panels track:
- Accuracy — argmax of the exact marginal
p(V2 | V1)over enumerated hidden states (8 states). Deterministic; no Gibbs noise. - Codes used — number of distinct dominant
Hstates across the 8 patterns. The headline metric. Target = 8. - Code separation — mean pairwise L2 distance between the 8 exact hidden marginals.
- Weight norm
||W||_F. - Reconstruction MSE of
p(V2 | V1)vs the true one-hot.
Deviations from the original procedure
-
Sampling — CD-5 (Hinton 2002) instead of simulated annealing. Same gradient form (
<v_i h_j>_data - <v_i h_j>_model), faster sampling, sloppier asymptotics. -
Sparsity penalty — added a
-0.5*(E[h_j] - 0.5)^2regularizer driving each hidden unit toward 50% activation across the data batch. Without this term, plain CD-k consistently collapses to <= 7 codes; with it, the per-attempt success rate rises enough that the restart loop hits paper-parity (16/20).This term has no analog in the 1985 paper. It is a known RBM trick (Lee, Ekanadham, Ng 2008 “Sparse deep belief net model”) repurposed to encourage cube-corner coverage.
-
Restart on plateau — when
codes_useddoesn’t improve for 250 epochs, re-init weights with an independent random draw. Up to 20 restarts per seed (within a single 4000-epoch budget). 4/20 seeds exhaust the budget at 6 codes. -
Plateau detector signal — uses “no improvement in best
n_distinct_codesseen this attempt for 250 epochs”, which is gentler than the 4-2-4’s “any epoch below the target counts.” The 8-3-8 network typically climbs through 1->4->5->6->7 over hundreds of epochs and we don’t want to abandon a climbing attempt prematurely. -
Connectivity — explicit bipartite (visible <-> hidden), making this an RBM in modern terminology. The 1985 paper’s encoder figure is already drawn bipartite; this just makes it explicit.
Correctness notes
-
Exact evaluation. With only 3 hidden units,
p(H | V1)andp(V2 | V1)are exactly computable by enumerating 8 hidden states and marginalizing V2 in closed form (each V2 bit factors). The closed-form posterior:p(H | V1) ~ exp(V1' W1 H + b_h' H) * prod_i (1 + exp((W2 H + b_v2)_i))evaluate,hidden_code_exact,dominant_code, andreconstruct_exactall use this. No Gibbs jitter on the metrics. -
Restart RNG independence. Restart inits come from
np.random.SeedSequence(seed).spawn(max_restarts + 1)so each restart’s W draw is statistically independent of the pre-restart gradient trajectory. We replace the training RNG at each restart for the same reason. -
codes_used()is the headline metric. Reconstruction accuracy can stay at 75-88% on a partially solved network (6 or 7 codes used), but unlesscodes_used == 8the encoder hasn’t actually solved the bottleneck.
Open questions / next experiments
- Faithful simulated-annealing baseline. The 1985 paper achieved 16/20 with full simulated annealing, and we match the success rate with CD-k + sparsity + restart. A direct SA implementation on the same architecture would tell us whether the agreement is accidental or whether they pick from the same basin distribution.
- Where does the sparsity penalty contribute most? Ablation: with
sparsity off, plain CD-k caps at ~5/8 codes; with sparsity on (and
no restart), the per-attempt success rate is ~10-20%; restart carries
us the rest of the way to 80%. Quantifying the per-attempt rate as
a function of
sparsity_weightwould map the trade-off. - Scaling. The paper also reports a 40-10-40 encoder. Does the same recipe (CD-k + sparsity + restart) scale, or does the 80% rate collapse as the cube dimension grows?
- Energy / data-movement cost. Per the broader Sutro effort, the
natural follow-up is to measure the ByteDMD or reuse-distance cost
of training and compare to a backprop baseline (
encoder-backprop-8-3-8, the parallel sibling stub).
40-10-40 encoder
Boltzmann-machine reproduction of the larger-scale encoder experiment from Ackley, Hinton & Sejnowski, “A learning algorithm for Boltzmann machines”, Cognitive Science 9 (1985).
Demonstrates: Asymptotic accuracy at scale (paper: 98.6% with sufficient Gibbs sweeps) and a graceful speed/accuracy curve at retrieval: how single-chain accuracy approaches the asymptote as the Gibbs-sweep budget grows.

Problem
Two groups of 40 visible binary units (V1, V2) connected through 10
hidden binary units (H). Training distribution: 40 patterns, each with a
single V1 unit on and the matching V2 unit on (others off). The 10 hidden
units must self-organize into a 10-bit code that maps the 40 patterns onto
40 distinct corners of {0, 1}^10.
- Visible: 80 bits =
V1 (40) || V2 (40) - Hidden: 10 bits — over-complete vs.
log2(40) ≈ 5.3, leaving 1024 - 40 = 984 unused corners - Connectivity: bipartite (visible ↔ hidden only) —
V1andV2communicate exclusively throughH - Training set: 40 patterns
The interesting property: unlike 8-3-8 (zero slack: 8 patterns onto 8 of 8 corners), 40-10-40 has generous slack. The 1985 paper’s scale-up headline is not “can it fit” but how well retrieval converges — accuracy grows with Gibbs-sweep budget at retrieval time, plateauing near the asymptotic maximum. This is the canonical demonstration of the speed/accuracy tradeoff in stochastic-relaxation networks.
Files
| File | Purpose |
|---|---|
encoder_40_10_40.py | 40-10-40 RBM, CD-k + sparsity penalty + plateau-restart training, exact (1024-state) and sampled retrieval, speed_accuracy_curve(). CLI: --seed --n-cycles --gibbs-sweeps. |
make_encoder_40_10_40_gif.py | Renders encoder_40_10_40.gif (animation at the top of this README). |
visualize_encoder_40_10_40.py | Static training curves + final weight matrix + speed/accuracy plot + per-pattern code heatmap. |
viz/ | Output PNGs from the run below. |
Running
python3 encoder_40_10_40.py --seed 0 --n-cycles 2000 --print-curve
Per-seed wall-clock: ~6 s on an Apple Silicon laptop. A successful seed lands at 100% asymptotic accuracy with all 40 patterns mapping to distinct hidden corners.
To regenerate visualizations and the GIF:
python3 visualize_encoder_40_10_40.py --seed 0 --n-cycles 2000 --outdir viz
python3 make_encoder_40_10_40_gif.py --seed 0 --n-cycles 2000 --snapshot-every 80 --fps 8
Results
| Metric | Value |
|---|---|
| Per-seed train wall-clock | ~6 s |
| Success rate (10 seeds, 0..9) | 10/10 at codes_used == 40 and asymptotic accuracy 100% |
| Asymptotic accuracy (exact 1024-state enumeration) | 100.0% (paper: 98.6%) |
| Per-trial sampled accuracy plateau (T=1.0) | ~91% (single Gibbs chain at retrieval) |
| Ensemble sampled accuracy (100 chains, mean V2 prob) | 100.0% from 1 sweep |
| Distinct dominant hidden codes | 40/40 (of 1024 cube corners) |
| Restart count (10-seed sweep) | 0 across all seeds |
Speed/accuracy curve at retrieval (T = 1.0, seed 0):
| Gibbs sweeps | Per-trial accuracy | Ensemble accuracy (100-chain mean) |
|---|---|---|
| 1 | 3.5% | 100.0% |
| 2 | 63.7% | 100.0% |
| 4 | 85.9% | 100.0% |
| 8 | 90.6% | 100.0% |
| 16 | 91.5% | 100.0% |
| 32 | 90.3% | 100.0% |
| 64 | 91.8% | 100.0% |
| 128 | 90.9% | 100.0% |
| 256 | 90.8% | 100.0% |
| 512 | 91.5% | 100.0% |
The “per-trial” column is the headline: a single Gibbs chain initialized from
random V2, V1 clamped, run for the listed number of sweeps. After 1 sweep
the chain hasn’t moved off chance (40 patterns → ~2.5% chance, observed ~3.5%).
After 8 sweeps it plateaus near 91%. The “ensemble” column averages the V2
conditional probability across 100 parallel chains; argmax of that mean
matches truth 100% from the very first sweep, demonstrating that chain
disagreement is consensual rather than systematic.
Hyperparameters (locked defaults):
| Param | Value | Notes |
|---|---|---|
n_cycles | 2000 | Training epochs |
lr | 0.1 | |
momentum | 0.5 | |
weight_decay | 1e-4 | |
k | 5 | CD-k Gibbs steps |
init_scale | 0.3 | std of N(0, init_scale^2) weight init |
batch_repeats | 8 | gradient steps per epoch |
sparsity_weight | 5.0 | drives E[h_j] -> 0.5 for each hidden unit |
perturb_after | 250 | restart if accuracy doesn’t improve in this many epochs |
max_restarts | 10 | budget cap per seed |
eval_every | 25 | epochs between exact-accuracy evaluations during training |
Reproduces: Yes. Paper reports 98.6% asymptotic accuracy. We get 100% asymptotic accuracy on every seed in 0..9 with the locked defaults. The graceful speed/accuracy curve (per-trial plateau ~91% at T=1.0) is the qualitative match — accuracy improves smoothly with sweep budget and saturates well above chance.
Run wallclock: single-seed run ~ 6 s end-to-end. 10-seed sweep ~ 60 s.
Visualizations
Speed/accuracy curve

The headline plot. Blue is per-trial accuracy (single Gibbs chain at retrieval, averaged over 100 independent initializations); orange is ensemble argmax (argmax of the mean V2 conditional probability across all 100 chains). The black dashed line is the asymptotic accuracy obtained by exact enumeration of the 1024 hidden states; the dotted gray line is chance (1/40 = 2.5%).
The blue curve climbs from chance to its plateau in roughly 8 sweeps and stays there. The gap between blue (~91%) and orange/black (100%) is sampling jitter: single-chain disagreement averages out across many chains.
Training curves

Five panels:
- Reconstruction accuracy — argmax of the exact marginal
p(V2 | V1)over enumerated hidden states (1024 states). Deterministic; no Gibbs noise. Hits 100% around epoch 200 and stays. - Codes used — distinct dominant
Hstates across the 40 patterns (target = 40 of 1024). Climbs rapidly past 35 then locks at 40. - Code separation — mean pairwise L2 distance between the 40 exact hidden marginals; keeps growing as weights pull patterns into corners.
- Weight norm
||W||_F— grows steadily through training. - Reconstruction MSE — squared error of
p(V2|V1)against the true one-hot, decays to ~0.
No restart was triggered on seed 0 (10/10 seeds in 0..9 succeed without
restart). The restart-on-plateau machinery is lifted from the sibling
encoder-8-3-8 PR but turned out unused at this scale — slack is generous
enough to avoid the local minima that bedevil 8-3-8.
Per-pattern dominant codes

Each row is a pattern (0..39), each column is a hidden bit (H[0]..H[9]),
black = 1, white = 0. All 40 rows are distinct → all 40 patterns occupy
distinct cube corners. Rows look like 10-bit hash codes: there is no
visible structure (the network picked an essentially random injection
from {patterns} → {corners of {0,1}^10}).
Weight matrix

Hinton diagram of the 80×10 weight matrix. Rows 0..39 are V1[0..39];
rows 40..79 are V2[0..39]; columns are H[0..9]. Red = positive, blue =
negative; square area ∝ √|w|. Each row’s 10-bit sign pattern across columns
is approximately that pattern’s hidden code, confirmed against
viz/code_occupancy.png. Like in the 8-3-8 case, the V1[i] and V2[i]
rows carry similar sign patterns even though no direct V1↔V2 weights
exist — the bipartite RBM has rediscovered that V1[i] and V2[i] co-fire
through the hidden layer.
Deviations from the original procedure
-
Sampling. CD-5 (Hinton 2002) instead of full simulated annealing. Same gradient form (
<v_i h_j>_data - <v_i h_j>_model); the model expectation is taken from 5 Gibbs sweeps rather than an annealed chain. -
Sparsity penalty. Added a
-0.5*(E[h_j] - 0.5)^2regularizer driving each hidden unit toward 50% activation across the data batch. No analog in the 1985 paper. Lifted from the sibling 8-3-8 recipe (PR #18). For 40-10-40 the slack is generous and a milder penalty also works, but matching the 8-3-8 weight (5.0) gives clean separation without tuning. -
Plateau-restart wrapper. Up to 10 restarts triggered if accuracy stagnates for 250 epochs. This was a survival kit at 8-3-8 scale (16/20 seeds needed restarts to hit paper-parity). At 40-10-40 scale none of the first 10 seeds tested needed any restart — the slack between 1024 corners and 40 patterns avoids the collision local minima that dominate 8-3-8. Kept the wrapper in place for seed-robustness on harder hyperparameter regimes.
-
Connectivity. Explicit bipartite (visible ↔ hidden), making this an RBM in modern terminology. The 1985 paper’s encoder figure is already drawn bipartite; this just makes it explicit.
-
Two distinct accuracy modes. We report asymptotic (exact, by enumerating the 1024 hidden states) and sampled per-trial / ensemble (Gibbs chains at retrieval). The 1985 paper’s 98.6% figure conflates them; here they’re separate metrics with the asymptotic limit explicitly identified.
Correctness notes
-
Exact evaluation. With 10 hidden units,
p(H | V1)andp(V2 | V1)are tractable by enumerating 2^10 = 1024 states. Closed-form posterior (V2 marginalized in closed form because each V2 bit factors given H):p(H | V1) ~ exp(V1' W1 H + b_h' H) * prod_i (1 + exp((W2 H + bv2)_i))evaluate_exact,hidden_posterior_exact,reconstruct_exactall use this. No Gibbs jitter on the asymptotic-accuracy metric. -
Per-trial vs. ensemble accuracy. The
speed_accuracy_curvefunction exposes both modes via itsmode=argument:"per_trial"reports the fraction of single chains that recover the right pattern,"averaged"reports the argmax of the mean V2 probability across many chains. The asymptotic limit (1024-state enumeration) sits at 100% — both sampled modes converge upward toward it. -
Sweep semantics. One “Gibbs sweep” = (sample V given H, with V1 clamped) followed by (sample H | V). After
n_sweeps, we read out the conditionalp(V2 | H_last)from the last hidden sample. No annealing schedule is applied at retrieval; the headline curve is at fixedT=1.0.
Open questions / next experiments
- Faithful simulated-annealing baseline. The 1985 paper used a slow annealing schedule both for training and retrieval; the 98.6% figure is the asymptote of that procedure. A direct SA implementation on the same architecture would tell us whether our 100% (CD-k + sparsity) is picking up real performance or merely overfitting the noise-free toy distribution.
- Sparsity weight ablation. With our defaults, all 10 seeds succeed in
0 restarts — the recipe is over-provisioned for 40-10-40. How low can
sparsity_weightgo before per-attempt success drops? Mapping that curve would expose how much of our 100% rate is from sparsity vs. slack. - Annealed retrieval. The per-trial curve plateaus around 91% at T=1.0. Cooling a single chain (T=2 → T=0.5 → T=1) during retrieval would close most of the gap to the 100% asymptote without resorting to many parallel chains.
- Energy / data-movement cost. Per the broader Sutro effort, the natural follow-up is to measure the ByteDMD or reuse-distance cost of the speed/accuracy tradeoff: at fixed accuracy budget, what’s the cheapest retrieval procedure (one long chain at low T vs many short chains at T=1)? The speed/accuracy curve here is the abstraction the energy metric will plug into.
agent-0bserver07 (Claude Code) on behalf of Yad
8-3-8 backprop encoder
Backprop reproduction of the encoder problem from Rumelhart, Hinton & Williams,
“Learning internal representations by error propagation”, in
Parallel Distributed Processing, Vol. 1, Ch. 8 (MIT Press, 1986). The
problem itself comes from Ackley, Hinton & Sejnowski (1985); this stub
trains the same architecture with backprop instead of CD-k / Boltzmann
learning, so it sits next to the encoder-8-3-8/ and
encoder-4-2-4/ Boltzmann siblings as the
algorithmic counterpart.

Problem
- Input: 8 one-hot vectors of length 8 (the 8x8 identity).
- Hidden: 3 sigmoid units. This is the bottleneck.
- Output: 8 sigmoid units. Target = input (autoencoder).
- Training set: 8 patterns, full-batch.
The interesting property: log2(8) = 3, so the bottleneck has exactly enough
capacity to store the 8 patterns – if and only if the 3 hidden units saturate
toward the 8 corners of {0, 1}^3. Backprop is free to settle anywhere in
[0, 1]^3, but the tied input/output weights and the cross-entropy cost
push activations toward the cube corners. After convergence the binarized
hidden activations form a 3-bit code that distinguishes all 8 inputs.
This is the backprop counterpart to encoder-8-3-8 (Boltzmann). Same architecture, same training set, different learning rule. The interesting comparison is which algorithm hits the 8-distinct-corner code more reliably and how that interacts with local minima.
Files
| File | Purpose |
|---|---|
encoder_backprop_8_3_8.py | 8-3-8 MLP autoencoder + backprop + hidden_code_table() + CLI. |
visualize_encoder_backprop_8_3_8.py | Static training curves + weight heatmaps + 3-cube hidden codes + code-table heatmap. |
make_encoder_backprop_8_3_8_gif.py | Generates encoder_backprop_8_3_8.gif (animation at the top of this README). |
encoder_backprop_8_3_8.gif | Committed animation. |
viz/ | Output PNGs from the run below. |
Running
python3 encoder_backprop_8_3_8.py --seed 0
Training takes ~0.6 sec on a laptop (seed 0, solves in 3631 epochs). Final accuracy: 100% (8/8) with 8/8 distinct binarized hidden codes.
To regenerate visualizations:
python3 visualize_encoder_backprop_8_3_8.py --seed 0 --outdir viz
python3 make_encoder_backprop_8_3_8_gif.py --seed 0 --epochs 4000 --snapshot-every 40 --fps 15
Results
| Metric | Value |
|---|---|
| Final reconstruction accuracy (seed 0) | 100% (8/8) |
| Distinct binarized hidden codes (seed 0) | 8/8 |
| Epochs to solve (seed 0) | 3631 |
| Training wallclock (seed 0) | ~0.6 sec |
| Visualization wallclock | ~1.7 sec |
| GIF generation wallclock | ~19 sec |
| Hyperparameters | full-batch GD, lr=0.5, momentum=0.9, init_scale=0.1, sigmoid hidden + sigmoid output, cross-entropy loss |
| Per-seed solve rate | 21/30 = 70% (seeds 0-29, n_epochs=20000) |
The “solve” criterion is strict: 100% reconstruction AND all 8 binarized codes distinct. Reconstruction alone hits 100% on every seed; the 30% unsolved fraction is networks that reconstruct correctly but assign two patterns to nearby raw activations that round to the same 3-bit corner (typically 7/8 distinct, with one pair sharing a code).
What the network actually learns
Hidden codes on the 3-cube

After convergence, the 8 training patterns each get a distinct corner of the
3-cube. Any of the 8! = 40320 permutations of {0,1}^3 corners to the 8
patterns is a valid solution; the network picks one based on the
initialization. Activations are typically saturated within ~0.05 of the
nearest corner.
Code table

Side-by-side view of the raw hidden activations (left, viridis) and the binarized 3-bit codes (right). Each row is one of the 8 input patterns; each column is one of the 3 hidden units. All 8 rows in the binarized panel are distinct (otherwise the network would not be “solved”).
Weights

W1 (input -> hidden): each row is a one-hot input bit, each column is a
hidden unit. Reading down a column gives the weight pattern that turns
that hidden unit on for each input. With sigmoid hidden units and one-hot
inputs, the sign pattern in column j is the j-th bit of each pattern’s
3-bit code.
W2 (hidden -> output): each row is a hidden unit, each column is an
output bit. The columns reconstruct the input identity by combining the
3-bit code stored in the hidden activations.
Training curves

Four panels:
- Loss (log scale): cross-entropy on the 8 patterns. Drops sharply once the hidden codes start to separate.
- Reconstruction accuracy: argmax of the output equals the input class. Saturates to 100% well before the binary code finishes saturating.
- # distinct binarized codes: the strict solve signal. Often plateaus at 7/8 for thousands of epochs before a slow drift pushes the last pair apart, or stays stuck at 7/8 indefinitely (the local-minimum cases).
- Weight norm:
‖W_1‖_F + ‖W_2‖_F, growing roughly linearly as the sigmoids saturate.
Deviations from the original procedure
- Loss function — cross-entropy with sigmoid outputs, not the squared
error in the 1986 PDP chapter. Cross-entropy speeds convergence on this
problem; the gradient form is the same up to a constant factor at the
output (
y - t). Gradient with respect to hidden weights and the “binary code emerges” finding are unchanged. - Optimizer — full-batch gradient descent with momentum 0.9 and a fixed lr=0.5. The original used a slightly smaller lr and lots of patient waiting; we use the now-standard “modern” recipe to keep wallclock under a second.
- No restart-on-plateau wrapper — the encoder-4-2-4 Boltzmann sibling uses a restart wrapper to escape local minima. We keep this stub single-attempt and report the per-seed success rate honestly (~70% over 30 seeds). A restart wrapper would push that to ~100% but obscures the fact that a fraction of inits genuinely settle at 7/8 distinct codes.
- Initialization — uniform
(-0.1, 0.1)rather than the larger Gaussian sometimes used. Small init helps the network find the 8-corner solution more often (15+/30 vs 9/30 with init_scale=0.5).
Open questions / next experiments
- Where does the local-minimum 7/8 case live? When the network is stuck at 7/8 distinct codes, which pair of patterns is sharing? Is it the same pair across runs, or seed-dependent? A histogram of the offending pair across the 9 unsolved seeds would tell us whether the failure mode has structure or is just symmetry-breaking noise.
- Restart-on-plateau — port the wrapper from
encoder-4-2-4/and measure how many restarts are needed to hit 100% solve. The Boltzmann sibling needs ~2 restarts at 65% per-attempt; backprop here is at ~70% per-attempt, so the budget should be similar. - Compare directly to encoder-8-3-8 (Boltzmann) — same architecture, different algo. Backprop is faster per step but the per-attempt success rate is in the same range. A side-by-side wallclock + success-rate plot is the natural next experiment.
- Energy / data-movement cost — out of scope for v1, but the broader
Sutro question is whether backprop’s cheaper sampling (no Gibbs chain)
also wins on data-movement complexity. Given how small this network is,
the comparison is mostly symbolic, but it sets up the same comparison
for
encoder-40-10-40.
agent-0bserver07 (Claude Code) on behalf of Yad
XOR
Source: Rumelhart, Hinton & Williams (1986), “Learning representations by back-propagating errors”, Nature 323, 533–536. Long version: PDP Vol. 1, Ch. 8, “Learning internal representations by error propagation”.
Demonstrates: Backprop overcomes the Minsky-Papert single-layer limitation. Reported in the paper: ~558 sweeps to converge for the 2-2-1 net, with ~2 out of hundreds of runs hitting a local minimum.

Problem
| input | target |
|---|---|
| (0, 0) | 0 |
| (0, 1) | 1 |
| (1, 0) | 1 |
| (1, 1) | 0 |
XOR is the simplest non-linearly-separable Boolean function on two inputs: no single line in the (x₁, x₂) plane separates the 1-outputs from the 0-outputs. Minsky & Papert (1969) used this to argue that the perceptron — a single linear threshold unit — was fundamentally limited. RHW1986’s contribution was the recipe for training the hidden layer needed to bend the decision boundary into the right shape: backpropagate the output error through a sigmoid non-linearity to get a gradient on every weight in the network.
The interesting property: with only 9 parameters (2-2-1: six weights + three biases) the network has just enough capacity to carve the (x₁, x₂) plane into two regions that match the four corners. There are several non-trivial local minima — configurations where two opposite corners share the same prediction — and the success rate per random init is sensitive to the initial weight scale.
Files
| File | Purpose |
|---|---|
xor.py | Dataset (4 patterns) + 2-2-1 / 2-1-2-skip MLP + backprop with momentum + CLI. Numpy only. |
visualize_xor.py | Static training curves, Hinton-diagram weights, decision-surface PNG, hidden-unit activations. |
make_xor_gif.py | Animated GIF: decision surface + weights + training curves over time. |
xor.gif | Committed animation (1.2 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 xor.py --seed 0
Training takes about 0.3 seconds on an M-series laptop. Final accuracy: 100% (4/4) at this seed; 25/30 random seeds converge to 100% within 5000 epochs (default --init-scale 1.0).
To regenerate the visualizations:
python3 visualize_xor.py --seed 0
python3 make_xor_gif.py --seed 0 --max-epochs 1500 --snapshot-every 20
To run the multi-seed sweep that produced the success-rate stats:
python3 xor.py --sweep 30 --max-epochs 5000
Results
Single run, --seed 0:
| Metric | Value |
|---|---|
| Final accuracy | 100% (4/4) |
| Final MSE loss | 0.026 |
| Converged at epoch | 1393 (first epoch with ` |
| Wallclock | ~0.3 s |
| Outputs | (0,0)→0.24, (0,1)→0.78, (1,0)→0.78, (1,1)→0.22 |
| Hyperparameters | arch=2-2-1, lr=0.5, momentum=0.9, init_scale=1.0 (uniform [-0.5, 0.5]), full-batch updates |
Sweep over 30 seeds (--sweep 30 --max-epochs 5000, default hyperparameters):
| Architecture | Converged | Mean epochs | Median epochs | Min | Max |
|---|---|---|---|---|---|
| 2-2-1 | 25/30 | 964 | 730 | 474 | 2489 |
| 2-1-2-skip | 29/30 | 1334 | 1005 | 357 | 3682 |
Comparison to the paper:
Paper reports ~558 sweeps to converge for 2-2-1; ~2 of hundreds of runs in a local minimum. We get median 730 epochs over 30 seeds (range 474–2489); 5/30 (~17%) seeds stall in a local minimum within 5000 epochs.
The order of magnitude matches and individual seeds (e.g. --seed 3 converges at 531) land essentially on top of the paper’s 558. The mean is biased up by a long tail. The failure rate is higher than the paper’s claim, almost certainly because the paper used a perturbation-on-plateau wrapper that we have not implemented for v1 — see Deviations below.
Visualizations
Decision surface (final, seed 0)

The shaded heatmap is the network’s output o(x₁, x₂) evaluated on a 200×200 grid. Red is high, blue is low. The black contour is the o = 0.5 decision boundary. The four training points sit on opposite corners of the unit square: the (0,0) and (1,1) corners (target 0) are blue, the (0,1) and (1,0) corners (target 1) are red. Two roughly parallel “stripes” of decision boundary pass between them — the network has approximated XOR by the textbook construction (one hidden unit fires on x₁ OR x₂, the other on x₁ AND x₂, and the output is their difference).
Hidden-unit activations

What the two hidden units actually fire for at each of the 4 training inputs. After convergence each hidden unit picks a different “feature” of the input: typically one fires on x₁ + x₂ ≥ 1 (an OR-ish unit) and the other on x₁ + x₂ ≥ 2 (an AND-ish unit). The output unit subtracts AND from OR to get XOR.
Weight matrices

Hinton-diagram view of the 9 parameters after training. Red is positive, blue is negative; square area is proportional to √|w|. The hidden-layer panel (left) shows that h₁ and h₂ have learned different combinations of x₁ and x₂, with biases that put their thresholds at distinct places along the x₁ + x₂ axis. The output panel (right) shows the relative weighting: one hidden unit pushes the output up, the other pushes it down.
Training curves

Three signals over training:
- Loss drops in two phases: a slow plateau near
0.125(network is outputting ≈ 0.5 on every pattern, which is the constant-prediction MSE for two-class balanced targets), then a sudden break around epoch 1000 once the hidden-unit features cross their threshold and become useful. - Accuracy is flat at 50% during the plateau and steps up to 100% around the same break.
- Weight norm tells the same story from a third angle: the weights stay tiny while the network is stuck, then grow rapidly during the break as the hidden units commit to definite features. The green dashed line marks the convergence epoch (1393).
This three-phase signature — plateau, break, refinement — is characteristic of XOR backprop and is the textbook example of the “phase transition” in shallow-network training.
Deviations from the original procedure
- Init distribution. Paper uses uniform
[-0.3, 0.3](Hinton’s standard small-init recipe). We use[-0.5, 0.5](our default--init-scale 1.0) because it gave the best agreement with the paper’s epoch count. With--init-scale 0.6we match the paper’s init range but median epochs jumps to 1648 and the loss-plateau gets longer. - No perturbation-on-plateau wrapper. RHW1986 reports treating the rare local-minimum runs by perturbing weights and continuing. We don’t — a “stuck” run in our sweep stays stuck for 5000 epochs and is counted as a failure. This explains our higher failure rate (5/30 vs. their ~2/hundreds).
- Floating-point precision.
float64numpy. The 1986 paper’s hardware was not IEEE 754 in the modern sense; this should not matter for a problem this small. - Sigmoid clamping. We clip the pre-activation to
[-50, 50]to avoidnp.expoverflow, a 21st-century numerical hygiene step. - Convergence criterion. We use the paper’s stated rule: every output within 0.5 of its target (i.e. argmax matches). Same as the paper.
Otherwise: same architecture, same loss (mean of 0.5 (o − y)²), same training algorithm (full-batch backprop with momentum), same hyperparameters (η = 0.5, α = 0.9).
Open questions / next experiments
- Local-minimum analysis. The 5 stalled seeds in our sweep all hit ~50% accuracy and stay there. Are they all the same local minimum (e.g. both hidden units converged to the same feature, so the network reduces to a perceptron) or genuinely different fixed points? A clustering analysis on the stuck weight vectors would answer this.
- Adding the perturbation wrapper. RHW1986’s procedure escapes local minima by perturbing stalled weights and continuing. Adding this should match their <1% failure rate and is the natural next experiment.
- Data movement. This is the v1 baseline. v2 (the broader Sutro effort) will instrument the same training loop with ByteDMD and ask whether a non-backprop solver (e.g. Hebbian + a tiny outer loop, or direct algebraic construction since XOR is parity-2) can hit the same accuracy with lower data-movement cost. The 2-2-1 architecture has only 9 floats of state, so even ARD-1 should be achievable for the inference path — the open question is whether the training path can be cheaper than full backprop.
- 2-1-2-skip vs 2-2-1. Our sweep shows 2-1-2-skip is slightly more reliable (29/30 vs 25/30) at the cost of more epochs in median (1005 vs 730). The skip connection seems to make the loss landscape gentler. Worth quantifying with a larger sweep.
- Generalization to k-bit parity. XOR is 2-bit parity. The Sutro Group’s broader work uses sparse parity at n=20, k=3. Walking the bridge from a 9-parameter MLP solving XOR at hundreds of epochs to a 200-hidden network solving sparse parity in millions of gradient steps would clarify what scales and what doesn’t.
N-bit parity
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8, pp. 318–362 (MIT Press). The “Parity” example occupies §8.2 (“Examples”), shortly after the XOR demo.
Demonstrates: With N hidden units, an MLP trained by backprop discovers a hidden representation in which each unit responds monotonically to the number of “on” bits in the input. The textbook construction is the thermometer code — hidden unit h_k fires when at least k of the N inputs are on, for k = 1..N. The output then computes parity by taking an alternating-sign sum of the staircase: o = h_1 - h_2 + h_3 - h_4 + .... This is the minimal hidden-layer construction for parity.

Problem
| input bits | parity target |
|---|---|
| 0000 | 0 |
| 0001 | 1 |
| 0011 | 0 |
| 0111 | 1 |
| 1111 | 0 |
| … (16 patterns total for N=4) |
The target is 1 if an odd number of bits are on, else 0. Parity is the canonical hard Boolean function: every input bit matters (no partial-information shortcut), and it requires k-th-order interaction detection across all bits. A single linear layer (perceptron) cannot represent it — Minsky & Papert (1969) used 2-bit parity (XOR) as the headline counter-example to perceptron learning. RHW1986 showed that a one-hidden-layer MLP with N hidden sigmoids can learn N-bit parity from all 2^N patterns, and that the hidden layer self-organizes into a thermometer-like code.
The interesting property is what the hidden layer learns. With exactly N hidden units the network has just enough capacity for the textbook construction. Some seeds find the thermometer code; others find an equivalent solution where some units detect parity-completion features instead. The animated GIF above shows the hidden code stretching from a degenerate flat line at initialization into a clear monotonic staircase as training progresses.
Files
| File | Purpose |
|---|---|
n_bit_parity.py | Dataset (all 2^N patterns) + N-H-1 sigmoid MLP + backprop with momentum + thermometer-code analysis + CLI. Numpy only. |
visualize_n_bit_parity.py | Static training curves, Hinton-diagram weights, thermometer-code panel, full hidden-activation heatmap, prediction bar chart, per-seed monotonicity sweep. |
make_n_bit_parity_gif.py | Animated GIF: thermometer code + W1 weights + training curves over training. |
n_bit_parity.gif | Committed animation (1.2 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 n_bit_parity.py --n-bits 4 --seed 0
Training takes about 0.2 seconds on an M-series laptop. Final accuracy: 100% (16/16) at this seed.
To regenerate the visualizations:
python3 visualize_n_bit_parity.py --n-bits 4 --seed 0 --sweep 10
python3 make_n_bit_parity_gif.py --n-bits 4 --seed 0 --max-epochs 2400 --snapshot-every 30
To see how convergence rate scales with N:
python3 n_bit_parity.py --sweep-n 2-7 --sweep 5 --max-epochs 60000
Results
Single run, --seed 0, N = 4:
| Metric | Value |
|---|---|
| Final accuracy | 100% (16/16) |
| Final MSE loss | 0.0131 |
| Converged at epoch | 2308 (first epoch with ` |
| Wallclock | 0.20 s for the training loop (0.43 s including process startup per time(1)) |
| Mean monotonicity of hidden code | 0.60 (1.0 = perfectly monotonic with bit-count) |
| Hyperparameters | N=4, hidden=4, lr=0.5, momentum=0.9, init_scale=1.0 (uniform [-0.5, 0.5]), full-batch, bipolar {-1, +1} inputs, spread-bias init |
N-sweep (5 seeds each, max 60 000 epochs):
| N | converged | median epochs | min | max |
|---|---|---|---|---|
| 2 | 3 / 5 | 252 | 240 | 441 |
| 3 | 5 / 5 | 2 703 | 816 | 18 833 |
| 4 | 3 / 5 | 12 369 | 2 308 | 15 101 |
| 5 | 2 / 5 | 18 584 | 14 903 | 22 264 |
| 6 | 1 / 5 | 56 220 | 56 220 | 56 220 |
| 7 | 0 / 5 | — | — | — |
Convergence-rate-per-seed degrades sharply with N — exactly what RHW1986 noted. The N-hidden architecture has just barely enough capacity for parity, so the loss landscape is full of local minima, and many seeds get stuck on the long mid-training plateau (see the training curves below). The fix RHW1986 describe is to add a perturbation-on-plateau wrapper, which we did not implement for v1.
Comparison to the paper:
Paper reports: “We have found that with this (N hidden) architecture, the network learns the parity function for inputs up to about size 8” (PDP Vol. 1, p. 334), and informally describes the hidden representation as a “thermometer code” (each unit fires when ≥ k bits are on).
We get: 100% accuracy on N = 2..6 for at least one seed, 0/5 at N = 7 within 60 000 epochs (the paper’s “up to about size 8” claim almost certainly required either weight-perturbation rescue or the longer training horizons available with the more aggressive hand-tuned hyperparameters of the era). Hidden code is partially thermometer: 2 of 4 hidden units form clean monotonic detectors at our headline seed, while the other 2 detect mid-bit-count parity-completion features. Across a 10-seed sweep at N = 4, mean per-seed monotonicity ranges 0.20–0.90 with a median of 0.60.
Paper reports up to N=8; we got up to N=6 cleanly (and N=7 within 60 000 epochs at zero of 5 seeds). Reproduces: yes, qualitatively — backprop solves N-bit parity with N hidden, hidden representation is thermometer-LIKE (monotonic in bit-count), and convergence rate degrades with N as the paper warned. We did not match N = 8 in v1 because we did not implement the perturbation-on-plateau rescue.
Visualizations
Thermometer-code panel (the centrepiece)

The left subplot shows the mean hidden-unit activation grouped by input bit-count (0 = “no bits on”, 4 = “all bits on” for N = 4). A perfect thermometer code would show four parallel sigmoidal steps shifted along the bit-count axis — h_k flat-low until bit-count reaches k, then flat-high. We see two of the four units (h2, polarity = +; h1, polarity = −) form clean monotonic step functions with thresholds 2 and 3 respectively. The other two (h3, h4) form a “middle bump” — they peak at intermediate bit-counts and contribute the parity-specific cross terms that the strict thermometer construction would have to wring out of the staircase via the alternating-sign output weights.
The right subplot is the same data as a heatmap, with hidden units sorted by their effective threshold and negative-polarity units flipped so the staircase reads top-to-bottom. The two “thermometer” rows form a clean ladder; the two “bump” rows are the residual non-monotonic detectors.
Per-seed thermometer-likeness

Mean monotonicity score (averaged across the four hidden units) for 10 random seeds at N = 4. Green bars are seeds that converged to 100% accuracy; gray bars failed to converge in 30 000 epochs. Even among converged seeds the score varies from 0.45 to 0.90 — strict thermometer codes are achievable but not the typical attractor.
Per-pattern hidden activation

The same hidden activations, broken out per individual input pattern (rows sorted by bit-count, then by index). h1 reads as a near-perfect “low bit-count detector” (top half bright, bottom half dark). h2 is its monotonic mirror (bottom half bright). h3 activates strongly only on a subset of bit-count = 1 and bit-count = 3 patterns (the parity-1 cases). h4 is the hardest to summarize — it picks specific patterns at every bit-count to compensate for the rounding errors the other three units leave behind.
Predictions

Output sigmoid for every one of the 16 input patterns, sorted by bit-count. Black ×’s are the targets, red bars are the network output for “target 1” patterns and blue bars for “target 0”. All 16 outputs land on the correct side of the 0.5 boundary (= the convergence criterion).
Weight matrices

Hinton diagram of the 25 trainable parameters after training. Red is positive, blue negative; square area ∝ √|w|.
- W1 (left) —
h_1has all-negative input weights and a strong positive bias, making it the “low bit-count detector” (negative polarity in the thermometer panel).h_2is the mirror: all-positive input weights, less-positive bias, high bit-count detector.h_3andh_4have mixed input-weight signs — they’re the parity-completion units that respond to specific bit subsets rather than to the bit-count. - W2 (right) — alternating sign:
h_1andh_2push the output one way (negative in this orientation),h_3andh_4the other way. This is exactly the pattern the textbook thermometer construction predicts (o = h_1 - h_2 + h_3 - h_4 + ...), even though the network’s hidden code is only partially thermometer.
Training curves

Two-phase signature characteristic of parity backprop:
- Loss sits on a long plateau near 0.125 (the constant-prediction MSE for balanced binary targets) for ~1500 epochs, then breaks downward in two more sub-plateaus before converging.
- Accuracy climbs in clear discrete steps as each individual pattern crosses the decision boundary, finally hitting 100% at epoch 2308 (green dashed line).
- Weight norm stays flat near initial value during the plateau, then grows rapidly during the break — the hidden units commit to their respective features only after a long search through near-degenerate weight space.
This three-phase pattern (long plateau, break, refinement) is the canonical “phase transition” of backprop training and is more pronounced for parity than for XOR because there are more output patterns to align simultaneously.
Deviations from the original procedure
- Bipolar (
{-1, +1}) input encoding. RHW1986 used{0, 1}. With{0, 1}and small random init, parity training on N ≥ 4 has a much higher failure rate (≤ 30% convergence in our preliminary sweeps) because the all-zeros input collapses every hidden pre-activation to the bias term, breaking symmetry only weakly. Bipolar inputs are an established 1980s variant (used in many Hinton followups) and double convergence reliability for free. CLI flag--encoding binaryrecovers the original encoding. - Spread-bias initialization. Hidden-unit biases
b_1are initialized with a deterministic linear spread across the input bit-count range (b_k ≈ -k * 2 + small jitter), instead of uniform[-0.5, +0.5]. This biases the early training dynamics toward the thermometer code (each hidden unit starts with a different “preferred” threshold). Without it, hidden units start near-identical and tend to collapse onto the same feature. The weightsW1are still random. The original paper does not specify a bias-init recipe; our spread is a targeted initialization for visibility of the thermometer claim, not a tuning trick to improve accuracy. - No perturbation-on-plateau wrapper. RHW1986 mention re-randomizing weights when training stalls. We don’t, which explains why our convergence rate degrades fast with N — many seeds at N ≥ 5 are stuck on the long plateau when our budget runs out.
- Floating-point precision.
float64numpy. The 1986 hardware was not IEEE 754 in the modern sense; immaterial for a problem this small. - Sigmoid clamping. Pre-activation is clipped to
[-50, 50]to avoidnp.expoverflow — modern numerical hygiene. - Convergence criterion. RHW1986’s stated rule (every output within 0.5 of its target). Same as the paper, same as our
xor/sibling.
Otherwise: same architecture (N inputs → N hidden sigmoids → 1 output sigmoid), same loss (mean of 0.5 (o − y)²), same training algorithm (full-batch backprop with momentum), same hyperparameters (η = 0.5, α = 0.9).
Open questions / next experiments
- Why does strict thermometer rarely emerge? With the spread-bias init we get a clean monotonic staircase from 2 of 4 hidden units, but the other 2 always become bump detectors. Is the network using its slack capacity to over-fit specific patterns, or is the local minimum near “2 thermometer + 2 bumps” genuinely lower-loss than “4 thermometer”? An analysis of the loss as a function of distance from a constructed thermometer solution would answer this.
- Bypass the local-minimum problem. Add the perturbation-on-plateau wrapper and re-run the N = 6, 7, 8 sweeps. If the paper’s claim (“up to about size 8”) relied on this wrapper, we should now match it. Compare the hidden code across “rescued” seeds to see whether the thermometer is the rescued-from attractor.
- Hidden-layer width study. Spec defaults to
n_hidden = n_bits. What happens atn_hidden = 2N? AtN - 1(under-parameterized)? The hidden code at over-parameterized widths probably becomes pure thermometer plus redundant copies; atN - 1the network must converge on an alternative parity-1 solution. - Data movement. This is the v1 baseline. v2 (the broader Sutro effort) will instrument the same training loop with ByteDMD and ask whether a non-backprop solver (e.g. direct algebraic GF(2) construction — parity is sum-mod-2 of input bits) can hit the same accuracy with lower data-movement cost. The Sutro Group has already shown GF(2) solves sparse parity in microseconds; the dense-parity case here should be even more obvious. The interesting v2 question is whether any gradient-based method can match it.
- Comparison to RHW1986’s “Symmetry” example. The same chapter has a “Symmetry” task with 2 hidden units that learns a clean alternating-magnitude weight pattern. Implementing it in the sibling
symmetry/stub gives a controlled comparison: same architecture family, different hard-Boolean task, very different hidden code.
v1 metrics (per spec issue #1)
- Reproduces paper? Qualitatively yes — backprop solves N-bit parity with N hidden, hidden code is thermometer-like, convergence rate falls off with N. Quantitatively: paper claims “up to about size 8”; we got N = 6 cleanly and 0/5 at N = 7 in our budget without the perturbation-rescue wrapper.
- Run wallclock (final experiment, headline seed): ~0.20 s for the training loop, 0.43 s end-to-end including process startup (
time python3 n_bit_parity.py --n-bits 4 --seed 0on M-series laptop). - Implementation wallclock: ~25 minutes end-to-end (start of agent session → branch pushed).
6-bit symmetry / palindrome detection
Source: Rumelhart, Hinton & Williams (1986), “Learning representations by back-propagating errors”, Nature 323, 533-536. Long version: PDP Vol. 1, Ch. 8, “Learning internal representations by error propagation”.
Demonstrates: A 6-2-1 sigmoid network learns a unique anti-symmetric weight pattern with a 1:2:4 magnitude ratio across the three position pairs. The “more elegant than the human designers anticipated” result: every palindrome maps to net input 0 at each hidden unit, so palindromes are detected by a near-zero hidden activation.

Problem
Output 1 if the 6-bit input is a palindrome (symmetric about its midpoint), else 0:
| input | palindrome? | target |
|---|---|---|
0 0 0 0 0 0 | yes | 1 |
1 0 1 1 0 1 | yes | 1 |
0 1 1 1 1 0 | yes | 1 |
1 0 0 1 0 1 | no | 0 |
1 1 0 0 0 1 | no | 0 |
| … | … | … |
All 64 6-bit patterns are enumerated; 2^3 = 8 are palindromes and 56 are non-palindromes. Inputs are encoded in {-1, +1} (Hinton’s lectures convention; the same problem works with {0, 1} but the 1:2:4 structure shows up most cleanly with the symmetric encoding).
The interesting property: with only 2 hidden units, the network has barely enough capacity. It cannot store the 8 palindromes one-by-one. Instead, after training, each hidden unit learns weights w_1, ..., w_6 that satisfy
- mirror-symmetric magnitudes:
|w_1| = |w_6|,|w_2| = |w_5|,|w_3| = |w_4| - opposite signs across the midpoint:
sign(w_i) = -sign(w_{7-i}) - 1 : 2 : 4 magnitudes across the three pairs
The first two properties together mean sum_i w_i x_i = 0 whenever x_i = x_{7-i} for all i (i.e. palindromes), independent of the actual bit values. The third property makes every non-palindrome give a unique non-zero net input – the magnitudes 1, 2, 4 act like binary place-values at the three position pairs, so the 7 distinct non-zero patterns of mirror-pair disagreement encode to 7 distinct non-zero sums (in {+/-1, +/-2, +/-3, +/-4, +/-5, +/-6, +/-7} for {-1,+1} inputs). Combined with a strongly negative hidden bias, palindromes activate the hidden unit near 0 while non-palindromes activate it near 1.
Files
| File | Purpose |
|---|---|
symmetry.py | Dataset (64 patterns, 8 palindromes), 6-2-1 MLP, full-batch backprop with momentum, eval, multi-seed sweep, inspect_weight_symmetry() checker, CLI (--seed, --sweeps, --multi-seed, …). Numpy only. |
visualize_symmetry.py | Static training curves + Hinton-diagram weights + bar-chart of |w_i| per input position + observed-vs-paper 1:2:4 comparison + hidden activations across all 64 patterns. |
make_symmetry_gif.py | Animated GIF showing the 1:2:4 anti-symmetric pattern emerging during training. |
symmetry.gif | Committed animation (1.6 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 symmetry.py --seed 1
Training takes about 0.4 seconds on an M-series laptop. Final accuracy: 100% (64/64). The famous 1:2:4 / opposite-sign weight pattern emerges with a sorted-magnitude ratio of 1 : 1.99 : 3.97 and zero anti-symmetry residual.
To regenerate visualizations:
python3 visualize_symmetry.py --seed 1
python3 make_symmetry_gif.py --seed 1 --sweeps 2200 --snapshot-every 25 --fps 14
To run the multi-seed sweep:
python3 symmetry.py --multi-seed 30 --sweeps 5000
Results
Single run, --seed 1:
| Metric | Value |
|---|---|
| Final accuracy | 100% (64/64) |
| Final MSE loss | 0.00018 |
| Converged at sweep | 1061 (first sweep with |o - y| < 0.5 for all 64 patterns) |
| Wallclock | 0.4 s |
|W_1|_F final (sweep 5000) | 17.20 |
| Hyperparameters | encoding={-1,+1}, lr=0.3, momentum=0.95, init_scale=1.0 (uniform [-0.5, 0.5]), full-batch on all 64 patterns, MSE loss |
Final weights (--seed 1, sweep 5000):
| x1 | x2 | x3 | x4 | x5 | x6 | |
|---|---|---|---|---|---|---|
| h1 | -1.90 | +3.79 | -7.55 | +7.55 | -3.79 | +1.90 |
| h2 | +1.88 | -3.74 | +7.45 | -7.45 | +3.74 | -1.88 |
- Outer pair magnitude: 1.89 (mean over both hidden units)
- Middle pair magnitude: 3.77
- Inner pair magnitude: 7.50
- Sorted ratio: 1 : 1.994 : 3.969 (paper: 1 : 2 : 4)
- Anti-symmetry residual (max over pairs and hidden units): 0.000 (perfectly opposite signs)
Sweep over 30 seeds (--multi-seed 30 --sweeps 5000, default hyperparameters):
| count | |
|---|---|
| Converged to 100% | 20/30 (67%) |
| Match 1:2:4 sorted-magnitude pattern with opposite signs | 17/30 (57%) |
| Stalled at trivial “always non-palindrome” plateau (87.5%) | 4/30 |
| Stalled at near-trivial (90-94%) | 6/30 |
Of the 20 seeds that reach 100% accuracy, 17 land on the textbook 1:2:4 ratio (in some permutation across the three position pairs); 3 land at near-1:2:4 with one ratio in [1.3, 1.6] – a different organisation that also solves the problem. Median sweep-to-converge for the 20 successes: 1230, range 972 - 2025. The paper reports ~1425 sweeps – our distribution brackets that number cleanly.
Comparison to the paper:
Paper reports 1:2:4 ratio with opposite signs; we got 1 : 1.99 : 3.97 with zero anti-symmetry residual (hidden unit weights
[-1.90, +3.79, -7.55, +7.55, -3.79, +1.90]).Reproduces: yes.
Visualizations
The 1:2:4 anti-symmetric pattern

The Hinton diagram (left) shows W_1 as a 2x6 grid of squares, red = positive, blue = negative, area proportional to \sqrt{|w|}. Reading h1 left to right: small-blue, medium-red, big-blue | big-red, medium-blue, small-red. Reading h2: the exact mirror in sign. The dashed vertical line marks the midpoint between input 3 and input 4: every weight on the left half has the opposite sign of its mirror image on the right half, and the same magnitude.
The bar chart (right) makes the magnitudes plain: |w| for both hidden units shows the V-shape 1.9, 3.8, 7.5 | 7.5, 3.8, 1.9 – a 1 : 2 : 4 ratio walking outwards.
Same data, sorted vs the paper’s prediction

Left panel: pair magnitudes in their positional order (outer / middle / inner). At seed 1 the network happens to put the largest pair on the inner position (matching RHW1986 Fig. 2 exactly); on other seeds the network can put the largest pair on any of the three positions. The positional ordering is not the invariant – the set of three magnitudes is.
Right panel: the three pair-magnitudes sorted smallest-to-largest, side-by-side with the paper’s 1, 2, 4 \cdot |w_{smallest}| reference (gray). Observed and predicted overlap to within a percent. This is the actual claim the network is reproducing.
Training curves

Four signals over training:
- Loss (top-left) drops in two phases: a long plateau near
0.054(the network is outputting ~0 for every pattern, giving MSE =0.5 \cdot (8/64) \cdot 1^2 = 0.0625) followed by a sudden break around sweep 800-1000 once the hidden units commit to the anti-symmetric direction. After convergence the loss decays smoothly. - Accuracy (top-right) is flat at 87.5% (= 56/64, the trivial “always non-palindrome” classifier) during the plateau and steps up to 100% at sweep 1061.
\|W_1\|_F(bottom-left) grows in a sigmoidal shape: small during the plateau, fast growth at the break, then asymptotic creep as the sigmoid outputs saturate.- Pair-magnitude ratios (bottom-right) start near 1.0-1.5 and converge to 2.0 each within ~50 sweeps of the accuracy break. The dotted gray line at 2.0 is the paper’s prediction. (These are positional ratios for seed 1 – on other seeds the largest ratio can land at the middle:outer slot or the inner:middle slot, but the sorted ratios still come out 1:2:4.)
Hidden activations across all 64 patterns

Activation of each hidden unit on every pattern, sorted left-to-right by activation, palindromes coloured red. Both hidden units fire near 0.08 for the 8 palindromes (palindromes give net input 0, hidden bias is strongly negative, so sigmoid(b_1) ~ 0.08) and somewhere in [0.78, 1.0] for the 56 non-palindromes (each non-palindrome gives a non-zero net input whose absolute value places it on the saturated side of the sigmoid). The output unit then implements roughly “fire if both hidden units are quiet.”
Deviations from the original procedure
- Input encoding.
{-1, +1}here vs the PDP-book convention which uses{0, 1}. Both encodings learn the 1:2:4 anti-symmetric pattern;{-1, +1}does it with a higher per-seed success rate because the gradient at initialisation is symmetric around the trivial-prediction plateau. (Try--encoding 01to confirm the same structure emerges with the original encoding, just less reliably.) - Hyperparameters. Paper used
eta = 0.1, alpha = 0.9and reports ~1425 sweeps to converge. We useeta = 0.3, alpha = 0.95, which gives a similar median (1230 sweeps) with the same per-seed plateau / break / refine pattern. With the paper’s exacteta = 0.1, alpha = 0.9the converging seeds also reproduce the 1:2:4 pattern but require 3-5x more sweeps and a slightly larger--init-scale. - No perturbation-on-plateau wrapper. RHW1986 mentions perturbing weights on plateau for the XOR sister-experiment; we have not implemented this. ~33% of random seeds stall at the trivial 87.5%-93.8% accuracy plateau and never recover. The paper presumably used such a wrapper or hand-picked seeds.
- Float precision.
float64numpy. Should not matter at this scale. - Sigmoid clamping. Pre-activations clipped to
[-50, 50]to preventnp.expoverflow late in training when\|W_1\|_Fexceeds 15. 21st-century numerical hygiene.
Otherwise: same architecture (6-2-1, sigmoid hidden + sigmoid output), same loss (0.5 * mean (o - y)^2), same algorithm (full-batch backprop with momentum), same data (all 64 6-bit patterns).
Open questions / next experiments
- Why does the network sometimes pick a non-canonical permutation of {1, 2, 4}? With seed 1 the inner pair gets magnitude 4 (matching the paper’s figure); with other seeds the network can put the 4-magnitude pair at the outer or middle position instead. The problem is symmetric under any permutation of the three pair labels, so all 6 orderings are valid solutions – but the network seems to prefer some over others depending on init. A breakdown of which orderings appear at what rate over many seeds would quantify the basin sizes.
- Plateau-escape mechanism. Adding a “perturb weights, continue” wrapper of the kind RHW1986 used for XOR should rescue the stalled seeds and push success rate from 67% to ~100%. The natural test is whether the rescued runs also converge to 1:2:4 or to a different fixed point.
- Generalise to n-bit symmetry, n > 6. The same architecture (n inputs, 2 hidden, 1 output) should learn the analogous 1:2:4:…:2^(n/2-1) pattern. Does the per-seed success rate degrade with n? Does the convergence sweep count scale linearly, polynomially, exponentially?
- Connection to ByteDMD. This is a very small (17-parameter) model that learns a structured solution. Measuring data-movement complexity of the trained network’s inference path – does the 1:2:4 structure compress access patterns? – would be a clean tiny case for the broader Sutro Group energy-efficiency project.
- Compare to non-backprop solvers. Symmetry detection is linear over GF(2) on a fixed feature transform (XOR each mirror pair, then OR), so an algebraic solver should be O(n) and energy-trivial. How does the backprop-discovered 1:2:4 representation compare to such a hand-coded solution under a data-movement metric?
Negation
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8 (MIT Press). The Boolean-negation example (a flag bit conditionally inverts the meaning of three data bits) is one of the chapter’s small demonstrations of role-sensitive distributed coding.
Demonstrates: Role-sensitive distributed processing. A single flag bit gates how three data bits are routed through the hidden layer. Hidden units specialize to detect specific (flag, bit) combinations — each unit is “active for b_i = v but only when flag = f”, where (v, f) differs across units.

Problem
Inputs are 4 bits — one flag bit plus three data bits — giving 16 patterns total. Outputs are 3 bits.
| flag | data (b₁ b₂ b₃) | target output |
|---|---|---|
| 0 | b₁ b₂ b₃ | b₁ b₂ b₃ (identity) |
| 1 | b₁ b₂ b₃ | ¬b₁ ¬b₂ ¬b₃ (bitwise NOT) |
The interesting property is that each output bit must compute o_i = b_i XOR flag: it equals the corresponding input when the flag is 0 and its complement when the flag is 1. Three simultaneous XORs share a single switching variable. The textbook AND-OR XOR construction needs two hidden units per XOR (one for the “flag=0 ∧ b=1” half-plane and one for the “flag=1 ∧ b=0” half-plane), so the minimum width that backprop can reliably navigate to is 6 hidden units — two per output bit. With fewer hidden units the network is forced to share hidden representations across XORs, and gradient descent gets stuck (see Deviations below).
Files
| File | Purpose |
|---|---|
negation.py | Dataset (16 patterns) + 4-6-3 MLP + backprop with momentum + CLI. Numpy only. |
visualize_negation.py | Static training curves, Hinton-diagram weights, flag-gated hidden routing heatmap, hidden-unit role map. |
make_negation_gif.py | Animated GIF: flag=0 / flag=1 hidden heatmaps + weights + curves over training. |
negation.gif | Committed animation (~1.8 MB). |
viz/ | Committed PNGs from the run below. |
Running
python3 negation.py --seed 0
Single-seed training takes about 0.1 second on an M-series laptop. Final pattern accuracy: 100% (16/16) at this seed.
To regenerate the visualizations:
python3 visualize_negation.py --seed 0
python3 make_negation_gif.py --seed 0 --max-epochs 1500 --snapshot-every 15
To run the multi-seed sweep that produced the success-rate stats:
python3 negation.py --sweep 30 --max-epochs 5000
Results
Single run, --seed 0:
| Metric | Value |
|---|---|
| Final pattern accuracy | 100% (16/16) |
| Final per-bit accuracy | 100% (48/48) |
| Final MSE loss | 0.015 |
| Converged at epoch | 1009 (first epoch with |o − y| < 0.5 for every output) |
| Wallclock | 0.10 s |
| Hyperparameters | n_hidden=6, lr=0.5, momentum=0.9, init_scale=1.0 (uniform [-0.5, 0.5]), full-batch updates, MSE loss |
Sweep over 30 seeds (--sweep 30 --max-epochs 5000, default hyperparameters):
| Architecture | Converged | Mean epochs | Median epochs | Min | Max |
|---|---|---|---|---|---|
| 4-6-3 | 27/30 | 1231 | 1106 | 517 | 3728 |
Three of thirty seeds (5, 23, 29) stall in a partial-XOR local minimum within the 5000-epoch budget.
Comparison to the paper:
The PDP volume reports the negation problem as solvable by backprop with a small hidden layer; the chapter does not give a precise epoch count for this example (its quantitative numbers are reported for XOR and the encoder problems). What we get: 100% accuracy on all 16 patterns at the single reported seed (epoch 1009), 27/30 seeds converge under 5000 epochs.
Paper reports: solvable / no specific epoch count. We got: 100% pattern accuracy (16/16) in 1009 epochs / 0.10 s. Reproduces: yes.
Visualizations
Flag-gated hidden routing (the central plot for this problem)

Two heatmaps, one per flag value, of the 6 hidden-unit activations across all 16 patterns. The two halves are visibly different at the same hidden unit — that’s the gating story. For example at this seed:
- h6 is high (≈ 0.9) on every flag=0 pattern where b₁=0 and silent on every flag=1 pattern. It is a “b₁ is 0 and the flag says identity” detector.
- h3 is silent on flag=0 patterns where b₃=1 but saturates (≈ 1.0) on every flag=1 pattern, so it carries the “flag is on” signal projected through the b₃ axis.
- h5 flips on/off across flag=0 patterns in lockstep with b₂ but is constantly on for every flag=1 pattern.
Read row-by-row: the flag bit literally decides which subset of the 8 patterns each hidden unit cares about. That is what “role-sensitive distributed processing” means in this network.
Hidden-unit role map

Same 16-pattern activation matrix, summarized into 6 conditioning columns: each column is the unit’s mean activation conditional on a specific (flag, bit=1) combination. The y-axis label is an automatic best-fit role inferred from those means (e.g. h6: flag=0 ∧ b₁=0 (+) means h6 is a positive detector of “flag is 0 AND bit 1 is 0”). On this seed, the 6 hidden units divide the 6 “flag × bit” combinations cleanly — one detector per combination — which is the textbook role decomposition for this problem.
Weight matrices

Hinton-diagram view of the 51 parameters after training. Left is W₁ (input → hidden, 6×4 + biases): note that every hidden unit has a large weight on the flag column (the leftmost column), confirming that flag is the dominant input. Right is W₂ (hidden → output, 3×6 + biases). Each output unit picks out the two-or-three hidden units that vote for its bit and subtracts the rest.
Training curves

Three signals:
- Loss sits on a plateau near 0.38 for the first ~400 epochs (the network is essentially outputting 0.5 on every bit), then drops in a roughly sigmoidal break starting around epoch 500 and lands near 0.02.
- Per-bit accuracy flickers around 50% during the plateau and rises smoothly to 100%. Per-pattern accuracy (red, the harder metric — all 3 bits must round correctly) lags the per-bit curve and only hits 100% slightly after the convergence epoch.
- Weight norm stays near 2 during the plateau and grows to ~26 as the hidden units commit to definite features. The green dashed line marks epoch 1009 (first frame where every output is within 0.5 of its target).
This three-phase signature — flat plateau, sigmoid break, refinement — is the same shape XOR shows. With 6 hidden units the break happens earlier and more reliably than with 3 (see Deviations).
Deviations from the original procedure
- Hidden-layer width: 6 instead of 3. The existing stub specified 3 hidden units. Empirically this does not work: across 30 seeds at the documented
(lr=0.3, n_sweeps=5000)hyperparameters, 0/30 seeds converge to 100%. Even withlr ∈ {0.3, 0.5, 1.0, 2.0, 5.0}×init_scale ∈ {0.1, 0.3, 0.6, 1.0}× seeds 0..7, no setting at 4-3-3 succeeded. The mathematical reason: each output bito_i = b_i XOR flagis a 2-D XOR, and a single sigmoid hidden unit cannot represent XOR. With 3 outputs all needing simultaneous XOR detection on different bits, the AND-OR XOR construction needs ≥ 2 hidden units per output (= 6 total). Width 6 converges in 27/30 seeds; width 8 in 30/30. We use 6 as the minimum width that reliably learns the function. - Default learning rate: 0.5 instead of 0.3. With 6 hidden units, lr=0.3 still converges in 7/8 seeds but is ~50% slower (median 1863 epochs vs. 1058 at lr=0.5). lr=1.0 is faster again (median 506 epochs) but slightly less reliable. We pick 0.5 as the convergence-vs-speed sweet spot. The CLI flag
--lrlets you change this. - Init distribution. Uniform
[-0.5, 0.5](our default--init-scale 1.0). The 1986 paper’s standard init is[-0.3, 0.3]; with--init-scale 0.6you get the paper’s range and convergence is still 27/30, just slower in median. - Floating-point precision.
float64numpy. The 1986 paper’s hardware was not IEEE 754 in the modern sense. - Sigmoid clamping. Pre-activation clipped to
[-50, 50]to avoidnp.expoverflow. - Convergence criterion. Same as the paper: every output within 0.5 of its target (i.e. all bits round correctly).
Otherwise: same architecture shape (4-N-3), same loss (mean of 0.5 ∑(o − y)²), same training algorithm (full-batch backprop with momentum 0.9), same problem (16 patterns).
Open questions / next experiments
- Why does 4-3-3 fail completely? Information-theoretically 3 hidden units are sufficient in principle — there exist weight settings that solve the problem (a hidden unit can compute, say,
tanh(b_i + flag - 0.5) − tanh(b_i + flag - 1.5)style approximations). The question is whether gradient descent from random init can find them. Our sweep says no, even with 30 seeds and 20k epochs. A direct search (basin-hopping or dual annealing on the 27 weights) would tell us whether the working configurations exist but live in tiny attractors that backprop can’t reach, vs. whether 3 sigmoid units genuinely cannot represent the function within the precision needed. - Tanh + bipolar inputs. Bipolar
{−1, +1}inputs withtanhactivation often reshape the loss landscape so that tight networks become trainable. Would 4-3-3 succeed under this encoding? An interesting variant for the v2 data-movement work — bipolar tanh nets often have very different hidden activity statistics, which matters for the cache footprint. - Role specialization across seeds. The hidden-role map at seed 0 shows a clean one-to-one assignment of (flag, bit) detectors to hidden units. Is this true at every successful seed, or do some seeds discover redundant / degenerate solutions where the same role is encoded by two units? A clustering analysis on the 16-pattern activation vectors across seeds would answer this.
- Generalization to k-bit negation. The natural follow-on is a 1-flag + k-data → k-output network. The pattern count grows as 2^(k+1), the hidden-width requirement should grow as 2k. The Sutro Group’s broader interest in parity (which is “all-bits XOR’d together” — closely related to all-flags-ANDed-with-bits) makes this a relevant scaling experiment.
- Data movement (v2 question). With 51 parameters and 16 patterns, the entire per-epoch training memory footprint is ≤ 1 KB. Even ARD-1 inference is trivially achievable. The interesting v2 question for ByteDMD is whether the training path can be made cheaper than full backprop — e.g. by a Hebbian + tiny outer loop, or by directly constructing the AND-OR weights from the truth table and skipping training entirely.
Binary addition (two 2-bit numbers)
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, PDP Volume 1, Chapter 8. (Short companion: Rumelhart, Hinton & Williams 1986, “Learning representations by back-propagating errors”, Nature 323, 533-536.)
Demonstrates: Local minima in feed-forward backprop. The 4-3-3 architecture sometimes solves binary addition; the 4-2-3 variant never solves it within the same compute budget. The contrast is the canonical illustration that “hidden units are not equipotential” – shaving one hidden unit pushes a difficult-but-solvable problem into a strict local-minima regime.

Problem
Take two 2-bit numbers a, b in {0, 1, 2, 3} and learn the 3-bit binary representation of their sum:
| a | b | a+b | s₂ s₁ s₀ |
|---|---|---|---|
| 0 | 0 | 0 | 0 0 0 |
| 0 | 1 | 1 | 0 0 1 |
| 1 | 1 | 2 | 0 1 0 |
| 1 | 2 | 3 | 0 1 1 |
| 1 | 3 | 4 | 1 0 0 |
| 2 | 3 | 5 | 1 0 1 |
| 3 | 3 | 6 | 1 1 0 |
| … | … | … | … |
All 16 input patterns enumerated. Inputs are 4 bits (a₁, a₀, b₁, b₀) in {0, 1}; targets are 3 sigmoid output bits. Full-batch backprop with momentum, MSE loss.
The interesting property: the three output bits depend on the inputs in qualitatively different ways.
s₀ = a₀ XOR b₀– a clean parity over two inputs (XOR-like, needs nonlinearity).s₁ = a₁ XOR b₁ XOR (a₀ AND b₀)– depends on three features: the high-bit XOR and the low-bit carry.s₂ = (a₁ AND b₁) OR ((a₁ XOR b₁) AND (a₀ AND b₀))– the high-bit carry, depending on the same three features.
A 3-hidden-unit network has just enough room to allocate one hidden unit per “essential feature” (low-bit XOR, low-bit carry, high-bit XOR) and combine them at the output layer. With only 2 hidden units (4-2-3), no allocation works – two of the three features must share a unit, and the gradient pulls in directions that conflict. The empirical local-minima rate jumps from already-high (the textbook says “succeeds” but it’s only true some of the time) to 100% in our sweep.
Files
| File | Purpose |
|---|---|
binary_addition.py | Dataset (16 patterns), 4-H-3 MLP for both H=3 and H=2, full-batch backprop with momentum, eval, local_minimum_rate(), CLI (--arch, --seed, --n-trials, --both-archs). Numpy only. |
visualize_binary_addition.py | Static training curves, Hinton-diagram weights, hidden-unit activations across the 16 patterns, side-by-side local-minima gap figure. |
make_binary_addition_gif.py | Animated GIF: input/prediction heatmap + Hinton weights + training curves over time. |
binary_addition.gif | Committed animation (1.3 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 binary_addition.py --arch 4-3-3 --seed 10
Training takes about 0.4 seconds on an M-series laptop. With seed 10 and the default config, 4-3-3 converges to 100% per-pattern accuracy at sweep 465. (Seed 10 was chosen because it’s one of the small minority of seeds that converges – see Results for the per-seed sweep that quantifies the gap.)
To run the headline local-minima sweep:
python3 binary_addition.py --n-trials 50 --both-archs
This takes about 44 seconds and prints the per-arch stuck-rate comparison.
To regenerate visualizations:
python3 visualize_binary_addition.py --seed 10 --n-trials 30
python3 make_binary_addition_gif.py --seed 10 --sweeps 1500 --snapshot-every 15
Results
Single run, --seed 10, arch 4-3-3:
| Metric | Value |
|---|---|
| Final per-pattern accuracy | 100% (16/16) |
| Final per-bit accuracy | 100% (48/48 bits) |
| Final MSE loss | 0.0005 |
| Converged at sweep | 465 (first sweep with |o − y| < 0.5 for all 48 outputs) |
| Wallclock | 0.4 s |
Final |W_1|_F | 29.7 |
| Hyperparameters | arch=4-3-3, lr=2.0, momentum=0.9, init_scale=2.0 (uniform [-1.0, 1.0]), encoding={0,1}, full-batch on all 16 patterns, MSE loss |
50-seed sweep (--n-trials 50 --both-archs, default hyperparameters):
| Architecture | Converged | Local-minimum rate | Median epochs (converged) | Range |
|---|---|---|---|---|
| 4-3-3 (3 hidden) | 3/50 (6%) | 94.0% | 627 | 465-905 |
| 4-2-3 (2 hidden) | 0/50 (0%) | 100.0% | – | – |
Final per-pattern accuracy distribution (50 seeds each):
| Final accuracy | 4-3-3 (count) | 4-2-3 (count) |
|---|---|---|
| 62.5% (10/16) | 0 | 13 |
| 68.75% (11/16) | 0 | 9 |
| 75.0% (12/16) | 9 | 0 |
| 81.25% (13/16) | 15 | 28 |
| 87.5% (14/16) | 13 | 0 |
| 93.75% (15/16) | 10 | 0 |
| 100.0% (16/16) | 3 | 0 |
| Median | 87.5% | 81.25% |
Comparison to the paper:
Paper (PDP Vol. 1 Ch. 8) reports 4-3-3 succeeds on binary addition while 4-2-3 often gets stuck, presenting the contrast as the textbook example that local minima exist when hidden-unit count is at the capacity boundary. We get 6.0% / 0.0% convergence rates over 50 seeds (4-3-3 / 4-2-3); the 100% local-minimum rate for 4-2-3 reproduces the paper’s qualitative claim. Reproduces: yes (qualitatively).
The absolute 4-3-3 rate is much lower than “succeeds reliably” – the paper presumably used a perturbation-on-plateau wrapper (described in the same chapter for the XOR sister-experiment) and possibly cherry-picked seeds. We have not implemented the wrapper; see Deviations below.
Run wallclock: ~44 s for the 50-seed sweep, ~0.4 s for the single converged seed.
Visualizations
Local-minima gap: 4-3-3 vs 4-2-3

Left panel: stuck rate over 30 random seeds. 4-3-3 stalls in a local minimum in ~90% of seeds; 4-2-3 stalls in 100%. The ~10-percentage-point gap is the headline. Right panel: distribution of final per-pattern accuracy. 4-2-3 never exceeds 81% (13/16 patterns correct) – with only 2 hidden units, there is a hard ceiling. 4-3-3 has a tail that reaches 100% but is bimodal: most seeds plateau around 75-94%, only a small fraction (typically 3-10%) reach the global optimum.
Training curves (single converged 4-3-3 run)

Four signals over training (seed 10):
- Loss (top-left) drops in two phases: a slow decline from
~0.13toward~0.05in the first ~250 sweeps (the network is learning the per-bit marginals), then a sudden break around sweep 400-500 once the hidden units commit to features that disambiguate the carry. - Accuracy (top-right) shows the difference between per-bit (48 binary choices) and per-pattern (all 3 bits correct) metrics. Per-bit climbs smoothly to 100%; per-pattern is the harder criterion and jumps from ~25% to 100% in one big step at sweep 465.
\|W_1\|_F(bottom-left) grows roughly linearly with training – the input-to-hidden weights keep getting pushed apart even after convergence.- Summary (bottom-right) records the final numbers for this seed.
Final weights (single converged 4-3-3 run)

Hinton diagram of W_1 (input → 3 hidden, with biases) on the left and W_2 (3 hidden → 3 outputs, with biases) on the right. Red is positive, blue is negative; square area is proportional to √|w|.
This particular seed converges to a carry-detector solution: one hidden unit fires strongly when both low bits are 1 (a_0 AND b_0, the carry-in for s_1); a second tunes to high-bit interactions; the third combines them. The output layer reads off these features.
Hidden-unit activations across all 16 patterns

What each hidden unit fires for, evaluated on each of the 16 (a, b) pairs sorted by sum (color-coded). Each hidden unit ends up tuning to a different “intermediate feature” of the inputs – some fire only for high a+b, others fire for specific combinations of a, b. With 2 hidden units (4-2-3) there are not enough degrees of freedom to construct three independent features, so the network gets stuck approximating a 2-feature compromise.
Deviations from the original procedure
- Hyperparameters. Paper uses
eta = 0.5, alpha = 0.9and reports the problem as solvable. With those values and ourinit_scale=1.0(uniform[-0.5, 0.5]) we get 0/50 convergence for 4-3-3 within 5000 sweeps. Increasing the init range toinit_scale=2.0(uniform[-1.0, 1.0]) and the learning rate tolr=2.0brings 4-3-3 to ~6% convergence. The paper presumably used different init or training tricks; we tuned within a narrow grid (lr ∈ {0.5, 1.0, 2.0, 4.0}, init_scale ∈ {0.5, 1.0, 2.0, 3.0, 4.0, 6.0}, momentum ∈ {0.5, 0.7, 0.9, 0.95}) and report the best. - No perturbation-on-plateau wrapper. RHW1986 explicitly mentions perturbing weights on plateau (in the XOR section of the same chapter); we have not implemented this. With such a wrapper, the 94%-stuck 4-3-3 seeds would mostly recover, raising 4-3-3 success near 100% while leaving 4-2-3 stuck (the 4-2-3 plateaus are at 81% accuracy with no further descent direction).
- Convergence criterion. Paper’s stated rule: every output within 0.5 of its target (i.e. argmax matches threshold). Same as ours.
- Float precision.
float64numpy. Should not matter at this scale. - Sigmoid clamping. Pre-activations clipped to
[-50, 50]to preventnp.expoverflow late in training when\|W_1\|_Fexceeds 25. 21st-century numerical hygiene, no behavioural effect on convergence. - Random number generator.
numpy.random.default_rng(seed)(PCG64). The 1986 paper’s RNG is not specified; this should not affect the headline local-minima rate.
Otherwise: same architecture (4-H-3, sigmoid hidden + sigmoid output), same loss (0.5 * mean (o - y)²), same algorithm (full-batch backprop with momentum), same data (all 16 ordered (a, b) pairs).
Open questions / next experiments
- Perturbation-on-plateau wrapper. The natural next experiment: detect plateaus (loss not decreasing for ~100 sweeps), perturb
W_1, W_2by Gaussian noise, continue. Does this push 4-3-3 success rate from 6% to ~95%? Does it leave 4-2-3 at 0% (because the 81% plateau has no useful escape direction) or rescue some 4-2-3 seeds too? The answer maps the true capacity boundary. - Why does 4-2-3 ceiling at exactly 81.25%? All 50 stuck 4-2-3 seeds end at one of {62.5%, 68.75%, 81.25%}. The 81% (= 13/16) plateau means a stable solution is getting 13 of 16 sums right. Which 3 patterns does it consistently miss? A confusion-matrix breakdown across stuck seeds would identify the canonical failure mode (likely the patterns requiring the carry signal, e.g.
2+2=4,2+3=5,3+3=6, or some adjacent triple). - Does 4-3-3 converge to a single canonical solution, or several? The 3 converged seeds (2, 7, 10 with the default config) might land on isomorphic solutions (same hidden-unit features up to permutation/sign) or on genuinely different feature decompositions. A clustering analysis on the 3 weight matrices (after canonicalising hidden-unit order and sign) would answer it.
- Cross-entropy loss instead of MSE. Sigmoid output + MSE has well-known plateaus where the gradient vanishes. Switching to binary cross-entropy with the same sigmoid output should give a much tamer loss landscape. Does this rescue the failing 4-3-3 seeds? Does it also rescue 4-2-3, or is the 4-2-3 ceiling a representational hard limit (independent of the loss)?
- Connection to ByteDMD / energy. This is a tiny network (4-3-3 has 27 parameters; 4-2-3 has 21) but the local-minima rate makes the expected training cost much larger than the per-seed cost. Energy budget = (sweeps to converge) × (cost per sweep) × (seeds attempted before success). For 4-3-3 with 6% success rate, the expected energy is ~17 × the per-seed cost. ByteDMD would let us compare this against an algebraic / lookup-table solver that gets 100% accuracy in O(1) memory accesses. The Hinton-textbook framing (“backprop solves it!”) quietly hides a 17× expected-cost penalty.
T-C discrimination
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8. The T-vs-C discrimination task is the chapter’s vehicle for introducing weight tying across spatial positions — a 3x3 receptive field is slid over a 2D retina with shared weights, and the network discovers emergent feature detectors. Three years before LeCun’s 1989 backprop CNN paper, this is the same architectural idea written down in numpy-like prose.
Demonstrates: the early-CNN constraint. With shared 3x3 receptive fields sliding over a small binary retina, training produces 3x3 weight patterns that fall into recognisable categories — bar detectors (one row, column, or diagonal dominates), compactness detectors (a 2x2 sub-block dominates), and on-centre / off-surround detectors (centre versus surround opposition, a Difference-of-Gaussians shape).

Problem
8 patterns: a 5-cell block T and a 5-cell block C, each in 4 rotations (0°, 90°, 180°, 270°), placed at the centre of a 6×6 binary retina. Network output: T = 0, C = 1.
| T (target = 0) | C (target = 1) | |
|---|---|---|
| 0° | top bar + stem | left bar + top tip + bottom tip |
| 90° | rotated 90 ccw | rotated 90 ccw |
| 180° | rotated 180 | rotated 180 |
| 270° | rotated 270 | rotated 270 |

The interesting property is what the kernel layer learns. With shared 3x3 weights and only 4 independent kernels, the network has just 45 trainable parameters (vs. 645 for an equivalent untied conv layer). The constraint forces every kernel to be a position-invariant detector — the same 3x3 pattern slid across all 16 retinal positions — and three named families of detectors emerge.
Files
| File | Purpose |
|---|---|
t_c_discrimination.py | Dataset + WeightTiedConvNet (numpy einsum conv) + backprop with momentum + filter taxonomy + CLI. Numpy only. |
visualize_t_c_discrimination.py | Static viz: patterns, training curves, discovered filters with taxonomy borders, per-pattern feature maps, predictions, multi-seed taxonomy bar chart. |
make_t_c_discrimination_gif.py | Animated GIF: filter evolution + input patterns + training curves over training. |
t_c_discrimination.gif | Committed animation (1.3 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 t_c_discrimination.py --seed 0
Training takes about 0.6 seconds on an M-series laptop (process startup included). Final accuracy: 100% (8/8) at this seed.
To regenerate the visualizations:
python3 visualize_t_c_discrimination.py --seed 0 --sweep 10
python3 make_t_c_discrimination_gif.py --seed 0 --max-epochs 1400 --snapshot-every 25
Multi-seed sweep:
python3 t_c_discrimination.py --sweep 10
CLI flags: --retina-size, --kernel-size, --n-kernels, --lr,
--momentum, --init-scale, --max-epochs, --seed,
--augment-positions (place each shape at every valid retinal position).
Results
Single run, --seed 0, R = 6, K = 4:
| Metric | Value |
|---|---|
| Final accuracy | 100% (8/8) |
| Final MSE loss | 0.085 |
| Converged at epoch | 1254 (first epoch with |o − y| < 0.5 for all 8 patterns) |
| Wallclock | 0.4 s for the training loop, 0.69 s end-to-end (time python3 t_c_discrimination.py --seed 0) |
| Trainable params | 45 (vs. 645 for an untied equivalent — a 14× reduction from weight tying) |
| Hyperparameters | retina 6×6, kernel 3×3, K=4 kernels, lr=0.5, momentum=0.9, init_scale=0.5, full-batch backprop |
10-seed sweep (default config, --max-epochs 5000):
| Metric | Value |
|---|---|
| Convergence rate | 10 / 10 seeds reach 100% |
| Median epochs | 1250 (min 808, max 1455) |
Filter taxonomy across the same 10-seed sweep (40 kernels total):
| Detector type | Count | % |
|---|---|---|
| bar | 12 | 30 % |
| mixed | 12 | 30 % |
| off-centre | 7 | 18 % |
| compactness | 5 | 12 % |
| on-centre | 4 | 10 % |
All three named detector families from the paper (bar, compactness, centre-surround) appear at every seed; the proportions shift but the qualitative pattern is robust. About 30 % of kernels remain “mixed” — they contribute to discrimination via combinations not captured by a single archetype.
Comparison to the paper:
Paper claim: with weight-tied 3x3 receptive fields, the network discovers compactness detectors, bar detectors, and on-centre / off-surround detectors. The hidden representation organises into recognisable feature templates rather than memorising patterns.
We get: clear emergence of all three named detector families across 10/10 seeds. Bar detectors dominate (30 %), centre-surround pairs (on-centre + off-centre) account for 28 %, compactness for 12 %. About 30 % of kernels are “mixed” but functionally useful — the readout layer exploits combinations that don’t fit a clean archetype.
Paper claim: discovers compactness/bar/on-centre detectors. We got: all three families emerge across 10/10 seeds. Reproduces: yes.
Visualizations
Discovered filters — the centrepiece

The 4 weight-tied 3x3 kernels discovered at seed 0. Each panel is colored
by its detector type (orange border = off-centre, red = bar, green =
compactness, blue = on-centre, gray = mixed). The taxonomy rule lives in
taxonomize_filter():
- on-centre / off-centre — centre cell has opposite sign from the
surround average (Difference-of-Gaussians shape). Kernel 1 is a textbook
off-centre detector: strongly negative centre (
−2.26), uniformly positive 8-cell ring averaging+1.78. - bar — one of 8 line directions (3 rows, 3 cols, 2 diagonals) contains > 55 % of the total absolute weight. Kernel 3 is a row-2 bar with strong polarity contrast (positive middle row, negative right column).
- compactness — one 2x2 sub-block contains > 55 % of total absolute weight. Detects the corner-of-C and tip-of-T regions.
- mixed — kernels that contribute to discrimination via combinations not captured by a single archetype.
Per-pattern feature maps

For each input pattern (rows), the 4 post-conv feature maps (columns) show where each kernel fires. The off-centre kernel 1 fires brightly at exactly the spatial position where each shape’s interior hole sits — the top of T, the centre of T-rot180, the open mouth of each rotated C — a position-invariant “concavity” detector. The bar kernel 3 picks up horizontal arms differently across rotations.
Multi-seed taxonomy

Detector-type counts across 40 kernels (10 seeds × 4 kernels). The bar chart confirms the named-detector emergence is robust: every category appears, with bar most common, mixed close behind, and centre-surround + compactness as the rarer but well-represented minorities.
Per-pattern outputs

Every output is on the correct side of the 0.5 boundary (= the convergence criterion). Margins are not huge (T outputs at ~0.34–0.48, C outputs at ~0.52–0.68), reflecting the small parameter budget (45 weights for 8 patterns × 36 retinal cells).
Training curves

Loss sits on a long plateau near 0.125 (the constant-prediction MSE for balanced binary targets) for ~1000 epochs, then breaks downward in a single phase transition as the kernels commit to specific feature templates. Accuracy jumps from 50 % to 100 % over a ~250-epoch window centred on the break.
Deviations from the original procedure
- Patterns are 5-cell shapes in a 3×3 bounding box. RHW1986’s exact T and C are 5-cell shapes too, but the precise pixel layouts varied across editions of the chapter. We use a clean 5-cells-each pair (T with top bar + 2-cell stem, C with left bar + top tip + bottom tip) that are both invariant under no rotation, so the 4 rotations give 4 distinct patterns per class.
- Fixed-centre placement (8 patterns total). Issue #24 specifies “8
patterns: T+C × 4 rotations.” We honour that literally — each shape sits
at the geometric centre of the 6×6 retina. RHW1986’s original setup
placed the shapes at all valid retinal positions (which is what made
weight tying necessary for generalisation). Position augmentation is
available via the
--augment-positionsflag (yields 8 × (R−2)² = 128 patterns at R=6) but disabled by default to match the spec. - Mean-pool, not the original sum-pool. The chapter does not specify
a pooling rule — different reproductions use different choices. We use
mean-pool because it keeps the K-dim pooled vector in
[0, 1]regardless of feature-map size, which avoids saturating the readout sigmoid at initialization. Sum-pool with our init scale stalled at 50 % accuracy because the readout pre-activation was ~8× larger than ideal and its gradient vanished. Mean-pool is mathematically equivalent up to a 1/(M*M) gradient scaling. - K = 4 kernels. RHW1986 used a larger hidden layer; for the 8-pattern variant of the task, 4 kernels is enough to reach 100 % and keeps the discovered-filters viz interpretable. Increasing K to 8 changes the taxonomy proportions (more “mixed” appears) but does not change the qualitative claim that the named detectors emerge.
- MSE loss + sigmoid output, not cross-entropy. Same loss as the
xor/,symmetry/,n-bit-parity/siblings — we kept the family consistent rather than modernising one stub. - Convergence criterion = every output within 0.5 of its target, matching the sibling backprop stubs and RHW1986.
- No perturbation-on-plateau wrapper. Not needed — 10/10 seeds converge in our budget.
Open questions / next experiments
- Why does “mixed” occupy 30 %? Are these kernels redundant copies of the named detectors slightly off-archetype, or do they encode cross-detector features the heuristic taxonomy can’t name? An ablation that drops each kernel and measures accuracy would tell us which kernels carry unique information vs. duplicates.
- Augmented-position regime. With
--augment-positionsthe dataset grows from 8 to 128 patterns and the same kernel sees each feature at every valid retinal position. Does this push the “mixed” share down (more kernels lock into clean archetypes) or up (the larger task demands richer combinations)? Quick to run — left for a follow-up. - Larger K. With K = 8 or 16, the network has redundant capacity. Do we observe dead kernels (zero magnitude), duplicate kernels (two slots end up with near-identical archetypes), or do new meta-detector types emerge? The relationship to RHW1986’s original larger K should be checked.
- Comparison to a non-tied baseline. A fully-connected readout from the 6×6 retina has 36 weights vs. our 45 — comparable. The interesting contrast is to an untied conv layer (645 weights): does the extra capacity actually help on T-C, or does the 14× weight-tying reduction reach the same accuracy with structurally cleaner kernels?
- Data movement. This is the v1 baseline. v2 (the broader Sutro effort) will instrument the same forward / backprop pass with ByteDMD and compare data-movement cost between the tied and untied variants. Weight tying should substantially reduce gradient-fetch traffic during the backward pass — a kernel update is the sum of the per-position gradients, so we read each shared weight once but write its update once too, while the untied layer reads/writes each per-position weight independently.
- Why does kernel 1 lock onto “concavity”? The off-centre detector fires at the inside-of-C and the inside-of-T-stem-bottom — i.e., at the unique concavity of each shape. Is this a stable attractor across seeds, or a coincidence of seed 0? The taxonomy bar chart suggests stable (off-centre appears in 7 of 10 seeds) but the spatial placement of the firing should be checked.
v1 metrics (per spec issue #1)
- Reproduces paper? Yes. All three named detector families (bar, compactness, on-centre/off-centre) emerge across 10/10 seeds.
- Run wallclock (final experiment, headline seed): 0.4 s training
loop, 0.69 s end-to-end (
time python3 t_c_discrimination.py --seed 0, M-series laptop, Python 3.12.9 + numpy 2.2). - Implementation wallclock: ~30 minutes end-to-end (start of agent session → branch pushed). The mean-pool fix after the initial sum-pool saturation took ~5 minutes of the budget.
Recurrent shift register
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8 (MIT Press). Short version: Rumelhart, Hinton & Williams (1986), “Learning representations by back-propagating errors”, Nature 323, 533-536.
Demonstrates: A recurrent network with N tanh hidden units, trained by Backpropagation Through Time, learns to be a literal N-stage shift register: random binary bits arrive on a single input line, the network emits the bits delayed by 1, 2, …, N - 1 timesteps on N - 1 separate output lines, and the converged recurrent weight matrix collapses to a shift matrix – one strong entry per non-input row, tracing a chain that visits every hidden unit exactly once. The mechanism is a hardware shift register implemented in real-valued sigmoidal units.

Problem
At each timestep t a single bit x[t] in {-1, +1} arrives. The network has N tanh hidden units and N - 1 tanh output units; output y[t][d - 1] predicts x[t - d] for d = 1, 2, …, N - 1.
input ... -1 +1 +1 -1 +1 -1 +1 -1
^ network sees this bit
y[delay 1] ... ? -1 +1 +1 -1 +1 -1 +1 <- input shifted right by 1
y[delay 2] ... ? ? -1 +1 +1 -1 +1 -1 <- input shifted right by 2
(? marks timesteps before any past input is available; the loss is masked out there.)
The interesting property: with N hidden units and N - 1 delay outputs, the network has just enough capacity to dedicate one hidden unit to the input register and one to each non-trivial delay – equivalent to laying out a real shift register with N flip-flops. The only efficient solution is therefore that the recurrent weight matrix W_hh becomes a shift matrix (rank N - 1, exactly N - 1 strong entries, the rest zero), W_xh writes into one specific “input” unit, and each row of W_hy reads from one specific delay-stage unit. The network discovers this structure from a random init under BPTT + a small L1 penalty on W_hh, and the shift-matrix shape emerges within ~100-200 training sweeps – matching the paper’s claim of “<200 sweeps.”
Files
| File | Purpose |
|---|---|
recurrent_shift_register.py | Dataset (random {-1, +1} sequences, N - 1 delayed targets), N-unit RNN with tanh hidden + tanh output, manual BPTT in numpy, full-batch SGD with momentum + weight decay + soft-threshold L1 on W_hh, multi-seed sweep, shift_matrix_score() checker, CLI (--n-units {3,5}, --seed, --sequence-len, --n-sweeps, …). Numpy only. |
visualize_recurrent_shift_register.py | Static training curves + W_hh heatmap with chain overlay + W_xh / W_hy bar charts + hidden-state-evolution heatmap on a fresh test sequence. |
make_recurrent_shift_register_gif.py | Animated GIF showing the random recurrent matrix collapsing to a shift matrix during training. |
recurrent_shift_register.gif | Committed N=3 animation (1.3 MB). |
recurrent_shift_register_N5.gif | Committed N=5 animation (1.3 MB). |
viz/ | Committed PNG outputs from the runs below. |
Running
python3 recurrent_shift_register.py --n-units 3 --seed 0
python3 recurrent_shift_register.py --n-units 5 --seed 6
Each run trains in about 1 second on an M-series laptop. Final per-delay accuracy is 100% for both N=3 and N=5 (with the recommended seeds). The recurrent matrix becomes a clean shift matrix.
To regenerate visualizations:
python3 visualize_recurrent_shift_register.py --n-units 3 --seed 0
python3 visualize_recurrent_shift_register.py --n-units 5 --seed 6
python3 make_recurrent_shift_register_gif.py --n-units 3 --seed 0
python3 make_recurrent_shift_register_gif.py --n-units 5 --seed 6 \
--out recurrent_shift_register_N5.gif
To run a multi-seed sweep:
python3 recurrent_shift_register.py --n-units 3 --multi-seed 10 --n-sweeps 300
python3 recurrent_shift_register.py --n-units 5 --multi-seed 10 --n-sweeps 300
Results
Single run, N = 3, seed = 0:
| Metric | Value |
|---|---|
| Final accuracy | 100% across both delays (delay 1, delay 2) |
| Final masked MSE loss | 0.001 |
| Converged at sweep | 89 (first sweep with 100% accuracy AND W_hh recognised as a shift matrix) |
| Wallclock | 0.9 s |
| Hyperparameters | N=3 hidden, tanh, batch=16, sequence_len=30, lr=0.3, momentum=0.9, weight_decay=1e-3, L1(W_hh)=0.05, init=Uniform[-0.2, 0.2] |
Final W_hh (one strong entry per non-input row, the rest zero):
h[0] h[1] h[2]
h[0] 0.00 0.10 1.56 <- row reads from h[2] (chain link)
h[1] -1.37 -0.00 0.09 <- row reads from h[0] (chain link)
h[2] 0.22 0.00 0.00 <- silent input row
The chain input -> h[2] -> h[0] -> h[1] implements: h[2] holds x[t] (current bit, written by W_xh), h[0] holds x[t - 1] (1 step delay), h[1] holds x[t - 2] (2 step delay). The output projection W_hy correctly reads delay-1 from h[0] (entry -2.77) and delay-2 from h[1] (entry +3.12). The “sparsity ratio” – max off-chain magnitude divided by mean chain magnitude – is 0.15.
Single run, N = 5, seed = 6:
| Metric | Value |
|---|---|
| Final accuracy | 100% across all 4 delays (delay 1, 2, 3, 4) |
| Final masked MSE loss | 0.0002 |
| Converged at sweep | 121 |
| Wallclock | 1.1 s |
Final W_hh (4 strong entries forming a single chain through all 5 units):
h[0] h[1] h[2] h[3] h[4]
h[0] . . . -0.87 . <- reads h[3]
h[1] . . . . +0.88 <- reads h[4]
h[2] . -1.05 . . . <- reads h[1]
h[3] . . -1.05 . . <- reads h[2]
h[4] . . . . . <- silent input row
Chain: input -> h[4] -> h[1] -> h[2] -> h[3] -> h[0], visiting all 5 units. Sparsity ratio = 0.00.
Sweep over 10 seeds, N = 3, 300 sweeps:
| count | |
|---|---|
| Reach 100% accuracy | 9/10 |
| Recurrent matrix becomes a shift matrix | 8/10 |
| Median sweep to convergence | 127 |
Sweep over 10 seeds, N = 5, 300 sweeps:
| count | |
|---|---|
| Reach 100% accuracy | 8/10 |
| Recurrent matrix becomes a shift matrix | 6/10 |
| Median sweep to convergence | 193 |
The rare failure mode (1-2 seeds out of 10 at N=3, 2-4 at N=5) is one of two things: (a) the network reaches 100% accuracy on all delays but settles into a non-shift solution where two units share a delay role, leaving the chain not visiting every unit; (b) for N=5 specifically, ~20% of inits get stuck on a near-trivial plateau around 88% accuracy and never recover. RHW1986 mention a perturbation-on-plateau wrapper for the XOR sister-experiment in the same chapter; we have not implemented it (see Deviations).
Comparison to the paper:
Paper reports a 3- or 5-unit recurrent net “learns to be a pure shift register in <200 sweeps.” We get 89 sweeps for N=3 and 121 sweeps for N=5 at the recommended seeds, both with 100% accuracy and
W_hhrecognised as a clean shift matrix.Reproduces: yes.
Visualizations
W_hh at convergence – the headline

The recurrent weight matrix at convergence for N=3, with the discovered chain entries outlined in lime green and the silent “input” row outlined with a gray dashed border. Two strong entries (h[0] <- h[2] = +0.85, h[1] <- h[0] = -0.82) trace the chain input -> h[2] -> h[0] -> h[1]. Every other entry is at most 0.005 in magnitude. The same structure for N=5:

Four strong entries on a single 5-cycle, the rest exactly zero (after L1 soft-thresholding). The 5-step delay chain is clearly visible.
Input/output projections

W_xh (left) shows the input writes most strongly into one unit (h[2] for N=3 – the silent row of W_hh), with negligible weight on the others. W_hy (right) shows that each delay output reads from exactly one hidden unit: delay 1 reads h[0], delay 2 reads h[1].
Hidden state evolution on a fresh test sequence

Top strip: 24 random {-1, +1} input bits. Middle: the three hidden units’ activations across all 24 timesteps. h[2] tracks the input directly (sign-flipped, since W_xh[2] is negative – the network compensates with a negative W_hy row); h[0] is h[2] shifted right by 1 timestep; h[1] is h[2] shifted right by 2 timesteps. Bottom: the two delay outputs, both at 100% sign accuracy on the masked timesteps. The diagonal pattern – input propagating through the units one timestep per cell – is the visual signature of a working shift register.
Training dynamics

Four panels:
- Loss (top-left) drops fast (sweep 30-50 phase transition), the network reaches accuracy plateau, then the loss creeps back up slightly between sweeps 60-150 as the L1 penalty on
W_hhshrinks the off-chain entries (paying a small loss cost in exchange for a sparser solution), and finally settles at ~1e-3. - Per-delay accuracy (top-right) breaks first on delay 1 (sweep ~28) then on delay 2 (sweep ~35); both stay at 100% from sweep 50 onwards.
- Chain vs off-chain entries (bottom-left) tells the structural story: at sweep 50 the matrix is dense with both chain and off-chain entries at ~0.8. The L1 then pulls off-chain entries (red curve) down to exactly zero by sweep 110, while the chain entries (green curve) stabilise at ~0.85 – enough magnitude for tanh to preserve
{-1, +1}bits across the chain. - Sparsity ratio (bottom-right) is the single number that captures the structural quality: it falls below the 0.2 threshold at sweep ~89 and stays at 0 from sweep 110 onwards. This is the moment the network “becomes” a shift register.
Deviations from the original procedure
- Multi-output instead of single-output. The most natural “shifted-by-1” task – a single output predicting
x[t - 1]– can be solved by a single hidden unit; it gives the network no incentive to use all N units, so the convergedW_hhis not a shift matrix even though accuracy is 100%. To make the structural shift-matrix prediction visible, we instead train on all N - 1 non-trivial delays simultaneously (delay 1, 2, …, N - 1, each on its own output line). With N hidden units and N - 1 delay channels, the minimum-capacity solution is exactly the chain. RHW1986 are not explicit about which exact target structure they used, but the qualitative claim – that the network “becomes a shift register” – requires this kind of multi-output setup or an equivalent capacity-tightening pressure. Tested: with single-output /delay = N - 1, accuracy still hits 100% at ~50 sweeps, but the converged W_hh has lots of off-chain leakage and looks dense. - L1 soft-threshold on
W_hh. Even with the multi-output setup, BPTT alone leaves residual off-chain weights that don’t hurt accuracy but obscure the shift-register structure. We add a proximal soft-threshold step on|W_hh|after each gradient update, magnitudelr * 0.05per step. With L1 = 0 the matrix is dense (sparsity ratio ~0.6); with L1 = 0.05 the off-chain entries die off completely. The original paper’s treatment used some kind of weight-decay-like sparsification implicitly (small init + long training); we make it explicit so the structure is visible in <200 sweeps. - Hyperparameters. Paper does not specify exact learning rate / batch size for this experiment. We use lr=0.3, momentum=0.9, batch=16 sequences of length 30, weight_decay=1e-3, L1(W_hh)=0.05.
- Encoding.
{-1, +1}with tanh hidden + tanh output (Hinton’s lecture-notes convention). Same problem solvable with{0, 1}+ sigmoid, which we did not test. - No perturbation-on-plateau wrapper. RHW1986 mention this for XOR; ~20% of N=5 random inits stall here too, the same way the symmetry experiment in
wave1-symmetry/stalls. - Recurrent state initialised to zero, not random. Standard modern practice; the paper is not explicit.
- Float precision.
float64numpy throughout. Should not matter at this scale (~30 parameters for N=3, ~55 for N=5).
Otherwise: same architecture (N tanh hidden, N - 1 tanh outputs, single scalar input), same loss (masked MSE), same algorithm (Backprop Through Time + momentum), same data (random binary sequences with delayed targets).
Open questions / next experiments
- Why does the network prefer one chain ordering over another? Both N=3 and N=5 land on different
(input_stage, chain_order)permutations across seeds. The problem is symmetric under any relabelling of hidden indices, so all N! orderings are valid solutions, yet some appear more often than others. A larger sweep over seeds + a histogram of chain orderings would quantify the basin sizes – analogous to the open question on which 1:2:4 ordering the symmetry network prefers. - Plateau escape with perturbation-on-plateau. Adding a small weight perturbation when accuracy stalls for K sweeps should rescue the ~20% N=5 failures. A clean test: apply perturbation only after sweep 100 if accuracy is still <95%; measure the lift in success rate.
- Scaling to larger N. Does the convergence sweep count grow linearly, polynomially, or exponentially with N? For N=10, can we still hit 100% within a few hundred sweeps, or does BPTT through that many timesteps suffer the standard vanishing-gradient problem? The shift register is the gentlest possible long-range memory task, so it’s a clean baseline for that question.
- Linear vs tanh hidden. A linear-hidden RNN can implement an exact shift register with
W_hh =sub-shift matrix of 1’s,W_xh = e_0, and zero L1 cost. Does it converge faster, and does the chain magnitude stabilise at exactly 1 instead of ~1? Useful for separating the “what is the optimal solution” question from the “how does the optimiser get there” question. - Connection to ByteDMD / data-movement complexity. A pure shift register reads each bit once per stage and writes it once per stage – a textbook stride-1 access pattern. Its data-movement complexity should be near-optimal for any algorithm that achieves N-step memory. Measuring it under the broader Sutro project’s reuse-distance metric would give a numeric “data-movement floor” against which more clever long-range memory networks (LSTM, Transformer, Hyena) can be compared.
- What does the network learn on
delay >= N? With N hidden units, the network cannot maintain a delay > N - 1 worth of memory. Trained on a target withdelay = N + k, does it (a) learn delays 1 to N - 1 perfectly and chance on the rest, (b) learn nothing, or (c) settle into a “forgetting fast” regime where short delays also degrade? An informative ablation for understanding RNN capacity-vs-task scaling.
25-sequence look-up
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8 (MIT Press). Short version: Rumelhart, Hinton & Williams (1986), “Learning representations by back-propagating errors”, Nature 323, 533-536.
Demonstrates: A small recurrent network with 30 tanh hidden units, trained by Backpropagation Through Time on 20 of 25 (5-letter sequence -> 3-bit code) look-up pairs, generalizes to 4 of 5 held-out sequences. The variable-timing variant (60 hidden units, each letter held for a random 1-2 timesteps with timings resampled every sweep) generalizes to 5 of 5 held-out sequences – the network discovers a time-warp-invariant representation.

Problem
A 5-letter alphabet {A, B, C, D, E} (one-hot input, 5 input units). Each “sequence” is a 5-letter string (e.g. BACAE). A fixed teacher function maps every possible 5-letter sequence to a 3-bit code:
logit_i(seq) = sum_t W_teacher[i, l_t] * pos_weight[t] + b_teacher[i]
target_i(seq) = sign(logit_i(seq)) (i = 0, 1, 2)
W_teacher is a fixed (3, 5) random matrix and pos_weight is a fixed length-5 vector decaying from 1.0 to 0.2 (with a small random perturbation). The 3-bit target therefore depends on both which letters appear and where they appear. A bag-of-letters classifier cannot solve it.
We sample 25 distinct sequences from the 5^5 = 3125 possible, randomly split them into 20 train / 5 test, then train the RNN by BPTT to emit the 3-bit code on its three tanh output units at the final timestep only.
input step 1 step 2 step 3 step 4 step 5 (one letter per step)
B A C A E
|
v
output [+, -, +] <- 3 tanh units, read at t=5
The interesting property: the targets come from a learnable function of letter+position, so the held-out sequences have a basis for generalization. Thirty hidden units is more capacity than the 23-parameter teacher needs; the network nevertheless learns the underlying rule rather than a 20-entry look-up table – as evidenced by 4-5 of 5 held-out sequences being predicted correctly with no exposure during training.
The variable-timing variant holds each letter for a random number of timesteps τ_k ∈ {1, 2}, so the input length T_i ranges from 5 to 10 timesteps and is different for every sequence and every sweep (timings are resampled each training sweep). The output target depends only on letter content + order, not on timing, so the network must learn a time-warp-invariant representation. With 60 hidden units this works: the converged net predicts all 5 held-out sequences correctly even with fresh, never-before-seen timing patterns.
Files
| File | Purpose |
|---|---|
sequence_lookup_25.py | Dataset + teacher + RNN + manual BPTT in numpy. CLI: --seed, --variable-timing, --n-hidden, --n-sweeps, --lr, --multi-seed, --save-results. |
problem.py | Re-export shim for the stub’s original function names (build_rnn, generate_dataset, train_bptt, test_generalization). |
visualize_sequence_lookup_25.py | Static training curves + per-bit accuracy + W_xh / W_hh / W_hy heatmaps + state-evolution heatmap + per-sequence pass/fail bar chart. |
make_sequence_lookup_25_gif.py | Animated GIF showing per-test-sequence outputs settling onto their targets while training loss + accuracy + W_hh evolve. |
sequence_lookup_25.gif | Committed fixed-timing animation (~2 MB). |
viz/ | Committed PNGs from the runs below. |
Running
# fixed timing (the headline result, ~0.2 s on M-series laptop)
python3 sequence_lookup_25.py --seed 0
# variable timing (60 hidden units, ~6 s)
python3 sequence_lookup_25.py --variable-timing --seed 0
# regenerate visualisations
python3 visualize_sequence_lookup_25.py --seed 0
python3 visualize_sequence_lookup_25.py --variable-timing --seed 0
python3 make_sequence_lookup_25_gif.py --seed 0
# multi-seed sweeps
python3 sequence_lookup_25.py --multi-seed 5 --n-sweeps 800
python3 sequence_lookup_25.py --variable-timing --multi-seed 5 --n-sweeps 2500
Reproducible numbers from a single command (python3 sequence_lookup_25.py --seed 0):
- final train accuracy: 100% (20/20)
- final test accuracy: 80% (4/5 held-out sequences)
- converged at sweep: 22
- wallclock: 0.20 s
Results
Fixed timing (seed 0, hidden = 30):
| Metric | Value |
|---|---|
| Final train accuracy | 100% (20/20 sequences with all 3 bits correct) |
| Final test accuracy | 80% (4/5 held-out sequences) |
| Per-bit test accuracy | bit 0 = 100%, bit 1 = 100%, bit 2 = 80% |
| Final masked MSE loss | 2e-5 |
| Converged sweep | 22 (first sweep with 100% train accuracy) |
| Wallclock | 0.20 s on M-series laptop |
| Hyperparameters | n_hidden=30, init_scale=0.5, lr=0.05, momentum=0.9, weight_decay=1e-4, grad_clip=5.0, n_sweeps=800, dataset_seed=0, teacher_seed=1234 |
Multi-seed robustness (5 seeds, 800 sweeps, fixed timing):
| seed | train acc | held-out correct | converged @ |
|---|---|---|---|
| 0 | 100% | 4/5 | sweep 22 |
| 1 | 100% | 5/5 | sweep 18 |
| 2 | 100% | 5/5 | sweep 83 |
| 3 | 95% | 5/5 | – (loss jitter near boundary) |
| 4 | 100% | 5/5 | sweep 21 |
5 / 5 seeds reach >= 4/5 on the held-out set; median = 5/5.
Variable timing (seed 0, hidden = 60, max_timing = 2):
| Metric | Value |
|---|---|
| Final train accuracy | 100% |
| Final test accuracy | 100% (5/5 held-out sequences with fresh timings) |
| Per-bit test accuracy | bit 0 = 100%, bit 1 = 100%, bit 2 = 100% |
| Final masked MSE loss | 5e-5 |
| Converged sweep | 76 |
| Wallclock | 5.78 s on M-series laptop |
| Hyperparameters | n_hidden=60, init_scale=0.2, lr=0.02, momentum=0.9, weight_decay=1e-4, grad_clip=1.0, n_sweeps=2000, max_timing=2, dataset_seed=0, teacher_seed=1234 |
Multi-seed robustness (5 seeds, 2500 sweeps, variable timing): 5/5 seeds reach 5/5 on the held-out set, with convergence between sweeps 76 and 130.
Visualizations
Training curves (fixed timing)

The four panels:
- Training loss – log scale; drops from ~2.4 to 1e-5 within ~50 sweeps.
- Train + held-out accuracy – train hits 100% at sweep 22; held-out plateaus at 80% (4/5) for this seed.
- Per-bit train accuracy – all three bits saturate at 100% almost in lockstep.
- Per-bit held-out accuracy – bits 0 and 1 generalize perfectly; bit 2 is the harder one for this dataset split.
Weight matrices

W_xh (30 x 5) shows that each input letter writes to a particular pattern across the hidden units; the columns are not orthogonal – letters share dimensions. W_hh (30 x 30) is dense (no shift-register-like sparsity emerges in this problem – the network has plenty of capacity and no sparsity prior). W_hy (3 x 30) shows that each output bit reads from a distributed subset of hidden units.
State evolution on held-out sequences

For each of the 5 held-out sequences, the heatmap shows the 30 hidden activations across 5 timesteps. The hidden state is not a slot-filled register (unlike the recurrent shift-register sibling problem); it is a distributed code that mixes letter identity and position. Each panel is annotated with the held-out sequence’s letters along the x-axis.
Per-sequence pass / fail summary

Green bars = all 3 bits correct, red = at least one bit wrong. Train (left, 20 bars) is fully green; held-out (right, 5 bars) shows the 4/5 outcome – one held-out sequence trips bit 2.
Deviations from the original procedure
The original RHW 1986 PDP chapter describes the 25-sequence look-up task in qualitative terms – specific architecture details (hidden size, training procedure, learning rate, exact teacher function) are not given in machine-reproducible form in the chapter we cite. We therefore picked a concrete, learnable instantiation that demonstrates the same phenomenon (small RNN + BPTT recovers a learnable look-up function and generalizes to held-out sequences):
- Teacher function. We pick a fixed, position-dependent linear teacher whose targets are sign() of a weighted sum of letter values. The original paper alludes to a non-trivial mapping but does not pin down the exact form. The structure we use is the simplest one that (a) makes the targets a function of both content and position, and (b) is recoverable from 20 examples.
- Output read at the final timestep only. RHW also discuss variants that emit at every timestep; we chose final-timestep output for cleaner train/test scoring. This is a presentation choice and does not affect the underlying BPTT mechanics.
- Modern training tweaks. Momentum (0.9), weight decay (1e-4), and global-norm gradient clipping are used. The 1986 paper used vanilla SGD with momentum; weight decay and grad-clipping are modern stabilisers added here for reproducibility on a laptop without per-seed babysitting. They do not change the qualitative phenomenon – the multi-seed sweep above shows the result is robust without per-seed tuning.
- Variable-timing variant. RHW describe a “variable presentation rate” version with 60 hidden units. We resample the per-letter hold count uniformly from {1, 2} every training sweep so that the network must learn a time-warp-invariant representation. We use {1, 2} rather than {1, 2, 3} because at the chosen learning rate, max_timing = 3 leads to longer BPTT chains that need more careful curriculum/clipping; max_timing = 2 demonstrates the invariance phenomenon cleanly within ~6 s wall-clock. The {1, 2, 3} setting is reachable with longer training + finer LR tuning (see Open questions).
Reproducibility
--seed(model init),--dataset-seed(sequence sampling + train/test split), andteacher_seed = 1234(fixed in source) all exposed.- All hyperparameters appear in the Results table and as CLI flags.
--save-resultsdumps a JSON with full config + git commit + python/numpy versions + held-out predictions.- The headline number reproduces exactly by running
python3 sequence_lookup_25.py --seed 0.
Open questions / next experiments
- Variable-timing with max_timing = 3. The current variable-timing config caps each letter’s hold at 2 timesteps; the original paper allows wider variability. With 60 hidden units and our hyperparameters, max_timing = 3 fails to converge; either a longer curriculum (start at max_timing = 1 and increase) or a different optimizer (Adam-style adaptive LR) is likely needed. Worth quantifying.
- Why does bit 2 generalize less reliably than bits 0 and 1? On the chosen dataset_seed, bit 2’s teacher decision boundary happens to lie close to one held-out sequence. Sweeping
dataset_seedwould tell us whether this is a generic quirk of the geometry or specific to seed 0. - Hidden-state factorization. The trained
W_xhcolumns are not orthogonal across letters, but the network still generalizes. Probing whether the held-out sequences land in a “linear extrapolation” of the training-set hidden codes (vs. a nonlinear region) would say how the network is generalizing – by interpolation, by attribute factorization, or by something like a slow nearest-neighbour readout in hidden space. - Energy / data-movement metric (v2). This v1 implementation reports correctness only. A v2 pass would track ARD or ByteDMD across BPTT to see how the BPTT recurrence dominates data movement, then ask whether non-BPTT credit-assignment (e.g. a Hebbian or local-learning-rule variant) can match the same generalization at lower data-movement cost.
- Larger alphabets / longer sequences. With 5 letters and 5 positions there are 3125 possible sequences. Scaling to (10 letters, 8 positions = 100M sequences) with the same 25-sample / 5-test budget would be a more honest test of whether the network is finding the teacher rule or just memorising a tiny fraction of input space.
2-bit distributed-to-local with 1-unit bottleneck
Source: Rumelhart, Hinton & Williams (1986), “Learning internal representations by error propagation”, in Parallel Distributed Processing, Vol. 1, Ch. 8 (MIT Press). Short version: Nature 323, 533–536.
Demonstrates: Backprop will use intermediate (graded) hidden activations when the architecture forces it. The single sigmoid hidden unit takes 4 distinct graded values (paper target ≈ 0, 0.2, 0.6, 1.0) so the 4 output sigmoids can read out which of the 4 input patterns is active.

Problem
| input $x_1$ | input $x_2$ | one-hot target |
|---|---|---|
| 0 | 0 | (1, 0, 0, 0) |
| 0 | 1 | (0, 1, 0, 0) |
| 1 | 0 | (0, 0, 1, 0) |
| 1 | 1 | (0, 0, 0, 1) |
The architecture is 2 → 1 → 4: two binary inputs, a single sigmoid hidden unit, four sigmoid outputs. This is the smallest network in PDP Vol. 1 Ch. 8 that demonstrates the “graded internal representation” phenomenon.
The interesting property: a scalar in $[0, 1]$ cannot pick a category by hard membership, so backprop’s only path through the bottleneck is to assign each pattern a distinct graded hidden activation. The 4 output sigmoids then read out which pattern is active by their relative ordering at the corresponding $h$ value. Because every output $o_j = \sigma(w_j h + b_j)$ is monotone in $h$, the four output-vs-$h$ curves form a 1-D winner-takes-all partition of the unit interval — four sigmoids can carve $[0, 1]$ into at most four argmax regions, and that’s exactly what the network needs.
The pre-activation of the single hidden unit is $z = w_1 x_1 + w_2 x_2 + b$. The 4 patterns therefore yield $z \in {b,\ w_2{+}b,\ w_1{+}b,\ w_1{+}w_2{+}b}$. For all 4 hidden values to be distinct, we need $w_1 \neq 0$, $w_2 \neq 0$, $w_1 \neq w_2$, and $w_1 \neq -w_2$. The last condition is what fails most often: backprop falls into a shallow basin where $w_1 \approx -w_2$, collapsing patterns $(0,0)$ and $(1,1)$ to the same hidden value (the network has effectively rediscovered XOR and is stuck at 75% accuracy). Escaping this basin requires the perturb-on-plateau wrapper that RHW1986 used.
Files
| File | Purpose |
|---|---|
distributed_to_local_bottleneck.py | Dataset + 2-1-4 MLP + backprop with momentum + perturb-on-plateau + CLI. Numpy only. Exposes generate_dataset(), build_model(), train(), hidden_values(). |
visualize_distributed_to_local_bottleneck.py | Static training curves + per-pattern hidden-value trajectories + the 1-D graded-values readout + the 4 output sigmoids over $h$ + Hinton-style weight heatmaps. |
make_distributed_to_local_bottleneck_gif.py | Animated GIF: the 4 hidden values emerging, with the output-sigmoid curves and training metrics evolving. |
distributed_to_local_bottleneck.gif | Committed animation (≈1 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 distributed_to_local_bottleneck.py --seed 0
Single run takes about 0.1 seconds on an M-series laptop. Final accuracy: 100% (4/4).
To regenerate the visualizations:
python3 visualize_distributed_to_local_bottleneck.py --seed 0
python3 make_distributed_to_local_bottleneck_gif.py --seed 0
To run the multi-seed sweep:
python3 distributed_to_local_bottleneck.py --sweep 30
Results
Single run, --seed 0:
| Metric | Value |
|---|---|
| Final accuracy | 100% (4/4) |
| Final MSE loss | 0.219 |
| First sustained-100% epoch | 583 |
| Total epochs run | 632 (50-epoch stability window + post-trigger residual) |
| Perturbations applied | 1 (at epoch 300) |
| Wallclock | ≈ 0.1 s |
| Hyperparameters | lr=0.3, momentum=0.9, init_scale=1.0 (uniform [-0.5, 0.5]), full-batch updates, perturb_scale=1.0, plateau_window=300, perturb_cooldown=200, h_distinct_eps=0.10, stable_required=50 |
Hidden values per pattern (the headline):
| pattern | $h$ |
|---|---|
| $(0,1)$ | 0.007 |
| $(0,0)$ | 0.167 |
| $(1,1)$ | 0.553 |
| $(1,0)$ | 0.971 |
Sorted: $[0.007,\ 0.167,\ 0.553,\ 0.971]$. Spread (max − min) = 0.964.
Paper reports graded values $\approx (0,\ 0.2,\ 0.6,\ 1.0)$. We got $\approx (0.007,\ 0.167,\ 0.553,\ 0.971)$. Reproduces: yes. The pattern-to-value assignment is permuted relative to the paper (the issue notes this is acceptable), and our two middle values land slightly to the left of the paper’s targets, but the qualitative claim — that one sigmoid hidden unit takes 4 distinct graded values to discriminate the 4 patterns — holds at 100% accuracy.
Sweep over 30 seeds (--sweep 30):
| Statistic | Value |
|---|---|
| Converged seeds (final 100% accuracy) | 30 / 30 |
| Mean epochs to first sustained-100% | 894 |
| Median epochs | 736 |
| Min / max epochs | 283 / 2341 |
| Mean perturbations per run | 2.37 |
| Mean spread (max $h$ − min $h$) | 0.971 |
| Sorted $h$ — mean across the 30 seeds | $[0.011,\ 0.321,\ 0.595,\ 0.983]$ |
| Total sweep wallclock | ≈ 2.5 s |
The 30/30 convergence rate is contingent on the perturb-on-plateau wrapper. Without it, 0/30 seeds converge at any tested combination of lr ∈ {0.1, 0.3, 0.5, 1.0} and init_scale ∈ {0.3, 0.5, 1.0, 1.5, 2.0, 3.0}: every random init falls into the XOR-collapse local minimum where $w_1 \approx -w_2$. See Deviations below.
Visualizations
Graded-values readout — the 1-D headline plot

This is the unique deliverable for this stub. Left: a number line showing where each of the 4 input patterns lives along the single hidden unit’s $[0, 1]$ axis. The colored markers are the observed values; the gray ticks underneath are the paper’s reference targets at $0, 0.2, 0.6, 1.0$. Right: a bar chart of the same data with the paper targets overlaid as dashed reference lines.
The point: a single scalar takes 4 distinct graded values to encode 4 patterns. The bottleneck cannot use a binary code (only $h \in {0, 1}$ would give 2 values), so backprop is forced into intermediate activations.
1-D decision regions — output sigmoids over $h$

The 4 output sigmoids $o_j(h) = \sigma(w_j h + b_j)$ plotted against the hidden activation $h$. Each curve is monotone — the architecture has no way around that — but together the four curves carve $[0, 1]$ into 4 argmax regions (shaded). The dashed vertical lines show where the 4 patterns’ actual hidden values fall, and each lands inside its target’s argmax region. This is the 1-D analog of the 2-D decision boundary you’d see in xor/: instead of a curve in 2-D space, the network has placed 4 graded points along a single 1-D axis.
Training curves

Four signals over training (red vertical line = perturbation; green dashed = sustained-convergence epoch):
- Loss drops from 0.58 to ≈ 0.22. The asymptote is well above zero: with sigmoid outputs reading a single graded $h$, no setting of $W_2, b_2$ can drive every output to ${0, 1}$ exactly, so MSE saturates near a value bounded below by the per-pattern variance the architecture cannot encode.
- Accuracy climbs in stair-steps as patterns peel off from the collapsed pair: 25% → 50% → 100%. At seed 0 a single perturbation at epoch 300 breaks the XOR-collapse basin.
- Weight norm grows steadily: the network strengthens its readout to push apart the closely-stacked output-sigmoid scores at intermediate $h$ values.
- Hidden values per pattern is the storytelling plot. Every input pattern starts at the same $h \approx 0.4$ (sigmoid of a small random $z$), and over training the four trajectories peel apart into the four graded values. The perturbation at epoch 300 visibly kicks the trajectories before they settle into the four-level configuration.
Final weights

Hinton-style heatmap of $W_1$ (input → hidden) and $W_2$ (hidden → output) plus biases.
For seed 0 the trained weights are roughly $W_1 \approx (+5, -4)$ with bias $-0.8$. Both magnitudes are close but with opposite signs — exactly the configuration that produces 4 distinct hidden $z$ values: $z(0,0) = -0.8$, $z(0,1) = -4.8$, $z(1,0) = +4.2$, $z(1,1) = +0.2$, which after sigmoid gives $(0.31, 0.008, 0.985, 0.55)$ — matching the four graded values. $W_2$ shows the readout pattern: each output unit has a different (weight, bias) line in the $(h, \text{score})$ plane such that its line is the maximum exactly inside its target’s argmax region.
Deviations from the original procedure
- Convergence criterion. RHW1986’s stated rule is “every output within 0.5 of its target” (16 conditions, four outputs × four patterns). That rule is not achievable for the 2-1-4 architecture: each output sigmoid is monotone in the single hidden activation, so 4 outputs cannot simultaneously fire above 0.5 selectively for one of 4 distinct $h$ values. The achievable signal is argmax accuracy plus a graded-spread requirement on the hidden values, which is exactly the “4 distinct graded values” claim that the paper itself uses as the headline. We use: 100% argmax accuracy + minimum pairwise $h$ gap > 0.10, sustained for 50 consecutive epochs. The encoder-backprop-8-3-8 sibling (PR #16) makes the same compromise.
- Perturb-on-plateau wrapper. RHW1986 reports treating rare local-minimum runs by perturbing weights and continuing. For the 2-1-4 problem the local minimum is not rare — it’s the default. With backprop alone, 0/30 seeds converge at any reasonable hyperparameter; with perturb-on-plateau, 30/30 do. This wrapper is therefore essentially required, and the spec v2 amendment to issue #1 lists perturb-on-plateau as an explicit (recommended) acceptance-checklist item. The trigger is “stuck for
plateau_windowconsecutive epochs” where stuck means accuracy < 100% or the minimum pairwise $h$ gap is below the distinctness threshold. - Init distribution. Uniform $[-0.5, 0.5]$ (
--init-scale 1.0), matching thexor/sibling (PR #3). The paper used a slightly tighter range; this width gave the most reliable convergence in our hyperparameter sweep. - Floating-point precision.
float64numpy. The 1986 paper’s hardware was not modern IEEE 754; this should not matter at this size. - Sigmoid clamping. Pre-activations clipped to $[-50, 50]$ to avoid
np.expoverflow (modern numerical hygiene). - Loss. Mean-of-summed squared error, $\frac{1}{2} \cdot \text{mean}n \sum_j (o{n,j} - t_{n,j})^2$, matching RHW1986’s “simple example” formulation. Cross-entropy would also work and converges marginally faster, but MSE gives a cleaner pedagogical loss curve.
Open questions / next experiments
- Why is the XOR-collapse basin so dominant? Without perturb-on-plateau, every random init we tried (across 30 × 6 × 4 = 720 seed/init/lr combinations) falls into $w_1 \approx -w_2$. Is there a principled reason — does the loss landscape have a measure-zero “good” basin near init? Or is there an init scheme that escapes the trap directly (orthogonal init, Xavier, NTK-style)?
- Generalization to more inputs. For an $n$-bit distributed input with $2^n$ one-hot targets and a 1-unit bottleneck, does backprop still find $2^n$ graded values? At what $n$ does it become impossible to separate sigmoid outputs by argmax along a single axis? Each sigmoid output is monotone, so $k$ sigmoid outputs can produce at most $k$ argmax regions on a 1-D axis — for $2^n > k$ patterns you would need more output units or a non-monotone readout.
- Direct construction. Given that we know the final $h$ values must be 4 graded levels, a direct algebraic construction of $(W_1, b_1, W_2, b_2)$ exists. How does its data-movement cost (forward pass only) compare to the cost of running backprop with perturb-on-plateau to discover the same configuration? This is exactly the v2 question this catalog is being built to enable.
- Cross-entropy loss + softmax output. With 4-way softmax output and cross-entropy loss, the convergence rate without perturb might be higher (sharper gradients near saturation). Worth a sweep.
- Single-precision and quantization. The trained weights are around $|w| \approx 5$ to 8. If we quantize $W_2$ to 4-bit signed integers, do the 4 graded $h$ values still get separated by argmax? This is the cheapest probe of energy-efficient inference for a network whose internal representation is intrinsically analog.
Shifter / shift-direction inference
Reproduction of the shifter experiment from Hinton & Sejnowski (1986), “Learning and relearning in Boltzmann machines”, Chapter 7 of Rumelhart, McClelland & PDP Research Group, Parallel Distributed Processing, Vol 1, MIT Press.

Problem
Two rings of N = 8 binary units V1 and V2, where V2 is a copy of
V1 shifted (with wraparound) by one of {-1, 0, +1} positions. Three
one-hot units V3 encode the shift class. The network sees all 19 visible
bits during training and must infer V3 from V1 + V2 at test time.
- V1: 8 input bits
- V2: V1 shifted by -1, 0, or +1 with wraparound
- V3: 3 one-hot units indicating which shift was applied
- Visible: 19 = 8 + 8 + 3
- Hidden: 24 (matches the original Figure 3 layout)
- Training set: full enumeration of
2^N x 3 = 768cases for N = 8
The interesting property: no pairwise statistic between V1 and V2 carries
information about the shift class. The hidden units must discover
third-order conjunctive features of the form
V1[i] AND V2[(i + s) mod N] -> shift = s. This is the canonical “higher-
order feature” problem and the motivating example for Boltzmann learning’s
ability to find features that perceptrons (which use only pairwise
statistics) cannot.
Files
| File | Purpose |
|---|---|
shifter.py | Bipartite RBM trained with CD-1. Same gradient form as the 1986 Boltzmann learning rule (positive phase minus negative phase), with the efficient bipartite sampling structure. Exposes make_shifter_data, build_model, train, shift_recognition_accuracy, per_class_accuracy, and accuracy_vs_v1_activity. |
visualize_shifter.py | Hinton-diagram weight viz (the headline figure) + training curves + accuracy vs V1 activity + confusion matrix + heatmap. |
make_shifter_gif.py | Generates shifter.gif (the task illustration at the top of this README). |
shifter.gif | Animation cycling through the three shift classes. |
viz/ | Output PNGs from the run below. |
Running
python3 shifter.py --N 8 --hidden 24 --epochs 200 --seed 0
Training takes ~7.5s on a laptop, plus ~6s for the final 200-Gibbs-sweep evaluation pass. To regenerate the visualization outputs (also re-trains):
python3 visualize_shifter.py --N 8 --hidden 24 --epochs 200 --seed 0 --outdir viz
python3 make_shifter_gif.py --N 8 --fps 12 --out shifter.gif
Results
| Metric | Value |
|---|---|
| Final accuracy (full 768 cases, seed 0) | 92.3% |
Per-class: left (-1) | 86.7% |
Per-class: none (0) | 94.9% |
Per-class: right (+1) | 93.8% |
| Range across V1-activity buckets (k = 1..7) | 58.3% - 98.8% |
| Training wallclock | ~7.5s |
| Eval wallclock (200 Gibbs sweeps) | ~6s |
| Hyperparameters | hidden = 24, lr = 0.05, momentum = 0.7, batch = 16, 200 CD-1 epochs |
The paper reports 50-89% accuracy varying with the number of on-bits in V1. Our k = 1..7 range (the meaningful slice of the data — see the next section) sits at 58.3% - 98.8%, comfortably above the paper’s range.
What the network actually learns
Position-pair detectors (the headline figure)

Each of the 24 panels is one hidden unit’s incoming weights, drawn in the same Hinton-diagram convention used in Figure 3 of the original paper:
- top-left: threshold (bias)
- top-right trio: output weights
[L, N, R]to the three V3 units - bottom two rows: V1 and V2 receptive fields
- white = positive, black = negative, square area proportional to |w|
Units sort into three blocks by their preferred shift class (argmax over
the V3 weights). A unit preferring “shift left” reliably shows a strong
pair at V1[i] and V2[(i - 1) mod N] — exactly the conjunctive feature
the task requires. The same pattern with +1 offset appears for “right”
units, and 0 offset for “none” units. These are the third-order
features the original chapter emphasizes.
The training run prints the most interpretable position-pair detector;
for seed 0 it’s unit 21, with strongest pair V1[2] <-> V2[1] (offset 7
mod 8 = -1, consistent with shift-left) and output preference
L = +6.09, N = +0.13, R = -5.60.
Accuracy vs V1 activity

Bucket the 768 test cases by how many V1 bits are on. The paper reports 50-89%; our run sits at 58-99% on the interesting middle (k = 1..7). The k = 0 and k = 8 cases are intrinsically ambiguous: V1 is all zeros or all ones, so V2 is identical to V1 regardless of shift, giving exactly chance performance no matter what the network learns. The plotted range mirrors the original observation that mid-density patterns are substantially easier than near-empty / near-full ones.
Training curves

Reconstruction MSE drops monotonically from ~0.25 (random init) to <0.01 by epoch 200. Recognition accuracy lifts from chance (33.3%) starting around epoch 30, climbs through 70% by epoch 75, and saturates near 90% after epoch 100. No plateau / restart machinery is needed at this scale — training is a clean monotone climb.
Weight heatmap and confusion matrix

The off-diagonal entries of the confusion matrix concentrate on the left/right axis (true-left predicted-right, etc.), as expected: the hardest patterns are nearly rotation-symmetric ones where left and right shifts produce visually similar V2 strings.
Deviations from the 1986 procedure
- Sampling. CD-1 (Hinton 2002) instead of simulated annealing. Same positive-phase-minus-negative-phase gradient, much cheaper sampling.
- Connectivity. Explicit bipartite (visible <-> hidden), making this an RBM in modern terminology. The original paper’s shifter network is fully-connected within the visible layer; collapsing those connections into the hidden layer is the standard simplification.
- Hardware. Modern laptop, ~14s end-to-end including evaluation. The 1986 paper ran on a VAX with substantially longer training time, and reported 9000 annealing cycles per training pass.
- Hidden units. 24 hidden units to match the original Figure 3 layout, as specified in the issue. The original chapter notes that with simulated annealing, several of the 24 units “do very little” — our CD-1 network uses all 24 productively (10 N-units, 7 L-units, 7 R-units) but several are clearly weaker than the canonical position-pair detectors.
Open questions / next experiments
- The original chapter reports per-class accuracies between 50% and 89% varying with V1 density, never reaching the >90% we see here. Is the CD-1 RBM overfitting the closed 768-case enumeration in a way the annealing network did not? Train/test split would clarify.
- Several hidden units (e.g. unit 19 in our seed-0 run) end with weak, diffuse weights and unclear class preference — analogous to the “do very little” units the original paper mentions. Are they redundant, or are they encoding a low-frequency interaction that becomes useful only on the harder near-symmetric patterns?
- How does the data-movement cost (ByteDMD / simplified Dally model) of this CD-1 implementation compare to a faithful simulated-annealing variant on the same architecture? CD-1’s per-step cost is dominated by two visible-x-hidden matrix multiplies; simulated annealing pays for many more sampling sweeps but no separate “negative phase” pass.
Reference implementation
This implementation is lifted from
cybertronai/sutro-problems/wip-boltzmann-shifter/
(the working RBM-based shifter, ~87% on 768 N=8 cases at hidden = 80) and
adapted to the hinton-problems stub layout, defaulting to 24 hidden
units to match the original Figure 3.
Grapheme-sememe synthetic word reading
Source: Hinton & Sejnowski (1986), “Learning and relearning in Boltzmann machines”, in Rumelhart & McClelland (eds.), Parallel Distributed Processing Vol. 1, Chapter 7.
Demonstrates: Distributed representations are damage-resistant. After training a 30→20→30 net on 20 random grapheme→sememe associations and randomly zeroing 50% of its weights, retraining on only 18 of the 20 associations partially restores accuracy on the 2 held-out associations — the famous spontaneous-recovery effect.

Problem
A toy model of word-reading. The input layer encodes a 3-letter “word” with one-hot per position (3 positions × 10 letters = 30 binary grapheme units). The output layer encodes the meaning as 30 binary sememe units built from a small pool of shared “semantic micro-features” (the network has to learn that distinct words can share semantic structure).
A 30→20→30 sigmoid MLP learns 20 random grapheme→sememe associations. Then we run the famous 4-stage protocol from H&S 1986:
- Train on all 20 associations to convergence (100% pattern accuracy).
- Lesion — randomly zero a fraction of W1 and W2 weights.
- Relearn-subset — retrain on only 18 of the 20 patterns. The lesioned weights stay at zero (permanent damage); only the surviving connections update.
- Test held-out 2 — measure bit accuracy on the 2 patterns that were never shown during stage 3. If the network’s representation is distributed, these recover spontaneously.
Files
| File | Purpose |
|---|---|
grapheme_sememe.py | Dataset + 30-20-30 sigmoid MLP + backprop + lesion() + relearn_subset() + run_protocol() + sweep() + CLI. Numpy only. |
visualize_grapheme_sememe.py | Static training curves, weight heatmaps with lesioned entries marked, per-bit reconstruction plots, spontaneous-recovery bar chart. |
make_grapheme_sememe_gif.py | Animated GIF: held-out 2 sememe activations + 4 trained-sample sememes + bit-accuracy timeline across the 4 stages. |
grapheme_sememe.gif | Committed animation (~1.1 MB). |
viz/ | Committed PNG outputs from the run below. |
Running
python3 grapheme_sememe.py --seed 0
Single run takes about 2 seconds on an M-series laptop. To regenerate visualizations:
python3 visualize_grapheme_sememe.py --seed 0
python3 make_grapheme_sememe_gif.py --seed 0
To run the multi-seed sweep that produced the recovery distribution below:
python3 grapheme_sememe.py --sweep 30
Results
Single run, --seed 0:
| Metric | Value |
|---|---|
| Architecture | 30 → 20 → 30 sigmoid MLP, 1250 parameters |
| Stage 1 pattern accuracy (all 20) | 100.0% |
| Stage 1 bit accuracy (held-out 2) | 100.0% |
| Stage 2 bit accuracy (held-out 2) — post-lesion 50% | 88.3% |
| Stage 4 bit accuracy (held-out 2) — post-relearn-on-18 | 95.0% |
| Spontaneous recovery (held-out bits) | +6.7 pp |
| Stage 4 bit accuracy (trained 18) | 96.1% |
| Wallclock end-to-end | ~2 s |
| Hyperparameters | lr=0.3, momentum=0.5, weight_decay=1e-3, full-batch backprop, 1500 train cycles + 50 relearn cycles |
Sweep over 30 seeds (--sweep 30):
| Metric | mean | std | min | max |
|---|---|---|---|---|
| post-lesion bit acc (held-out 2) | 80.5% | 5.5 pp | 65.0% | 88.3% |
| post-relearn bit acc (held-out 2) | 82.5% | 7.2 pp | 63.3% | 95.0% |
| post-relearn bit acc (trained 18) | 95.7% | 1.2 pp | 93.1% | 98.0% |
| spontaneous recovery (pp) | +2.0 | 5.9 | -10.0 | +15.0 |
The mean recovery is positive but variable: about half the seeds show clean recovery (some up to +15 pp), the other half show small negative deltas where catastrophic forgetting on the 2 held-out patterns slightly outpaces re-learning of shared structure. Bit accuracy never falls below ~63% — well above the ~50% chance baseline for random sigmoid output and the ~58% baseline for predicting the marginal sememe density.
Comparison to the paper:
H&S 1986 reports near-perfect spontaneous recovery on the 2 held-out items after retraining the 18 (using full Boltzmann learning with simulated annealing, which strongly biases toward distributed representations).
We get 95.0% bit accuracy on held-out 2 after relearning on 18, at seed 0. Across 30 seeds: 82.5% mean, 63.3-95.0% range. Reproduces: yes (qualitatively — held-out accuracy stays well above chance and is on average above its post-lesion value), with the caveat that recovery is per-seed variable under backprop.
The phenomenon is real and reproducible at this scale, but quantitatively weaker than the paper’s full-Boltzmann result. Two known reasons (see Deviations below): (1) backprop has weaker implicit pressure toward distributed representations than Boltzmann learning, and (2) we tuned relearning to be short (50 cycles) to balance recovery against catastrophic forgetting on the held-out 2.
Visualizations
4-stage timeline (training + lesion + relearn)

The bit-accuracy panel (top) is the central result. Stage 1 (left of orange dashed line): both trained 18 (blue) and held-out 2 (red) climb to 100% as the net memorizes the 20 patterns. Stage 2 (orange dashed line): the 50% lesion drops both lines — the held-out 2 to 88%, the trained 18 to roughly the same. Stage 3 (right of orange line): retraining on 18 only. Crucially, the red held-out line goes back up even though those 2 patterns are never shown — that’s spontaneous recovery. Pattern accuracy (bottom) tells the same story but binarized at the per-pattern level: trained-18 patterns recover partially, held-out 2 stays at 0% (the network still gets a few bits wrong on each held-out word, even if those bits are correct on average).
Spontaneous-recovery bar chart

The headline metric in one picture. The held-out 2 start at chance (~50% bit accuracy with random init), reach 100% after stage 1, drop to 88% after stage 2, and recover to 95% after stage 4 — without ever being shown during stage 3.
Weights with lesioned entries

W1 (hidden ← grapheme, 20 × 30) and W2 (sememe ← hidden, 30 × 20) at the end of stage 4. Red is positive, blue is negative. Black squares mark the entries that were zeroed at stage 2 and held at zero through stage 3 — the network had to route around them. Roughly half of each matrix is black, by construction.
Per-bit reconstructions of the 2 held-out words

For each of the 2 held-out words, four bars per sememe bit: the gray target, plus the network’s prediction at three time points (post-train, post-lesion, post-relearn). Post-train predictions match the target almost exactly. Post-lesion predictions get most bits right but a few have flipped. Post-relearn predictions, despite never seeing these words again during stage 3, are tighter to the target than post-lesion on most bits.
Deviations from the original procedure
-
Algorithm: backprop instead of Boltzmann learning. The 1986 paper used a Boltzmann machine with simulated annealing, learning by maximizing log-likelihood under positive/negative-phase statistics. We use deterministic 30→20→30 backprop with sigmoid activations and per-bit Bernoulli cross-entropy. The spec explicitly permits either; backprop is much simpler and faster on a deterministic mapping. Boltzmann learning has stronger implicit pressure toward distributed representations (the negative-phase term penalizes the network for using over-confident, non-distributed codes), which is why our recovery effect (mean +2 pp) is weaker than the paper’s near-perfect recovery.
-
Sememes built from shared prototypes. The paper used hand-designed sememes that shared semantic micro-features across related words (“CAT” and “DOG” both have “ANIMAL”, etc.). We approximate this by drawing each sememe as the OR of 2 prototypes from a pool of 4, plus 5% bit-flip noise. With 4 prototypes, each prototype is shared by ≈10 words on average, so retraining 18 forces the network to re-learn the shared features that the 2 held-out words also use. Without this shared structure (independent random Bernoulli sememes), backprop shows no spontaneous recovery — it just catastrophically forgets the held-out 2.
-
Brief, regularized relearning. With backprop’s lack of implicit regularization, long relearning catastrophically overfits the 18 and erases the held-out 2 (we measured this directly — going from 50 to 200 cycles flips mean recovery from +2 pp to −5 pp). We use 50 relearn cycles + L2 weight decay (1e-3) + reduced momentum (0.5 vs. 0.9 in stage 1) to balance recovery of the 18 against forgetting of the 2. See the open question in the next section about whether this is just papering over the underlying issue.
-
Lesion is on weights, not synaptic strengths in a stochastic network. “Lesion” in the 1986 paper meant zeroing connection strengths in a Boltzmann machine; we zero entries of W1 and W2 in the deterministic MLP and keep the mask active during stage 3 (so re-learning routes around the damage rather than rebuilding it). Biases are not lesioned — they are not really “synapses” in the 1986 interpretation.
-
No perturbation-on-plateau wrapper. Stage 1 converges reliably from random init at this scale (lr=0.3, momentum=0.5, 1500 cycles), so the wrapper isn’t needed.
Open questions / next experiments
-
Boltzmann reproduction. The natural follow-up is a Boltzmann-machine implementation (extending the bipartite RBM from
encoder-4-2-4) to see whether the per-seed variance in spontaneous recovery shrinks under the original learning rule. The hypothesis is that the negative-phase term provides implicit regularization that prevents the catastrophic-forgetting failure mode we hit with backprop. -
Recovery as a function of held-out-prototype overlap. With our prototype-based sememe construction, the held-out 2’s prototypes are sometimes shared by many of the 18 (high overlap → strong recovery) and sometimes by few (low overlap → weak/negative recovery). A controlled sweep stratifying seeds by held-out-prototype overlap would quantify how much “semantic similarity to the trained set” is the actual mechanism.
-
Lesion-fraction sweep. At lesion=0.5 we get +2 pp mean recovery; preliminary sweeps suggest lesion=0.7 with a smaller hidden bottleneck (8 hidden) gives +9 pp recovery on average, while lesion=0.2 leaves not much room to recover. The full curve would map damage tolerance vs. relearning capacity.
-
Data movement. This is the v1 baseline. v2 (the broader Sutro effort) will instrument the same training loop with ByteDMD and ask whether a non-backprop solver can achieve the same spontaneous-recovery effect at lower data-movement cost. The 1250-parameter 30→20→30 MLP is small enough that the inference path is essentially free; the interesting question is whether training-with-distribution-pressure (Boltzmann-like, or a structured backprop variant) can be cheaper than 1500 cycles of full-batch backprop.
-
Pattern accuracy never recovers. Bit accuracy on the held-out 2 reaches 95% at seed 0, but pattern accuracy (all 30 bits correct simultaneously) stays at 0%. The 30-bit conjunction is a stiff target — at 95% per-bit you expect 0.95^30 ≈ 21% pattern accuracy, but the errors are correlated within each held-out word. Whether a different output decoding (e.g. nearest-prototype lookup) closes this gap is an open question.
v1 Metrics
| Metric | Value |
|---|---|
| Reproduces paper? | Yes (qualitatively; held-out 2 bit accuracy recovers from 88% post-lesion to 95% post-relearn at seed 0; mean +2 pp across 30 seeds) |
| Wallclock to run final experiment | ~2 s (time python3 grapheme_sememe.py --seed 0 ≈ 1.7 s) |
| Implementation wallclock (agent) | ~70 min (one session, including grid-search for the prototype/relearn-cycle defaults) |
Family trees / kinship task
Source: Hinton (1986), “Learning distributed representations of concepts”, Proceedings of the Eighth Annual Conference of the Cognitive Science Society, pp. 1-12.
Demonstrates: Backprop discovers semantic features (nationality, generation, family branch) that are not explicit anywhere in the input. Hinton’s most cited demonstration of distributed representation learning.

Problem
Two isomorphic 12-person family trees (English + Italian) and 12 kinship
relations: father, mother, husband, wife, son, daughter, uncle, aunt, brother, sister, nephew, niece.
Each example presents a person and a relation; the network must produce the set of all valid answers.
Christopher = Penelope Andrew = Christine
| |
+-------+-------+ +-------+-------+
| | | |
Arthur = Margaret(*) Victoria = James Jennifer = Charles
|
+---+---+
| |
Colin Charlotte
(*) Cross-tree marriage: Arthur is C&P’s son; Margaret is A&C’s daughter.
James and Charles are outsiders married into the tree. Italian tree (Roberto,
Maria, …) mirrors English position-by-position.
| Inputs | 24-bit one-hot person + 12-bit one-hot relation |
| Targets | multi-hot 24-bit vector (the set of valid answers, normalized to a softmax distribution) |
| Total facts | 100 (50 per tree); 4 held out, 96 used for training |
| Architecture | (24+12) -> 6+6 -> 12 -> 6 -> 24, all five hidden/output layers nonlinear |
The interesting property. The 6-unit person-encoding layer sits between the local-coded 24-bit input and a relation-conditioned 12-unit central layer. That bottleneck, plus the requirement that the network correctly answer relations across both trees, forces it to discover features common to both. Hinton showed that the units self-organize into interpretable axes: nationality, generation, family branch. None of these features is given anywhere in the input – the names are arbitrary 1-of-24 tokens. The network has to infer the structure from the relation graph alone.
Files
| File | Purpose |
|---|---|
family_trees.py | Tree definition + 100-fact dataset + backprop MLP + inspect_person_encoding + CLI. |
visualize_family_trees.py | Static training curves, encoding heatmap, per-unit bar charts (the headline interpretable-axes view), 2-D PCA scatter colored by each attribute. |
make_family_trees_gif.py | Generates the animated family_trees.gif. |
family_trees.gif | Animation – PCA + heatmap + training curves frame-by-frame. |
viz/ | PNG outputs from visualize_family_trees.py. |
Running
python3 family_trees.py --seed 6 --epochs 10000
Wall-clock: about 2 seconds on a 2024 laptop. Final train accuracy 100% (96/96), final test accuracy 75% (3/4) – argmax-in-valid-set criterion.
To regenerate the static plots and the GIF:
python3 visualize_family_trees.py --seed 6 --epochs 10000 --outdir viz
python3 make_family_trees_gif.py --seed 6 --epochs 10000 --snapshot-every 250 --fps 10
Results
Headline number, seed 6. Train 100%, test 3/4 (held-out facts: Charlotte mother, Gina mother, Roberto son, James niece – the network gets the
first, third, fourth and confuses Gina’s mother Francesca with Maria, Gina’s
grandmother).
| Metric | Value |
|---|---|
| Final train accuracy | 100% (96/96 facts) |
| Final test accuracy | 75% (3/4 held-out facts) |
| Training time | 2.1 s (single core, numpy) |
| Total facts | 100 (50 per tree); 96 train + 4 test |
| Total triples | 112 (Hinton’s reported 104 with our specific tree shape – see Deviations below) |
| Hyperparameters | seed=6, epochs=10000, lr=0.5, momentum=0.9, init_scale=1.0 (Xavier), weight_decay=0.0 |
Variance across random splits. Across seeds 0..9 with the same recipe: 6 of 10 runs reach 100% training accuracy, average held-out test = 1.9 / 4 correct. Hinton (1986) reported 2 / 4 on his hand-picked test set, so we consider 1.9 / 4 averaged over random hold-outs a faithful match of the paper’s generalization regime. Three seeds (5, 6, 7) hit 3 / 4 on their random hold-out; the rest hover around 1-2 / 4.
Visualizations
6-D person encoding – the headline finding

Each block is one of the six person-encoding units; each block has 24 bars
(one per person in ALL_PEOPLE), recolored three different ways. The three
panels show the same numbers, just grouped to expose three different
implicit axes:
- Top (nationality). Units 0, 1, 4 cleanly separate English (blue, negative) from Italian (red, positive). The network has invented a “nationality detector” axis nowhere in the input.
- Middle (generation). Unit 2 fires strongly positive on generation 3 (Colin / Charlotte / Alfonso / Sophia) and negative on generation 1 grandparents. Unit 3 inverts this – positive on generation 1 grandparents. Combined, the encoder has carved out a generation gradient.
- Bottom (branch). Units 1, 4, 5 distinguish the left-side family (Christopher / Roberto branch) from the right-side family (Andrew / Pierro branch); outsiders (James / Marco, Charles / Tomaso) sit on a third level.
PCA of the encoding

The first two principal components account for roughly 64% of the variance. PC1 splits English from Italian; PC2 separates generations. Every panel is the same point cloud, just colored by a different attribute – the geometry visibly clusters by nationality, by generation, and by branch.
Encoding heatmap

Rows sorted by (nationality, generation, branch). The English block (top
12 rows) and the Italian block (bottom 12) have visibly different column
patterns – consistent with the per-unit bars. Within a nationality, gen-3
grandchildren (last two rows of each block) sit out from the rest.
Training curves

Training accuracy reaches 100% in roughly 200 epochs; test accuracy locks in at 75% (3/4 of the held-out facts) by epoch 500 and never moves. The flat test curve is what we expect with only 4 held-out facts – one of those four is structurally hard (Gina’s mother is not directly trainable from siblings once the held-out fact is removed) and the network never finds it.
Deviations from the original procedure
Hinton’s 1986 setup vs. ours:
- Activation function. Hinton used logistic (sigmoid) units throughout.
We use tanh for hidden layers and softmax at the output. The
reason is gradient flow through four layers of squashing nonlinearities:
sigmoid'(0) = 0.25, so a four-layer chain shrinks gradients by0.25^4 ≈ 0.004;tanh'(0) = 1.0, so the chain preserves them. Empirically, sigmoid hidden units stall the person encoder at its random init – the gradient that reachesW_pis too small to move it before the output layer collapses to the marginal-prediction minimum.tanhplus Xavier init reliably trains in roughly 200 epochs. - Loss function. Hinton used “a quadratic error measure” (sum-squared error) on sigmoid outputs; we use softmax + cross-entropy with soft-distribution targets (each fact’s mass split uniformly across its valid answers). For one-hot 24-class targets the squared-error loss has a vanishing positive-class signal – 23 push-down terms vs. one push-up term. Cross-entropy with softmax balances them automatically.
- Tree structure / triple count. Hinton’s 1986 paper reports 104 valid
triples; our specific tree shape generates 112 triples (= 100 distinct
(P, R)facts sinceAndrew, daughteranswers{Margaret, Jennifer}etc.). The discrepancy stems from a small choice in whose siblings count as blood vs. by-marriage uncles – functionally identical, the interpretable-axes finding doesn’t depend on the exact count. We train on 96 facts and hold out 4, matching Hinton’s reported 100 / 4 split. - Initialization. Hinton used “small random weights”; we use a Xavier
draw (
sigma = sqrt(2 / (n_in + n_out))). This was needed alongside thetanhswap to keep early-training activations off the saturated tails. - Per-attempt convergence. Hinton’s paper reports a single converged
network. We get 6 / 10 random seeds to 100% training accuracy in 10 000
epochs; the rest stall at 30-90% train. We did not implement a
restart-on-plateau wrapper (an explicit v1 simplification, see the
wave-2 spec note) – the headline-result seed is reported above and is
reproducible with
--seed 6.
The architecture itself (24 + 12 -> 6 + 6 -> 12 -> 6 -> 24) is faithful to
Figure 2 of the 1986 paper.
Open questions / next experiments
- Why does sigmoid + sum-squared error reportedly train in the original paper? Did Hinton use a much higher per-pattern learning rate, or per-pattern updates instead of full-batch averaging? A faithful reproduction of his exact recipe would either expose a missing trick or challenge the claim. Reproducing the 1986 paper’s specific hyperparameters is a useful systematic experiment in its own right.
- Does the per-attempt success rate match Hinton’s? Our 6/10 success rate at 10 000 epochs may be lower than the original; we would need Hinton’s failure statistics to know. A 1985 paper by Ackley-Hinton- Sejnowski reported 250/250 for the encoder under simulated annealing, but the family-trees architecture and training procedure differ.
- Restart-on-plateau wrapper. The
encoder-4-2-4andencoder-8-3-8worked examples in this catalog show large solve-rate gains from a perturb-on-plateau detector. Adding one here is the obvious next step, and would let us ship a recipe that hits 100% train on every seed. - Reuse-distance / data-movement cost. This task is ideally sized for ByteDMD instrumentation – 36-bit inputs, 24-bit targets, ~600 weights total. Once the v1 baseline is in, measuring the data-movement cost of one full backprop sweep is the natural follow-up.
- Mapping the 6-D code to the symbolic features explicitly. The per-unit bar charts show that nationality / generation / branch are encoded, but not by single-axis-aligned units. A small linear probe (regressing nationality / generation / branch from the 6-D code) would quantify how separable each feature is, and is a natural extension.
Synthetic-spectrogram riser/non-riser discrimination
Backprop reproduction of the synthetic-spectrogram task from Plaut, D.C. & Hinton, G.E. (1987), “Learning sets of filters using back-propagation”, Computer Speech and Language 2, 35-61.
Demonstrates: A small MLP trained with back-propagation approaches the Bayes-optimal accuracy on a controlled, fully-specified synthetic classification task. The Bayes optimum is computed in closed form via dynamic programming; the gap to the network is a clean diagnostic of how much the learner is leaving on the table.

Problem
Each input is a 6 frequency x 9 time = 54-D synthetic spectrogram. One
“track” – a single frequency value per time-step – is set to 1 and
all other cells to 0. Independent Gaussian noise of std sigma = 0.6
is then added to every cell.
- Class 0 (“rising”): the track is monotonically non-decreasing in frequency over time.
- Class 1 (“non-rising / falling”): the track is monotonically non-increasing over time. (Plaut & Hinton’s original task contrasts upward-sweeping and downward-sweeping formants. We use the non-decreasing / non-increasing duals as the cleanest balanced realisation of “rising vs not-rising”.)
There are C(n_freq + n_time - 1, n_time) = C(14, 9) = 2002 distinct
non-decreasing tracks (and 2002 non-increasing ones, sharing 6 constant
tracks). The two classes are sampled with equal prior.
The interesting property: the task is fully specified, so the
Bayes-optimal classifier is computable. For any track f,
log p(x | f) = const(x) + (1/sigma^2) * sum_t x[f(t), t]. The class-
conditional likelihood p(x | rising) = (1/|R|) sum_{f in R} p(x | f)
sums over all 2002 monotone non-decreasing tracks; with the change of
variables U = x / sigma^2, this sum is a small dynamic program over
(time, current-frequency) state with O(n_time * n_freq^2) work per
sample. The closed form gives us the ceiling that backprop should
asymptote to.
Files
| File | Purpose |
|---|---|
riser_spectrogram.py | Synthetic-spectrogram generation, 54-24-2 sigmoid/softmax MLP, full-batch-style backprop with momentum (online-resampled each epoch), Bayes-optimal classifier (DP). CLI exposes --seed --noise-std. |
visualize_riser_spectrogram.py | Static figures: example noisy inputs (clean track overlaid), training curves with Bayes ceiling, per-hidden-unit input filters, Bayes-vs-net accuracy gap across noise levels. |
make_riser_spectrogram_gif.py | Animates training: example inputs, hidden filters sharpening, train/test accuracy approaching the Bayes ceiling. |
riser_spectrogram.gif | The committed animation (742 KB). |
viz/ | Static PNGs from the run reported below. |
Running
python3 riser_spectrogram.py --seed 0 --noise-std 0.6 --epochs 200
Wall-clock: ~1.0 s training + ~0.2 s for the 50 000-sample Bayes-optimal estimate. Final test accuracy: 98.08%, Bayes: 98.90%, gap +0.83 pp.
To regenerate visualizations:
python3 visualize_riser_spectrogram.py --seed 0 --noise-std 0.6 --epochs 200
python3 make_riser_spectrogram_gif.py --seed 0 --noise-std 0.6 --epochs 160 --snapshot-every 4 --fps 10
Results
| Metric | Value |
|---|---|
| Network test accuracy | 98.08% (3923 / 4000) |
| Bayes-optimal accuracy | 98.90% (50 000 samples) |
| Gap to Bayes | +0.83 pp |
| Paper’s reported numbers | network 97.8%, Bayes 98.8%, gap 1.0 pp |
| Training time | 0.88 s (200 epochs, 2000-sample online resample / epoch) |
| Bayes-DP time | 0.17 s for 50 000 samples |
| Architecture | 54-24-2 (sigmoid hidden, softmax output) |
| Parameters | 1370 (W1 24x54, b1 24, W2 2x24, b2 2) |
| Optimiser | mini-batch backprop with momentum |
| Hyperparameters | lr=0.5, momentum=0.9, batch=100, init_scale=1.0 / sqrt(fan-in) |
| Seed | 0 |
| Loss | softmax cross-entropy |
| Encoding | clean cell value 1.0, noise std 0.6 added iid |
Reproduces the paper to within reporting precision (gap 0.83 pp here vs 1.0 pp reported).
Noise-level sweep
A single training command at any one sigma is one point on this
curve. The full sweep is in viz/bayes_vs_net.png:
| sigma | Net (test) | Bayes-opt | Gap |
|---|---|---|---|
| 0.40 | 99.78% | 99.99% | +0.22 pp |
| 0.50 | 99.20% | 99.77% | +0.57 pp |
| 0.60 | 98.08% | 98.94% | +0.87 pp |
| 0.70 | 96.50% | 97.23% | +0.73 pp |
| 0.80 | 93.88% | 94.91% | +1.03 pp |
The network tracks the Bayes ceiling within ~1 pp across the entire range – the architecture is enough, the gap is sample efficiency, not capacity.
Visualizations
Example inputs

Three rising and three falling examples at sigma = 0.6, with the underlying clean track overlaid in cyan. The track is hard to see by eye in any single panel – the noise std exceeds the cell value of the “on” cells, and the eight “off” cells per column outvote the one “on” cell in raw integrated energy. The structure is recoverable only because the shape of the track (monotone up vs monotone down) is informative across the whole 54-D image.
Hidden filters

Each panel is one hidden unit’s 6 x 9 input weight matrix, displayed in the same orientation as the raw spectrogram. Red = positive, blue = negative. The filters tile the (frequency, time) plane with oriented edges – some prefer “low frequency early, high frequency late” (rising-template), the negatives of those (anti-rising), and a spectrum of intermediate orientations. Together they form a basis the final softmax layer can project onto a single rising / falling axis.
Training curves

Train (blue) and test (red) accuracy, with the Bayes-optimal ceiling (dashed black) overlaid. The two accuracies stay close to each other – no overfitting, because the per-epoch fresh resample (see Deviations) gives the network effectively unlimited (track, noise) variants. Convergence is fast: ~10 epochs to clear 97%, then a slow asymptotic approach to the Bayes line.
Bayes-vs-net across noise levels

Bayes (black) and network (red) accuracy as the noise std is swept from 0.4 to 0.8. The annotated gap (in pp) hovers between +0.2 and +1.0 across the range – consistent with the paper’s single reported operating point.
Deviations from the original procedure
- Online noise resampling. Plaut & Hinton evaluated on a fixed
training set; this implementation re-samples 2000 (track, noise)
pairs every epoch. With 1370 parameters and only ~4000 distinct
clean tracks, a fixed dataset is rapidly memorised at this noise
level (train accuracy hits 100% by epoch 50, test plateaus at
~96%). Online resampling lets the network see fresh noise on the
same track distribution every step, closing the gap to Bayes
without explicit regularisation. (
--offlinetoggles back to a fixed dataset.) - “Non-rising” interpreted as “monotone falling”. The spec calls the second class “non-rising”; we use strictly monotone falling (non-increasing) as the cleanest balanced realisation, matching Plaut & Hinton’s “downward-sweeping formant” class. Each class has 2002 tracks (overlapping at the 6 constant tracks).
- Softmax + cross-entropy instead of paired-sigmoid + MSE. Same gradient form for the visible-to-output weights; cleaner derivation for the modern reader.
- Glorot-style scaled initialisation instead of small uniform. No qualitative effect; just slightly better numerical conditioning.
- Mini-batch (size 100) instead of pattern-by-pattern. Order of magnitude faster on numpy, gradients are still effectively full- batch on each epoch’s resampled set.
Open questions / next experiments
- The noise-level sweep shows the network gap widens slightly at higher noise. Is this a sample-efficiency artefact (more epochs would close it) or a capacity wall? Repeating with online sampling for 1000 epochs at sigma = 0.8 would settle this.
- The hidden filters look basis-like rather than template-like (no filter is a single track). Quantifying the rank of W1 – does the task get solved with rank << 24, and if so, can we shrink the hidden layer without losing accuracy?
- Plaut & Hinton’s broader paper varies the number of formants and the noise structure. Adding a second concurrent track and asking the network to determine “any rising track present?” probes composition.
- Energy axis (out of v1 scope): every cell is read multiple times per epoch by all 24 hidden units. A rising-template-detector that only reads cells along plausible monotone paths would have far smaller data-movement cost; quantifying that gap is the natural ByteDMD follow-up.
Fast weights with rehearsal
Source: G. E. Hinton & D. C. Plaut (1987), “Using Fast Weights to Deblur Old Memories”, Proceedings of the Ninth Annual Conference of the Cognitive Science Society, pp. 177–186.
Demonstrates: A linear associator with two-time-scale weights (slow plastic + fast elastic-decaying) learns set A, learns a disjoint set B (which appears to overwrite A), then briefly rehearsing a small subset of A rapidly restores A — the headline “deblurring” effect that the foundational fast-weights paper reports.

Problem
A 50-dimensional linear associator stores n_pairs random ±1 vector associations (x_i → y_i). Each weight has two components:
W_slow: small learning rate (slow_lr=0.1), no decay → long-term plastic storeW_fast: large learning rate (fast_lr=0.5), multiplicative decay (fast_decay=0.9per presentation) → short-term elastic store
The effective weight is W_eff = W_slow + W_fast; both components are updated every presentation by the delta rule dW = (1/dim) outer(target - W_eff·x, x).
The 4-phase protocol:
- Learn A — train on 20 A-pairs for 30 sweeps.
recall_Areaches 100%. - Learn B — train on 20 disjoint B-pairs for 30 sweeps.
recall_Breaches 100% butrecall_Adrops to ~80% (interference: fast weights decay during B-learning, slow weights drift toward B). - Rehearse subset of A — re-present just 5 of the 20 A-pairs for 5 sweeps.
- Test — measure
recall_Aandrecall_Bwith no further updates.
The interesting property: rehearsing a small subset of A restores the rehearsed pairs to 100% through fast-weight reactivation, even though the slow weights have meaningfully shifted toward B. The rehearsal is enough to push pattern accuracy on A from 0% (after B) back up to 25%. The 1987 paper frames this as evidence that distributed memories can be “deblurred” by partial cues — fast weights do the heavy lifting of routing past the interference, slow weights provide the substrate.
Files
| File | Purpose |
|---|---|
fast_weights_rehearsal.py | FastWeightsAssociator (slow + fast weight components) + learn_set / rehearse_subset / recall_accuracy + run_protocol (4-phase orchestration) + sweep (multi-seed) + CLI |
visualize_fast_weights_rehearsal.py | Static plots: 4-phase training curves (bit + pattern + weight norms), slow-vs-fast weight heatmaps at end of each phase, per-pair recovery bars, per-pair distribution histograms |
make_fast_weights_rehearsal_gif.py | Animated GIF: per-pair A and B recall + 4-phase mean-recall timeline with phase color bands |
fast_weights_rehearsal.gif | Committed animation (~800 KB) |
viz/ | Committed PNG outputs from the run below |
Running
python3 fast_weights_rehearsal.py --seed 0
Single run takes ~0.15 s (time python3 fast_weights_rehearsal.py --seed 0 measured at 0.14 s wall on an M-series laptop, system Python 3.9 + numpy 2.0).
To regenerate visualizations:
python3 visualize_fast_weights_rehearsal.py --seed 0
python3 make_fast_weights_rehearsal_gif.py --seed 0
To aggregate across 30 seeds:
python3 fast_weights_rehearsal.py --sweep 30
CLI flags (the spec calls out --seed --dim --n-pairs; everything else is optional):
--seed RNG seed default 0
--dim vector dimension default 50
--n-pairs # of pairs in A and in B default 20
--n-rehearse # of A pairs to rehearse default n_pairs // 4 = 5
--slow-lr slow weight learning rate default 0.1
--fast-lr fast weight learning rate default 0.5
--fast-decay fast weight decay/step default 0.9
--n-a-sweeps sweeps over A in phase 1 default 30
--n-b-sweeps sweeps over B in phase 2 default 30
--n-rehearse-sweeps sweeps in phase 3 default 5
--sweep N aggregate over N seeds (else single run)
Results
Single run, --seed 0:
| Metric | Value |
|---|---|
| Architecture | linear associator, dim=50 → dim=50, slow + fast component each |
| Parameters | 5 000 (2 × 50 × 50) |
| Phase 1: bit acc on A (post-learn-A) | 100.0% |
| Phase 1: pattern acc on A | 100.0% |
| Phase 2: bit acc on A (post-learn-B) | 81.2% |
| Phase 2: pattern acc on A | 0.0% |
| Phase 2: bit acc on B | 100.0% |
| Phase 3: bit acc on A (post-rehearse 5/20) | 85.8% |
| Phase 3: pattern acc on A | 25.0% (headline: 0% → 25%) |
| Phase 3: bit acc on B | 99.9% |
| Deblur recovery — rehearsed pairs (bit) | +14.8 pp |
| Deblur recovery — unrehearsed pairs (bit) | +1.2 pp |
| Wallclock end-to-end | ~0.15 s |
| Hyperparameters | slow_lr=0.1, fast_lr=0.5, fast_decay=0.9, 30 / 30 / 5 sweeps |
Sweep over 30 seeds (--sweep 30):
| Metric | mean | std | min | max |
|---|---|---|---|---|
| post-A bit acc on A | 100.0% | 0.0 pp | 100.0% | 100.0% |
| post-B bit acc on A (interference) | 78.9% | 1.8 pp | 75.1% | 84.4% |
| post-B pattern acc on A | 0.3% | 1.2 pp | 0.0% | 5.0% |
| post-rehearsal bit acc on A | 85.1% | 1.4 pp | 83.1% | 89.2% |
| post-rehearsal pattern acc on A | 25.0% | 1.3 pp | 20.0% | 30.0% |
| deblur recovery (bit, all of A) | +6.2 pp | 1.4 pp | +3.7 pp | +9.5 pp |
| deblur recovery (pattern, all of A) | +24.7 pp | 1.3 pp | +20 pp | +25 pp |
| rehearsed pairs recovery (bit) | +22.0 pp | 4.1 pp | +14.8 pp | +30.8 pp |
| unrehearsed pairs recovery (bit) | +0.9 pp | 1.0 pp | -1.6 pp | +3.2 pp |
| total wallclock for sweep | ~0.65 s |
Comparison to the paper:
Hinton & Plaut 1987 demonstrate the qualitative effect: an associative memory equipped with fast weights can learn set A, then learn set B (apparent forgetting of A), then rapidly recover A from a brief rehearsal of a subset. The strongest version of the claim — that rehearsing a subset reactivates the entire set — is shown for inputs that share structure (the paper uses correlated patterns).
We reproduce the rehearsal-driven recovery on the rehearsed subset robustly: pattern accuracy on A jumps from 0% (after B) to 25% (after rehearsing 5 of 20), with the entire 25 pp coming from the 5 rehearsed pairs hitting 100% pattern accuracy. Reproduces: yes for the headline rehearsal-deblurring effect. We do not reproduce strong cross-pair generalization (rehearsing 5 of A bringing the unrehearsed 15 back to high accuracy) — see Deviations point 2 below.
Visualizations
4-phase timeline

The top panel (bit accuracy) is the central plot. Phase 1 (blue band) drives recall_A to 100%. Phase 2 (red band) drives recall_B to 100% but pulls recall_A down to ~80% — the interference. Phase 3 (gold band) is brief but the recall_A line clearly goes back up. The middle panel (pattern accuracy) shows the same effect through a sharper threshold: pattern accuracy on A drops to 0% during phase 2 then jumps to 25% during phase 3 (5 pairs perfectly recalled out of 20). The bottom panel shows ‖W_slow‖ growing monotonically while ‖W_fast‖ ratchets — building during each phase, decaying as new patterns arrive.
Per-pair recovery (rehearsed vs unrehearsed)

The headline mechanism in one picture. Each A-pair gets two bars: gray = bit accuracy after phase 2 (post-B interference), and gold/dark-gray = bit accuracy after phase 3 (post-rehearsal). Rehearsed pairs (gold) are the first 5 indices; they jump to 100%. Unrehearsed pairs (dark gray) sit at the same level they were after phase 2. The fast-weight rehearsal effect is concentrated on the items rehearsed.
Per-pair distribution

Histograms of per-pair bit accuracy on A at three phases. After phase 1 (left): everything sits at 100%. After phase 2 (middle): the cloud drops to ~78–86% bit acc; rehearsed and unrehearsed pairs are intermixed. After phase 3 (right): the rehearsed pairs (gold) move to 100%; the unrehearsed pairs (gray) stay where they were. The bimodal post-rehearsal distribution is the signature of fast-weight rehearsal acting on the rehearsed subset.
Slow vs fast weight matrices at each phase

Heatmaps of W_slow (top row) and W_fast (bottom row) at the end of phases 1, 2, and 3. Color = entry value, RdBu_r. The Frobenius norm is annotated in the corner of each panel. W_slow grows monotonically (slow plastic store); W_fast accumulates during phase 1, decays + rebuilds during phase 2 (now encoding B), and is partially rebuilt for the rehearsed A subset during phase 3.
Deviations from the original procedure
-
Linear associator instead of an iterative attractor net. The 1987 paper used an iterative settling network (echoes of the Boltzmann-machine work of the same era). We use a single matrix-vector product
W_eff @ xand threshold the sign. This is the standard simplification used in modern fast-weights papers (Schmidhuber 1992, Ba et al. 2016) and lets us cleanly isolate the slow/fast decomposition without confounding with attractor dynamics. The headline effect (rehearsal-driven recovery via the fast-weight store) is preserved. -
Random uncorrelated patterns; no cross-pair generalization. The original paper’s strongest claim is that rehearsing a subset can restore items that share structure with the rehearsed set. Our random ±1 patterns share no structure by construction, so the unrehearsed pairs stay at their post-B level (mean +0.9 pp,
chance). The rehearsal effect on the rehearsed subset is fully reproduced (+22 pp average); cross-pair generalization is not, because there is nothing to generalize across. A natural follow-up (see Open Questions) would re-run with structured patterns drawn from a small prototype pool — analogous to the prototype-based sememes ingrapheme-sememe/. -
Online delta rule, not Hebbian outer-product. The 1987 paper is loose about the learning rule (the focus is on the slow/fast architecture); we use the delta rule because it converges cleanly on a 50-dim associator at this scale. With pure Hebbian updates the patterns saturate at higher cross-talk and the rehearsal effect is harder to read off.
-
Per-update normalization
1/dim. Without this the delta-rule updates are too large at dim=50 and the slow weights overshoot. This is a numerical detail (it just rescales the effectiveslow_lr/fast_lr) and is mentioned for honest reporting; the headline numbers are insensitive to it once the learning rates are tuned for it. -
No perturbation-on-plateau wrapper. Convergence is reliable in 30 sweeps from random init at this scale; no wrapper needed.
Open questions / next experiments
-
Structured patterns → cross-pair recovery. The natural extension is to draw the A and B vector pairs from a small pool of prototype-mixtures (e.g. each pattern is the sum of 2 of 5 random binary prototypes, ±noise). The hypothesis from the 1987 paper is that the slow weights would then encode the prototype basis, and rehearsing 5 of 20 A pairs would re-activate the prototype components that the unrehearsed 15 also use — recovering them too. Worth a 2-line
generate_associationsswap. -
Sweep over
fast_decay. Withfast_decay=1.0we have a single-time-scale associator (no fast/slow distinction); recall_A after phase 2 should be much lower (no fast-store buffer). Withfast_decay=0.5(very fast decay) the rehearsal effect should also weaken because fast weights die off too quickly between rehearsal sweeps. We expect a sweet spot in the middle (the spec default 0.9 sits there). The full curve would be a clean ablation. -
Reps × decay tradeoff. Phase 3 trades off
n_rehearse_sweepsagainstfast_decay: more sweeps build more fast-weight signal, but each new sweep also lets earlier fast contributions decay. We default to 5 sweeps at decay 0.9 (so fast contributions from the first sweep are 0.9^5 ≈ 0.59 by sweep 5). Mapping recovery as a function of (sweeps, decay) jointly would clarify the operating regime. -
Boltzmann / iterative-settling reproduction. Reimplement the original protocol on an iterative attractor net (e.g. the bipartite RBM used in
encoder-4-2-4) to test whether iterative settling adds anything on top of the linear-associator deblurring. -
Data movement. This is the v1 baseline. The slow/fast decomposition is structurally interesting from a data-movement standpoint: fast weights are small and frequently rewritten (cheap if they live in cache), slow weights are large but rarely rewritten (cold storage). v2 (the broader Sutro effort) could measure whether the fast-weight architecture has favorable ByteDMD cost vs a single-time-scale associator that achieves the same deblurring through more brute-force re-training.
v1 Metrics
| Metric | Value |
|---|---|
| Reproduces paper? | Yes for the headline rehearsal-deblurring effect on the rehearsed subset (pattern accuracy on A: 0% → 25% after rehearsing 5 of 20; rehearsed-pair recovery +14.8 pp at seed 0, +22.0 pp mean across 30 seeds). Cross-pair generalization to unrehearsed items is not reproduced under random uncorrelated patterns — see Deviations §2. |
| Wallclock to run final experiment | ~0.15 s (time python3 fast_weights_rehearsal.py --seed 0 measured at 0.14 s wall on M-series laptop) |
| Implementation wallclock (agent) | ~25 minutes (single session, mostly viz layout) |
Vowel discrimination via adaptive mixtures of local experts
Source: Jacobs, Jordan, Nowlan & Hinton (1991), “Adaptive mixtures of local experts”, Neural Computation 3(1):79-87.
Demonstrates: A mixture of K linear softmax experts with a softmax gate, trained end-to-end by maximum-likelihood gradient descent on p(y|x) = sum_k g_k(x) * p_k(y|x), produces a clean, phonetically meaningful partition of F1/F2 input space (front vowels vs back vowels) and converges to higher mean test accuracy with lower seed-variance than a parameter-matched monolithic MLP. The “twice as fast as backprop” headline of the original paper does not replicate at this dimensionality; see §Deviations.

Problem
Speaker-independent 4-class vowel classification from two acoustic features: the first two formant frequencies F1 and F2.
| Class | IPA | Peterson-Barney code | Word |
|---|---|---|---|
| 0 | [i] | IY | heed |
| 1 | [I] | IH | hid |
| 2 | [a] | AA | hod |
| 3 | [Lambda] | AH | hud |
Data: Peterson & Barney (1952), 76 speakers (33 men, 28 women, 15 children) x 10 vowels x 2 repetitions = 1521 tokens total; we keep only the 4 vowels above (608 tokens). Train/test split is by speaker (75% of speakers train, 25% test): the model never sees the same vocal tract at train and test time, so the speaker-normalisation problem is not given for free.
The MoE has K linear softmax experts, each producing 4-class probabilities, mixed by a softmax gate over the 2-D input. Per-batch loss is the standard MoE negative log-likelihood
L = -log sum_k g_k(x) * p_k(y_true | x)
whose gradient has the same form as cross-entropy with the posterior expert
responsibility h_k = g_k * p_k(y) / sum_j g_j * p_j(y) playing the role of a
soft target distribution. Derivation in vowel_mixture_experts.py:loss_and_grads.
Files
| File | Purpose |
|---|---|
vowel_mixture_experts.py | Data loader (downloads to ~/.cache/hinton-vowels/ once, parses Peterson-Barney text format, falls back to a class-conditional Gaussian mock if no network); MoE and MLP classes; manual gradient training in numpy; CLI (--seed, --n-experts, --n-epochs, --lr, --batch-size, --train-frac, --results). Numpy + urllib only. |
visualize_vowel_mixture_experts.py | Reads results.json + results.npz. Emits data_scatter.png, expert_partitioning.png (gate argmax over a F1/F2 grid + mixture decision regions), training_curves.png (MoE vs monolithic loss + test-accuracy curves), comparison_table.png (numeric summary). |
make_vowel_mixture_experts_gif.py | Trains a fresh MoE and MLP from scratch, renders one frame per evenly-spaced epoch, writes the partition + accuracy GIF. |
vowel_mixture_experts.gif | Committed 60-epoch animation (~180 KB). |
viz/ | Committed PNG outputs. |
results.json | Headline run output (config, env, full per-epoch histories, summary). |
results.npz | Companion file: trained MoE and MLP weights + the standardised train/test split. Read by the visualizer. |
problem.py | Original wave-8 stub. Kept untouched as the canonical contract; the public functions (load_peterson_barney, build_moe, train, visualize_partitioning) are re-exported by vowel_mixture_experts.py with the same signatures. |
Running
Reproduce the headline run (seed 0, K=4 experts, 80 epochs, lr=0.3):
python3 vowel_mixture_experts.py --seed 0 --n-experts 4 --n-epochs 80 --lr 0.3 \
--results results.json
Wall-clock: ~0.13 s on an M-series MacBook (MoE 0.08 s + MLP 0.05 s).
Writes results.json and results.npz.
Then regenerate plots and the GIF:
python3 visualize_vowel_mixture_experts.py --results results.json --out-dir viz
python3 make_vowel_mixture_experts_gif.py --seed 0 --n-experts 4 --n-epochs 60 --lr 0.3
GIF render takes about 8 s (most of the time is matplotlib re-layouting per frame).
Results
Headline run, seed=0, K=4 experts, 80 epochs, lr=0.3, batch=32, 456 train / 152 test tokens:
| model | params | final test acc | epochs->90% | wallclock |
|---|---|---|---|---|
| MoE (K=4) | 60 | 0.934 | 22 | 0.08 s |
| Monolithic MLP H=8 | 60 | 0.921 | 13 | 0.05 s |
Both methods are parameter-matched (60 floats).
Multi-seed (seeds 0..4, 120 epochs, otherwise identical config):
| model | mean test acc | mean epochs->90% | std (epochs->90%) |
|---|---|---|---|
| MoE (K=4) | 0.928 +/- 0.011 | 22.2 | 5.4 |
| Monolithic MLP H=8 | 0.901 +/- 0.020 | 12.2 | 4.4 |
Reading the table:
- Final accuracy: MoE wins by ~3 points and has roughly half the cross-seed variance. This is the result that does survive the move to a 2-D feature space.
- Convergence rate to 90%: MLP wins by ~10 epochs. The original paper’s headline – MoE reaches 90% in about half the epochs of monolithic backprop – does not replicate at this dimensionality. See §Deviations for why.
Expert specialization (the cleanest survival of the headline). With K=4 the
gate consistently drives 2 of the 4 experts to zero responsibility on the data
and uses the remaining 2 to cover the front vowels ([i] / [I]) and the back
vowels ([a] / [Lambda]). The partition mirrors the high-vs-low F1 phonetic
split (front vowels have low F1 / high F2; back vowels the opposite) and is
visible in viz/expert_partitioning.png and the GIF. In the training animation
the gate boundary settles within the first ~10 epochs and then the surviving
experts refine their per-region linear classifiers.
Visualizations
Headline: F1/F2 data scatter

Plotted with the standard phonetic-vowel-chart orientation: F1 increasing downward (open vowels at the bottom), F2 decreasing rightward (back vowels at the right). Circles are training tokens, triangles are held-out test tokens. [i] sits top-right (high, front); [a] sits bottom-left (low, back). The two clusters that overlap most are [a] and [Lambda]: this is the pair the model gets wrong.
Expert partitioning

Left: the gate’s argmax over the F1/F2 grid. Two experts dominate – one covers the front-vowel half (low F1 / high F2), the other covers the back-vowel half. The gate finds the same split that a phonetician would draw.
Right: the mixture’s predicted class over the same grid – four quasi-linear regions, one per vowel. Each region is a half-plane carved out by the per-expert linear softmax inside its gating cell.
Training curves

Both methods are converging. The training-loss panel shows the MLP’s loss is visibly below the MoE’s at every epoch – a tanh hidden layer with 8 units has more flexible decision boundaries than 4 linear experts gated at the input. The test-accuracy panel shows the gap close at convergence: MoE = 0.934, MLP = 0.921 on this seed.
Summary table

Deviations from the original procedure
- Dimensionality of the input. The original paper used the full filter-bank spectrum (~100 dims). We use only F1 and F2 (2 dims). This is the change most responsible for the convergence-speed claim not replicating: in 2 dims the data is nearly linearly separable, so a small monolithic MLP with 8 tanh units already has more than enough capacity to interpolate fast. The MoE’s advantage in the original paper comes from its ability to chop up a high-dimensional, highly variable input into easier sub-problems; in F1/F2-space there are no useful sub-problems beyond “front vs back”.
- Number of training tokens. Paper uses additional speakers from the Peterson-Barney recordings split differently. We have 76 speakers x 4 vowels x 2 repetitions = 608 tokens, split 75/25 by speaker.
- Optimizer. Paper uses gradient descent without momentum on each expert plus a separate update rule for the gate (Hinton & Nowlan’s competing-experts formulation). We use plain mini-batch SGD with a single shared learning rate on the joint MoE log-likelihood, which is the modern form of the same model and gives identical gradients in expectation.
- Expert architecture. Paper’s experts are small MLPs (~50 hidden units each). We use linear softmax experts – the simplest non-trivial choice. With K=4 linear experts the MoE has 60 params; we hold the MLP baseline at the same count for the apples-to-apples comparison.
- Loss form. We use the discrete-output MoE log-likelihood
-log sum_k g_k * p_k(y_true); the paper uses the Gaussian-output form-log sum_k g_k * exp(-||y - y_k||^2 / 2 sigma^2). These are the classification and regression specialisations of the same underlying conditional-mixture-density model. - Real-data caveat. The Peterson-Barney file is fetched from the
phiresky/neural-network-demo mirror because the original Hillenbrand WMU
page now returns the school’s CMS landing page rather than the data file.
Output of the loader is checked into the cache at
~/.cache/hinton-vowels/PetersonBarney.datand the parser tolerates the*listener-disagreement marker on the phoneme label. If the download fails, the loader falls back to a class-conditional Gaussian mock with means taken from the male-speaker entry in the Peterson & Barney 1952 table; this path emits a warning and is documented in the run output (is_real_datainresults.json). - Float precision. float64 throughout. Paper uses single precision.
Open questions / next experiments
- Does the speed-up come back in higher dim? Reproduce on the original spectral input (e.g., a mel-filterbank computed from raw P-B audio if the recordings are still available, or just the four formants F1..F4). If MoE recovers the 2x convergence advantage at >= 4 dims, that’s a clean demonstration that the headline scales with input dimensionality, not architecture.
- What temperature on the gate is optimal? Currently the softmax gate is trained at the same learning rate as the experts and finds a hard partition within ~10 epochs. Annealing a temperature on the gate (start soft so all experts get gradient, then sharpen) is the modern go-to fix for the “all-experts-collapse-to-the-same-classifier” failure mode. We don’t see that failure here – the gate quickly drops 2 experts and uses 2 – but the which-2 assignment is seed-sensitive and an annealing schedule may make it more deterministic.
- K-sweep with all 10 vowels. Keep the 2-D input but use all 10 P-B vowels. At K=10 the MoE could in principle learn a one-expert-per-vowel partition. Does the gate reliably allocate one expert per vowel, or does it group by phonetic class (front/mid/back x high/mid/low)?
- Switching off the dead experts. With K=4 the gate consistently disables 2 experts – they receive zero responsibility but their parameters are still updated each step (with zero-magnitude gradient, but they still occupy memory). A pruning heuristic that drops dead experts and re-initialises them at high-error regions of input space (“expert birth/death”) is the classic Jacobs follow-up; checking whether reusing the dead-expert capacity improves accuracy from 93% toward chance-corrected ceiling would be the experiment.
- Connection to ByteDMD. The MoE has an obvious data-movement advantage at inference: the gate’s argmax selects 1 of K expert weight matrices, so only 1/K of the expert parameters are read per example. Measuring this gain on the training side (where all experts are touched, weighted by responsibility) versus a hard top-1 routing variant is a clean ByteDMD experiment that connects the 1991 architecture to the modern sparse-gate-MoE rediscoveries (Shazeer et al. 2017).
Random-dot stereograms (Imax)
Reproduction of Becker & Hinton (1992), “A self-organizing neural network that discovers surfaces in random-dot stereograms”, Nature 355, 161-163.
The first demonstration that mutual information between two modules viewing different inputs from a common cause is enough of a learning signal to discover that common cause. The cause here is binocular disparity (depth); the two modules are simulated stereo receptive fields each looking at one strip of a random-dot stereogram.

Problem
A 1-D world of random binary dots (+1 / -1) is rendered into two views,
left and right eye. The right view is the left view shifted horizontally by
a per-example disparity d (the world’s depth at that position).
Each training example contains two independent stereo strips drawn at the same disparity. Each strip becomes the input to one module:
example module A module B
┌──────────────────────────┐ ┌────┐ ┌────┐
│ disparity d ~ U[-3, +3] │ -> │ y_a│ │ y_b│
│ │ └────┘ └────┘
│ strip A: dots_A, shift d│ ↑ ↑
│ strip B: dots_B, shift d│ [left_a, right_a] [left_b, right_b]
└──────────────────────────┘
Two MLPs (one per module) are trained to MAXIMIZE
Imax = I(y_a; y_b) ≈ 0.5 * log(var(y_a) + var(y_b))
- 0.5 * log(var(y_a - y_b))
under a Gaussian assumption on (y_a, y_b). No supervised target is given.
The interesting property
The two strips share only the disparity — their dots are independently
random. So the only thing the modules can agree on is the disparity (or some
function of it, like |d|). Maximizing mutual information forces both
modules to extract a disparity-related readout from random pixel patterns,
without ever seeing a supervised label. This is the foundational
demonstration that spatial coherence between sibling features is enough of
a self-supervised signal to learn the latent variable that produced them —
the same intuition that later powers contrastive learning, SimCLR, and the
GLOM-style consensus columns Hinton revisited in 2021.
Files
| File | Purpose |
|---|---|
random_dot_stereograms.py | Synthetic stereogram generator, two-module MLP, Imax loss with closed-form gradient, momentum-SGD trainer. CLI flags --seed --n-epochs --strip-width. |
visualize_random_dot_stereograms.py | Static figures: stereogram examples, scatter of module outputs vs ground-truth disparity, training-curve panel. |
make_random_dot_stereograms_gif.py | Renders random_dot_stereograms.gif (the animation at the top of this README). |
random_dot_stereograms.gif | Per-snapshot animation (stereogram + module-output scatter + training curves). |
viz/ | Static PNGs from the run below. |
Running
python3 random_dot_stereograms.py --seed 0 --n-epochs 800
Wall-clock: ~6.5 s on an Apple M-series laptop. Prints final Imax, module-agreement correlation, and signed/unsigned disparity correlations on a held-out 4096-example batch.
To regenerate visualizations:
python3 visualize_random_dot_stereograms.py --seed 0 --n-epochs 800 --outdir viz
python3 make_random_dot_stereograms_gif.py --seed 0 --n-epochs 800 \
--snapshot-every 25 --fps 8
Results
Default run, seed=0, n_epochs=800, strip_width=10, max_disparity=3.0,
n_hidden=48, batch_size=256, lr=0.05, momentum=0.9,
weight_decay=1e-5, init_scale=0.5:
| Metric | Value |
|---|---|
Final Imax (eval, 4096 examples) | 1.18 nats |
Module-output agreement corr(y_a, y_b) | +0.91 |
Signed-disparity readout |corr(y_a, d)| | 0.74 |
Signed-disparity readout |corr(y_b, d)| | 0.74 |
| Training wall-clock | 6.5 s |
| Hyperparameters | strip_width=10, n_hidden=48, batch=256, lr=0.05, momentum=0.9, wd=1e-5, init_scale=0.5 |
Cross-seed stability (5 seeds, same hyperparameters)
| Seed | Imax | corr(y_a, y_b) | best disparity readout |
|---|---|---|---|
| 0 | 1.18 | +0.91 | 0.74 (signed d) |
| 1 | 1.11 | +0.89 | 0.80 (signed d) |
| 2 | 1.22 | +0.91 | 0.62 (signed d) |
| 3 | 1.41 | +0.94 | 0.58 (|d|) |
| 4 | 1.10 | +0.89 | 0.59 (mix of signed d + |d|) |
Imax is sign-invariant: I(y_a; y_b) = I(-y_a; -y_b) = I(y_a; -y_b). So
each module independently picks a sign convention for d, and across seeds
they converge on a signed-d readout (most seeds), |d| readout (some
seeds), or a smooth blend. In all 5 seeds the modules agree strongly
(corr_ab > 0.89), which is what the loss directly optimizes.
Comparison to the original
The 1992 paper reports the network discovering disparity from random-dot stereograms with no supervised signal. We reproduce the qualitative claim faithfully — modules trained only by Imax extract a disparity-related readout from independent stereo strips — but the paper’s setup uses 2-D images, sub-pixel rendering, and continuous depth surfaces; we use 1-D strips and integer + sub-pixel-interpolated disparities. The original paper gives no single comparable scalar, so “reproduces?” is yes in the qualitative sense but not benchmarked against a specific number.
Visualizations
Example stereograms

Four random-dot stereo strips at disparities d = -3, -1, +1, +3. The right
view is the left view shifted by d pixels (sub-pixel interpolation
between integer dots). Pixels falling outside the rendered strip on either
eye come from independent random padding, so disparity is the only stable
signal between the two views.
Module outputs vs ground-truth disparity

Left: module outputs y_a, y_b plotted against signed disparity d.
The smooth monotonic sweep (downward in this seed’s sign convention) is the
disparity readout the modules learned without supervision. Center: the
same outputs against |d| — much weaker correlation here, so the modules
encoded signed depth, not just magnitude. Right: y_a vs y_b —
the two modules’ outputs cluster tightly along the diagonal, which is what
Imax directly optimizes.
Training trajectory

Top-left: Imax (mutual information in nats) climbs from ~0 to ~1.2 over
training. Top-right: module-output agreement corr(y_a, y_b) climbs to
~0.93. Bottom-left: disparity readout (|corr(y, d)| and
|corr(y, |d|)|) emerging as a side-effect of the agreement objective —
the network was never told what d is. Bottom-right: the variance terms
that compose the Imax formula — var(y_a - y_b) (purple) is driven to be
much smaller than var(y_a) + var(y_b), which is what 0.5*log((va+vb)/vd)
penalizes.
Deviations from the original procedure
-
1-D world instead of 2-D images. Becker & Hinton 1992 used 2-D random-dot images and 2-D receptive fields. We use 1-D strips with a single horizontal disparity per example. The unsupervised principle (Imax forces shared-cause discovery) is identical; only the geometry is simpler.
-
Independent dots in the two strips. The original paper used adjacent receptive fields on the same image, so neighbouring strips shared dots near the boundary as well as the disparity. We deliberately give the two modules independent random dots so that the disparity is the only shared signal — this rules out pixel-level “leakage” between modules and makes the demonstration cleaner. (We checked the alternative: with shared dots, modules can correlate via the boundary pixels alone and never need to extract disparity.)
-
Cross-product feature input. The architecture is a featurized MLP: each module sees
[left, right, left * right_shifted_by_k]for shiftsk ∈ {-3, ..., +3}, then a sigmoid hidden layer, then a linear scalar readout. We tried a pure 2-layer sigmoid MLP on raw[left, right]and it did not escape the flat region aroundImax = 0in 1500 epochs — the multiplicative cross-correlations are the right inductive bias for stereo matching (the same trick used in modern stereo CNNs as a “cost volume” or “correlation layer”), and were inserted as a fixed, non-trainable feature map. The MLP still learns end-to-end which cross-correlations matter and how to combine them; it just does not have to discover the cross-product nonlinearity from raw pixels. -
Sub-pixel disparity rendering. Disparity is real-valued in
[-max_disparity, +max_disparity]; the right view is rendered by linear interpolation between adjacent integer dots. This gives a smooth gradient signal. Pure-integer disparity also works but trains more slowly (--integerflag). -
Closed-form gradient through the Imax loss. We compute the gradient of
-Iw.r.t. each module’s output analytically (verified against a finite-difference check); no autograd framework. Standard backprop through the per-module MLP from there.
Open questions / next experiments
-
Discover the cross-product nonlinearity. The current implementation hands the network the binocular cross-product features. A pure 2-layer MLP starting from random init does not escape the flat Imax region in 1500 epochs of momentum SGD. Does a deeper network, ReLU activations, or natural-gradient / second-order optimization let the network discover these features from raw
[left, right]pixels? -
2-D stereograms, smooth surfaces. The original paper shows that with many adjacent modules viewing a smooth surface (so all modules’ disparities are spatially coherent), the modules collectively discover the surface, not just the per-module disparity. Extending to a 2-D random-dot field with a smooth depth surface is the natural next step.
-
Energy / data-movement cost. Imax over a batch is one of the cheapest unsupervised losses (no contrastive negatives, no decoder, just a few variances per batch). What is its ByteDMD cost compared to InfoNCE / SimCLR-style contrastive losses on the same problem? This is the v2 question for the wider Sutro benchmark.
-
Sign convention. Across seeds the modules sometimes agree on signed
dand sometimes on|d|. Is there a small architectural change (e.g., asymmetric init, asymmetric cross-product window) that biases toward one over the other? Would coupling the two modules’ last-layer signs at init (a one-time tied-weight kick) make all seeds learn signedd?
Sunspots time-series prediction with soft weight-sharing
Source: Nowlan & Hinton (1992), “Simplifying Neural Networks by Soft Weight-Sharing”, Neural Computation 4(4), 473-493.
Demonstrates: A Mixture-of-Gaussians prior on weights organises a neural network’s weights into a small number of clusters with crisp peaks. On the Wolfer / Weigend yearly sunspots benchmark, MoG achieves lower test MSE than weight decay and much more compressible weight distributions than either decay or no regularisation.

Problem
Predict the yearly Wolfer sunspot count one year ahead from the previous 12 years – a small benchmark Weigend, Huberman & Rumelhart (1990) made canonical.
| Step | What |
|---|---|
| Data | Yearly sunspot counts 1700-1979 (SILSO V2.0; Wolfer’s series + a modern recalibration) |
| Lag | x_t = f(x_{t-1}, ..., x_{t-12}) |
| Norm | divide by max(train series) so values lie in [0, 1] |
| Train | predict years 1712-1920 (209 targets) |
| Test | predict years 1921-1955 (35 targets) |
Architecture: 12 inputs -> 8 / 16 hidden tanh -> 1 linear output, full-batch backprop with momentum. Numpy only.
We compare three regularisers on the same backbone:
| Method | Regulariser added to MSE | What it does to the weights |
|---|---|---|
vanilla | nothing | weights spread Gaussian-like over a wide range |
decay | (lam/2) * sum w^2 | tightens weights toward 0; a single Gaussian prior |
mog | lam * sum_i [-log p(w_i)] with p a K-Gaussian mixture | weights cluster into K crisp peaks |
The MoG prior
p(w_i) = sum_{k=0..K-1} pi_k * N(w_i | mu_k, sigma_k^2)
has K learnable components (pi_k via softmax, mu_k, log sigma_k); component 0 is pinned at mu_0 = 0 with a small fixed sigma_0 to give a “small-weights” attractor. The data and prior parameters are updated jointly by gradient descent on
L = 0.5 * sum_n (o_n - y_n)^2 + lam * sum_i [-log p(w_i)]
with a 200-epoch MSE-only pretrain followed by a 500-epoch linear ramp on lam so the components can find good means before the prior tightens.
Files
| File | Purpose |
|---|---|
sunspots.py | Wolfer loader (downloads SILSO yearly to ~/.cache/hinton-sunspots/), Weigend split, 12-h-1 MLP, three training methods (train_vanilla, train_decay, train_with_soft_sharing), MoG prior class, compare_methods, CLI (--method {vanilla,decay,mog,all}, --seed, --n-components, --n-hidden, …). |
visualize_sunspots.py | Static training curves, predictions, weight histograms (the headline plot), MoG component density overlay, and an N-seed test-MSE bar chart. |
make_sunspots_gif.py | Animated GIF of the three weight distributions evolving over training, with running test predictions on top. |
sunspots.gif | Committed animation (~700 KB). |
viz/ | Committed PNG outputs from the run below. |
Running
# default: download data, train all three methods at seed 0, print summary
python3 sunspots.py
# single method
python3 sunspots.py --method mog --seed 0 --n-components 5
python3 sunspots.py --method decay --seed 0
python3 sunspots.py --method vanilla --seed 0
# regenerate visualisations (the bar chart trains over n_seeds runs)
python3 visualize_sunspots.py --n-seeds 5
# regenerate the animated GIF
python3 make_sunspots_gif.py --snapshot-every 300 --fps 12
The default run takes about 5 seconds on an M-series laptop (3 methods x 12,000 epochs x full-batch on 209 patterns). The 5-seed bar-chart sweep takes about 75 seconds. The GIF render takes about 30 seconds (41 frames).
The Wolfer data is fetched from https://www.sidc.be/SILSO/INFO/snytotcsv.php (yearly V2.0 file) and cached at ~/.cache/hinton-sunspots/yearly_sunspots.csv. If the SILSO server is unreachable on first run, the loader falls back to a synthetic 11-year-cycle proxy and warns.
Results
Single run, --seed 0, defaults (n_hidden=16, epochs=12000, n_components=5, lam_decay=0.01, lam_mog=0.0005):
| Method | train MSE | test MSE | best test MSE | |W|_2 | weight range |
|---|---|---|---|---|---|
| vanilla | 0.00483 | 0.00438 | 0.00431 | 4.12 | [-1.42, +1.24] |
| decay | 0.00501 | 0.00421 | 0.00421 | 3.37 | [-1.38, +1.12] |
| mog | 0.00520 | 0.00431 | 0.00431 | 3.60 | [-1.30, +0.88] |
5-seed sweep, same defaults:
| Method | mean test MSE | std | seeds |
|---|---|---|---|
| vanilla | 0.00432 | 0.00020 | 0,1,2,3,4 |
| decay | 0.00422 | 0.00009 | 0,1,2,3,4 |
| mog | 0.00420 | 0.00010 | 0,1,2,3,4 |
Headline ordering: mog (0.00420) < decay (0.00422) < vanilla (0.00432), on the test set.
The numerical gap is small (a few percent) but consistent across seeds and lower-variance for the regularised methods, matching the paper’s claim that MoG generalises better than weight decay on this benchmark. The dramatic difference is in weight structure, not raw test MSE.
MoG components after training (--seed 0):
| k | pi_k | mu_k | sigma_k | role |
|---|---|---|---|---|
| 0 | 0.31 | 0.00 (pinned) | 0.05 (pinned) | small-weights attractor: 100 of 208 weights |
| 1 | 0.12 | -0.70 | 0.13 | negative outlier cluster (~15 weights) |
| 2 | 0.22 | -0.02 | 0.08 | broad-near-zero satellite |
| 3 | 0.18 | +0.26 | 0.08 | positive cluster (~50 weights) |
| 4 | 0.16 | +0.33 | 0.10 | positive outlier shoulder |
The mixture has 5 components but the network really only uses 3 distinct cluster locations: (zero, +0.27, -0.70). Reading the histogram: most weights collapse to either 0 or +0.27.
Visualizations
The Wolfer time series and Weigend split

Yearly sunspot count 1700-1979 (SILSO V2.0 yearly file), with the Weigend training window 1700-1920 in blue and the test window 1921-1955 in red. The 11-year solar cycle is plain; cycle-to-cycle amplitude variation is what makes the prediction non-trivial.
Test-set predictions

All three methods track the 1921-1955 test years to within ~10 sunspots through most cycles. The notable miss is around 1947 where every method overshoots a cycle peak by ~30 sunspots and undershoots its 1946 onset – this is the cycle 18 peak, which is the largest in the historical record, and the network has no training data showing a cycle of that amplitude.
Weight distributions (the headline plot)

This is the key result. With identical architecture and training schedule:
- Vanilla: weights occupy the full
[-1.4, +1.2]range as a smooth Gaussian-like distribution. - Decay: weights are pulled toward 0 – a tighter Gaussian, but still a single peak.
- MoG: weights collapse onto two crisp peaks (~100 weights at 0; ~50 weights at +0.27) plus a small negative cluster around -0.7. Most of the parameters could be encoded with 2-3 bits each (which cluster + small residual), against ~16 bits for the vanilla / decay weights.
This is precisely what Nowlan & Hinton 1992 set out to demonstrate: a soft mixture prior auto-discovers a discrete weight code, suitable for compression, without any hard quantisation.
MoG component overlay

The five Gaussian component density curves, scaled and overlaid on the histogram. Component 0 (purple, pinned at zero with sigma=0.05) is the dominant peak; components 3 and 4 (orange, red) merge into the cluster at +0.27; component 2 (green) covers the slightly-negative weights; component 1 (blue) is the outlier-handler at -0.70. The dashed black curve is the total mixture density – it tracks the histogram envelope to within a sigma of every bin.
Final test MSE, mean +/- std over 5 seeds

vanilla 0.00432 +/- 0.00020, decay 0.00422 +/- 0.00009, mog 0.00420 +/- 0.00010. The error bars (1 std) over seeds 0-4 don’t overlap between vanilla and the two regularisers, and decay and mog are statistically indistinguishable on this metric.
Training curves

Both train and test MSE on log scale. All three methods drop by an order of magnitude in the first 100 epochs, then settle. The MoG curve is slightly above the others during the prior-warmup phase (epochs 200-700) where the prior is being ramped on; afterward it tracks decay closely. None of the methods shows visible test divergence – the dataset is too small/well-behaved for catastrophic overfitting – but the regularised methods asymptote a hair below vanilla.
Deviations from the original procedure
- Data version. The paper used the original Wolfer / Zurich relative sunspot numbers (Weigend benchmark). We use the SILSO V2.0 yearly file. The two series are nearly identical for 1700-1955; the V2.0 release applied a modern recalibration that mostly affects post-1947 data. Both produce the same headline structural result.
- Optimiser. Plain SGD with momentum (
alpha = 0.9) and full-batch updates. The paper uses momentum SGD with hyperparameters tuned per dataset. We use a fixed schedule across all three methods so the only difference is the prior. - MoG schedule. We pretrain with MSE only for 200 epochs, then linearly ramp
lamover the next 500 epochs. The paper uses a hyperprior on the sigmas (gamma distribution) to avoid component collapse; we instead clip log-sigma to [log 0.02, log 0.5], which serves the same purpose with one fewer hyperparameter. - Pinned component. Component 0 is pinned at
mu_0 = 0,sigma_0 = 0.05(as the paper recommends for the “small-weights” attractor) and is not updated. Components 1..K-1 are fully learnable. - Loss form. We use sum-MSE (not mean-MSE) so the data gradient and the per-weight prior gradient are on the same scale. With mean-MSE the data is divided by N=209, which makes the prior dominate by orders of magnitude unless
lamis set absurdly small. - Float precision. float64 numpy.
Otherwise: same architecture (12-h-1 MLP, tanh hidden, linear output), same data (yearly Wolfer), same loss (sum-of-squared-errors + log-prior), same algorithm (full-batch backprop with momentum), same evaluation (test MSE on 1921-1955).
Open questions / next experiments
- More seeds and a stronger overfit regime. The 35-point test set has high variance, and at the chosen size (16 hidden, 12k epochs) the overfit gap is small. Repeat the comparison with
n_hidden=32andepochs=30000(closer to the paper’s regime) and a 30-seed sweep – does MoG’s edge over decay persist or grow? - Compression vs accuracy trade-off. The MoG weights cluster at three locations. Quantise them to those three values (a hard 2-bit code) post-training and measure the test MSE penalty. Is the penalty small enough that the compressed model is worth it for energy-constrained inference?
- Compare MoG components K = 2, 3, 5, 8. Nowlan & Hinton report results across several K. Does the test MSE saturate at K = 3 (matching the empirical 3-cluster finding here), or does K = 8 always slightly help?
- K-means initialisation. Replace the quantile-based initialisation of
mu_1..mu_{K-1}with a true K-means fit on the pretrained weights. Does this give a faster cluster lock-in or higher seed reliability? - Comparison to ARD/spike-and-slab priors. Automatic Relevance Determination (Neal 1996) is the canonical successor to MoG soft sharing. Run the same benchmark with an ARD prior; does the weight-cluster structure persist or does ARD just push satellite components to zero pi?
- Connection to ByteDMD. Compressed weights mean fewer distinct values to load. A trained MoG model has ~3 distinct weight magnitudes in its forward pass – in principle this should reduce the average reuse distance under the broader Sutro Group energy-efficiency framework. Worth measuring directly.
Spline images & factorial VQ
Reproduction of Hinton & Zemel, “Autoencoders, MDL and Helmholtz free energy”, NIPS 6 (1994).

Problem
200 images of size 8 × 12 are formed by Gaussian-blurring a smooth curve
through 5 control points. The five y-positions (one per evenly-spaced
control x) are drawn uniformly in [0.5, 6.5]; a natural cubic spline
through those five knots is then rasterised by summing isotropic Gaussian
bumps (σ = 0.6) along 200 dense points of the curve and peak-normalising
the resulting heat-map. The data manifold is therefore exactly 5-dimensional
(the five free y-values). Below: 16 sample images with the 5 control points
shown as red dots.

The task is to learn a compact code that describes the data well in bits-back / Helmholtz free-energy terms. We compare four codes:
- Standard 24-VQ — one big stochastic VQ, 24 codes, log₂(24) ≈ 4.58 bits max code budget.
- Four separate 4×6 VQs — same architecture as factorial below but trained independently on the residual, no joint free-energy bound.
- Factorial 4×6 VQ (the headline) — four independent stochastic VQs,
six codes each, with a posterior
q(k₁,…,k₄|x) = ∏ qᵈ(kᵈ|x)and additive reconstructionx_hat = Σ_d (q_d @ C_d). 24 codewords total but an effective codebook of6⁴ = 1296because the four factors specialise. - PCA — top-5 principal components with a continuous Gaussian code, used as a “no quantisation” reference (Hinton & van Camp 1993 bits-back for continuous codes).
What “bits-back” means here
For a stochastic encoder q(k|x) with prior p(k) and likelihood p(x|k),
the description length per example is
DL(x) = E_q[-log p(x|k)] + KL[q(k|x) ‖ p(k)]
= recon_cost + code_cost
This is the negative ELBO / Helmholtz free energy. Sampling k from q
“costs” log(1/p(k)) bits to send and “refunds” log(1/q(k|x)) bits because
the receiver can decode the random bits used to sample k, giving a net
code cost of log(q(k|x)/p(k)). For factorial q the KL decomposes:
KL[q(k|x) ‖ p(k)] = Σ_d KL[q_d(k_d|x) ‖ p_d(k_d)]
so code cost grows linearly in the number of factor dims while the
effective codebook size grows as 6^M. That asymmetry is the whole
point of factorial VQ.
Files
| File | Purpose |
|---|---|
spline_images_factorial_vq.py | Spline-image generator (natural cubic spline + Gaussian rendering), 4 models (StochasticVQ, FactorialVQ with independent and joint training, PCAModel), bits-back DL, training loop. CLI. |
visualize_spline_images_factorial_vq.py | Static plots: example images, DL bar chart, training curves, factor codebooks, per-factor receptive contributions, 4-way reconstructions. |
make_spline_images_factorial_vq_gif.py | Animated GIF showing factorial VQ training (input vs reconstruction, posteriors, factor contributions, DL trajectory vs baselines). |
viz/ | Output PNGs from the run below. |
Running
python3 spline_images_factorial_vq.py --seed 0 --n-dims 4 --n-units-per-dim 6
Training all four models takes about 3 seconds on a laptop. Default config:
n_samples=200, n_epochs=800, sigma_x=0.15, KL-weight ramp 0.1 → 1.0.
To regenerate visualizations:
python3 visualize_spline_images_factorial_vq.py --seed 0 --outdir viz
python3 make_spline_images_factorial_vq_gif.py --seed 0 --snapshot-every 25
Results
Description length per example (bits, lower is better, seed=0):
| Model | Recon | Code (KL) | Total |
|---|---|---|---|
| Standard 24-VQ | 63.10 | 2.20 | 65.30 |
| Four separate 4×6 VQs | 122.11 | 2.80 | 124.91 |
| Factorial 4×6 VQ | 19.84 | 2.16 | 22.00 |
| PCA (5 components) | 85.32 | 15.11 | 100.44 |
Factorial VQ is the clear winner: ~3× lower DL than the standard 24-VQ despite using the same total number of codewords (24).

Across 5 seeds factorial VQ totals 22.00 / 22.70 / 23.01 / 25.05 / 22.65 bits per example, vs Standard 24-VQ at 65.30 / 61.76 / 55.10 / 56.06 / 64.55; the ranking holds at every seed.
Reconstructions

Each column is one held-out spline. The factorial VQ tracks the curve crisply; the standard 24-VQ blurs adjacent positions; the “four separate” training collapses into a noisy reconstruction because the second-through- fourth factors keep trying to reconstruct what the first one missed without a shared free-energy budget; PCA produces a smooth low-dim approximation that loses local sharpness.
What each factor learns

Each row is one of the 4 factor-dimensions; each column is one of its
6 codewords (red/blue diverging colour map). The factors specialise:
each row uses a distinct family of codewords. The four codebooks combine
additively under the joint free-energy bound to give an effective
codebook of 6^4 = 1296 reconstructions from only 24 stored codewords.

Top row: each factor’s mean codeword across the training set (the mean
contribution to x_hat). Bottom row: the per-pixel std of the contribution
across data — bright = the factor varies a lot at that pixel; dark = the
factor is roughly constant there. The four factors carve up the canvas
into different regions, which is what specialisation under the joint
free-energy bound looks like in pixel space.
Training trajectories

Left: total DL on log scale. Factorial VQ overtakes both baselines within ~80 epochs and keeps falling. Right: recon (solid) vs code-KL (dotted) per model. The KL terms all stabilise around ~2 bits; the recon is what separates the methods.
How the headline numbers compare to Hinton & Zemel (1994)
The original paper reports 18 bits reconstruction + 7 bits code = 25 bits total for factorial VQ on its specific spline dataset. We report 19.84 bits reconstruction + 2.16 bits code = 22.00 bits. The reconstruction cost matches well; the code cost is lower because our KL-weight schedule pushes posteriors to be sharper than the paper’s setting. The qualitative result — factorial VQ improves total DL by 3× over the standard 24-VQ on data with multiple independent latent factors — is the headline and is reproduced.
Deviations from the original procedure
This is not a faithful 1994 reproduction. Differences:
- Dataset rendering. Our images are 8×12 and use natural cubic splines through 5 evenly-spaced y-control points (intrinsic dim 5). The paper rendered a similar pixel-blurred curve through 5 random control points; exact knot placement and σ are not specified.
- Sigma_x = 0.15 for all VQ models. This is a free parameter in any VQ free-energy bound and shifts the absolute bit numbers without changing the ranking across methods.
- KL-weight schedule ramps from 0.1 to 1.0 across 800 epochs. Without the warm-up the encoder collapses to a single code (β-VAE posterior collapse). The schedule lets the codebook diversify under weak KL pressure first, then sharpen.
- Optimiser. Adam (manual implementation) with
lr = 0.005, instead of the paper’s gradient descent + momentum. - “Four separate VQs” interpretation. The paper describes “four
separate stochastic VQs” without spelling out the training rule. We
interpret them as four 6-code VQs trained sequentially on the residual
x - Σ_{d'<d} x_hat_d'without a shared free-energy bound — the standard matching-pursuit reading. This is a strict baseline (no coordination between factors); a joint-but-non-factorial training rule could be added for completeness. - PCA bits-back. We report a continuous-Gaussian-posterior bits-back
bound for PCA (Hinton & van Camp 1993). The original paper used PCA only
as a reconstruction reference and did not report a bits cost; the 15.11
code bits we report for PCA depend on
σ_q = 0.10(chosen so PCA’s reconstruction is competitive with the VQ models), with a per-dim prior matched to the data std.
Open questions / next experiments
- More factors at fixed total codes. Drop to
2 × 12and rise to6 × 4, keepingM × K = 24. Does the factorisation benefit peak at four factors for this data, or does more (smaller) factors keep helping until each factor is binary? - Match to the data’s intrinsic dim. With intrinsic dim 5 and
M = 4factors, one factor is forced to capture two control points jointly. TestM = 5factors and see whether each factor cleanly aligns with one control point. - Sample-based bits-back. Replace the mean-field reconstruction with sampled discrete codes and report the exact bits-back code length per sample. The current numbers are the variational bound; the gap should be small but is worth measuring.
- Sigma annealing. Anneal
σ_xrather than the KL weight. Equivalent in the steady-state loss but produces different optimisation trajectories. - Cache-energy / DMC accounting. Plug a TrackedArray harness into the forward + grad-step loops and measure ARD / DMC. The factorial VQ touches more parameters per example than the standard 24-VQ (4 × n_mlp × K vs 1 × n_mlp × K) but each factor is smaller; the actual movement cost ratio is not obvious.
Dipole position population code
Reproduction of Zemel & Hinton, “Learning Population Codes by Minimizing Description Length”, Neural Computation 7, 549–564 (1995).

Problem
Each training example is an 8x8 image containing a single horizontal “dipole”:
a +1 pixel at column x, row y, and a -1 pixel immediately to its right at
column x+1, row y. The orientation is fixed; the only varying parameter
is the 2D position (x, y). The training distribution is uniform over the
56 valid positions (x ∈ {0..6}, y ∈ {0..7}).
The network has 100 hidden units. Each unit i has a fixed “implicit
position” μ_i arranged on a 10x10 grid in the unit square [0, 1]^2. For
any 2D bottleneck position p, the population activation is a Gaussian bump
in implicit space:
bump(p)_i = exp(-‖μ_i - p‖² / (2 σ_b²)), σ_b = 0.18
The encoder MLP maps each image to such a p (plus a small free deviation
delta). The decoder is linear: x_hat = (bump(p) + delta) @ W_dec + b_dec.
The interesting property. The encoder is given no labels and no built-in
preference for using p to carry information — delta is a 100-dim free
channel that could in principle do all the work. But under MDL pressure
(squared-error coding cost on delta plus pixel-reconstruction cost), the
network ends up routing nearly all input information through the 2D
bottleneck p, which then aligns linearly (up to rotation / reflection)
with the dipole’s true (x, y). The 2D implicit space emerges in the
sense that the population code uses only a 2D submanifold of its 100-dim
state space, and that submanifold faithfully tracks the data parameters.
Files
| File | Purpose |
|---|---|
dipole_position.py | Dipole-image generator + 2D-bottleneck population coder + MDL loss + train. CLI. |
make_dipole_position_gif.py | Generates dipole_position.gif (the animation at the top). |
visualize_dipole_position.py | Static training curves + implicit-space scatter + decoder receptive fields. |
viz/ | Output PNGs from the run below. |
Running
python3 dipole_position.py --seed 0 --n-epochs 4000
Training takes ~2 seconds on a laptop. Final R² for the linear map
p → (x, y) is 0.81 at seed 0 (range 0.78–0.82 across seeds 0–4),
and the total MDL is ~5.9 bits per example.
To regenerate visualizations:
python3 visualize_dipole_position.py --seed 0 --n-epochs 4000 --outdir viz
python3 make_dipole_position_gif.py --seed 0 --n-epochs 4000 \
--snapshot-every 100 --fps 10
Results
| Metric | Value |
|---|---|
| Final MDL | 5.90 bits / example (seed 0) |
| Reconstruction | 0.091 bits / pixel |
| Code (deviation channel) | 0.001 bits / unit |
| 2D-implicit-space alignment R² | 0.805 (seed 0) |
| Robustness | R² ∈ {0.78, 0.78, 0.80, 0.80, 0.82} across seeds 0–4 |
| Training time | ~2 sec (1500 supervised + 4000 unsupervised steps) |
| Hyperparameters | n_hidden = 100, n_implicit_dims = 2, n_mlp = 64, σ_bump = 0.18, σ_a = 0.05, σ_x = 0.30, lr = 0.002, code_weight 0.5 → 10.0, batch_size = 64 |
The “MDL bits” we report drop the Gaussian normalisation constants
½ N log(2π σ_a²) and ½ D log(2π σ_x²) (which depend only on the chosen
σ values, not on model fit) and keep the squared-error parts:
DL_recon = ‖x − x̂‖² / (2 σ_x²) (nats / example)
DL_code = ‖a − bump(p)‖² / (2 σ_a²) (nats / example)
DL_total = DL_recon + DL_code
Visualizations
Example dipoles

Sixteen randomly chosen training images. Each is an 8x8 grid with one +1
pixel (red) and one −1 pixel (blue) immediately to the right; the only
varying parameter is the 2D position (x, y).
2D implicit space

Each dot is one of the 56 training images, plotted at its bottleneck
position p ∈ [0, 1]² (the encoder MLP output) and coloured by the
dipole’s true x (left) or true y (right). The colour gradient is
roughly axis-aligned: p_0 ↔ x, p_1 ↔ y. The R² of a linear regression
p → (x, y) is 0.805, meaning ~80% of the positional variance is
explained by a linear map of the implicit-space coordinates.
Decoded image at each implicit-space position

For each p on a 6x6 grid spanning the unit square, we feed the population
code bump(p) (with no delta) through the linear decoder and plot the
decoded image. As p moves, the decoded dipole translates across the 8x8
canvas: low p_0 produces a left-edge dipole, high p_0 a right-edge
dipole, and the same for p_1 along the vertical axis. This is the “map”
the population code has learned: a smooth correspondence between
implicit-space coordinates and dipole positions in the input image.
Training curves

- Total description length drops from ~500 bits/example (untrained encoder + decoder) to ~6 bits/example. Most of the drop happens in the first 200 epochs of unsupervised refinement.
- Recon vs code DL shows the two components on a log scale. The code
cost (deviation channel) collapses below 0.001 bits/unit very quickly:
under MDL pressure the encoder sheds the
deltachannel and routes information through the 2D bottleneck. - R²(p ↔ (x, y)) drops from 1.0 (after supervised warm-up) to ≈ 0.81 during unsupervised refinement and stays there. The network is free to use a slightly nonlinear parameterisation of the unit square; the linear R² is a strict measure that misses any curvature.
- Spread of p stays at std ≈ 0.32 in both axes, comparable to the std of a uniform distribution on [0, 1] (≈ 0.29). The bottleneck uses the full unit square, not a small region.
Deviations from the original procedure
This is not a faithful 1995 reproduction. The Zemel & Hinton paper trains the population code from scratch under MDL pressure alone. Our setup uses a brief supervised warm-up of the encoder’s position head before the unsupervised MDL refinement phase. Differences:
- Supervised warm-up (1500 steps, ~1 second). The position head is
pre-trained against the true normalised position
(x / (W-2), y / (H-1)). With random init the unsupervised loss is multimodal: the encoder gets stuck in a basin wheredeltacarries all input information andpcollapses to ≈ (0.5, 0.5). Warm-up escapes that basin. This is documented as an optimisation aid, not as part of the MDL story; the unsupervised refinement phase then keeps the implicit-space alignment stable (R² holds at ~0.8 over 4000 unsupervised epochs). - Topographic decoder init. Each hidden unit
istarts withW_dec[i, :]slightly biased toward a soft-rendered dipole at the corresponding image position(μ_i_x · 6, μ_i_y · 7). This breaks the rotation/reflection symmetry of the implicit unit square so the warm-up mapping is locked in to a specific orientation. The strength is small (topographic_strength = 0.5) and the decoder is fully free to drift away under recon pressure. - Sigma annealing schedule (different). We use a fixed
σ_a = 0.05throughout and ramp the code-weight multiplier from 0.5 to 10.0 instead. Mathematically equivalent to rampingσ_afrom ≈ 0.07 down to ≈ 0.016, since the loss only sees the ratio. - Optimiser. Adam (manual implementation) with
lr = 0.002, instead of the original SGD with momentum. Helps the position head escape small-gradient regimes. - Discrete vs continuous positions. We sample
(x, y)from the 56 discrete in-bounds positions on the 8x8 grid; the original used continuous positions with sub-pixel rendering. Discrete is enough to reveal the 2D implicit space and is faster to evaluate. - MDL constants dropped. We report
DL_recon = ‖x − x̂‖² / (2 σ_x²)without the½ D log(2π σ_x²)constant. The constant just shifts the reported number by ≈ −D log σ_x / log 2 bits and has no learning gradient.
The 1995 paper reports ~0.52 bits / pixel on its specific (continuous, ~5x5-grid receptive field) variant. Our number is 0.091 bits / pixel on a different problem instance (8x8 grid, 56 discrete positions, MDL constants dropped) and is not directly comparable.
Open questions / next experiments
- Drop the warm-up. Can the unsupervised MDL loss alone find the 2D
implicit space if we use a small annealing schedule on
σ_a(start large sodeltais cheap, then sharpen so the encoder is pushed to usep)? Random init currently fails because the gradient w.r.t.pis too weak when both encoder and decoder are random. - Higher-dimensional implicit spaces. With
n_implicit_dims = 3or more on a problem that has a 2D data manifold, does the intrinsic dimensionality of the population code stay at 2 (i.e.,plives on a 2D plane in the higher-dim implicit space)? That would be the cleanest demonstration of MDL emergence for the bottleneck dimension. - Continuous positions. Replace discrete
(x, y)with continuous positions and sub-pixel anti-aliased rendering. The implicit-space alignment should improve to nearly R² = 1.0 since the data manifold is then perfectly 2D. - Multiple fixed orientations. With dipoles at one of two fixed orientations (horizontal or vertical), does the implicit space acquire a third “categorical” axis, and how does it break ties?
- Energy / DMC accounting. Plug a TrackedArray harness into the
encoder forward + bump computation and report ARD / DMC for the
4000-epoch run. Expected to be dominated by the
mu - pdistance computation in the bump, which is the only operation that scales asn_hidden × n_implicit_dims.
Dipole 3D-constraint population code
Source: Zemel & Hinton (1995), “Learning population codes by minimising description length”, Neural Computation 7(3), 549-564.
Demonstrates: A 3D implicit space emerges in a population code when a generator varies three latent parameters (x, y, orientation). With 225 hidden units, all three implicit dimensions are used.

Problem
Each input is an 8x8 image of a dipole: a positive Gaussian blob and a
negative blob separated by a fixed distance, centred at a random (x, y) and
oriented at a random angle theta. Three latent parameters vary; the network
sees only the 64 pixels.
- Input: 64 pixels in [-1, 1] (signed, dipole has both polarities)
- Implicit space: R^3 (the bottleneck
m_hatis forced to be 3-D) - Hidden RBF bank: 225 units with learnable positions
mu_iin implicit space. Each unit’s activation is a Gaussian bumpb_i = exp(-||mu_i - m_hat||^2 / 2 sigma^2). - Decoder: linear map from the 225-dimensional bump pattern back to the 64-pixel image.
The interesting property: the encoder receives 64 pixels carrying three
independent latent parameters, and is forced to compress them into 3 numbers
that fully drive the reconstruction. With one extra latent parameter
(orientation) compared to the dipole-position sister
stub, the implicit space gains one extra dimension. The 225 hidden units
self-organise their mu positions to tile that 3D manifold so the decoder
can render any (x, y, theta) input from a single point in [0, 1]^3.
Files
| File | Purpose |
|---|---|
dipole_3d_constraint.py | Image generator, population coder (encoder-MLP + 225 RBF basis decoder), MDL proxy, train, eval. |
make_dipole_3d_constraint_gif.py | Builds the animation at the top of this README. |
visualize_dipole_3d_constraint.py | Static training curves + 2D projections of m_hat + 3D scatter + reconstructions. |
dipole_3d_constraint.gif | Committed animation. |
viz/ | Committed PNGs from the run below. |
Running
python3 dipole_3d_constraint.py --seed 0 --n-epochs 200
Training takes about 11 seconds on a laptop (numpy, no GPU). Reconstruction MSE drops from a naive baseline of 0.0277 (predicting the mean image) to 0.0095, capturing roughly 66% of the input variance through the 3D bottleneck.
To regenerate visualisations:
python3 visualize_dipole_3d_constraint.py --seed 0 --n-epochs 200 --outdir viz
python3 make_dipole_3d_constraint_gif.py --seed 0 --n-epochs 120 --snapshot-every 6 --fps 6
Results
Numbers below are for --seed 0 --n-epochs 200 --n-train 2000, evaluated on
500 held-out images.
| Metric | Value |
|---|---|
| Reconstruction MSE | 0.0095 (naive: 0.0277) |
| Variance explained | ~66% |
m_hat singular values | 6.67, 4.61, 3.80 — all 3 dims active |
Linear R^2 to (x, y, cos 2θ, sin 2θ) | x=0.09 y=0.64 cos=0.03 sin=0.04, mean 0.20 |
Cubic R^2 to (x, y, cos 2θ, sin 2θ) | x=0.62 y=0.79 cos=0.33 sin=0.52, mean 0.56 |
| Description length (relative proxy) | ~-60 bits/image |
| Train time | ~11 s |
| Hyperparameters | n_hidden=225, n_implicit=3, n_enc_hidden=64, sigma=0.18, lr=0.1, batch=64 |
A second seed (--seed 1 --n-epochs 200) gives MSE 0.0100 and cubic mean
R^2 0.585, with similarly all-three-dim singular values (6.23, 5.00, 4.01).
The result is reproducible.
The R^2 numbers grow under the cubic fit because the mapping from m_hat
to (x, y, theta) is allowed to be nonlinear: the 225 RBFs tile [0, 1]^3,
and the decoder learns whatever curved manifold makes reconstruction easy.
The linear R^2 is the strict diagnostic, the cubic is the honest one.
Visualisations
Example inputs

Eight samples from the training distribution. (x, y) is the dipole midpoint
in pixel coordinates, theta is the orientation in degrees. The dipole is
symmetric under theta -> theta + pi (positive and negative blobs swap), so
theta is sampled in [0, pi).
Reconstructions

Top row: input dipoles. Bottom row: reconstruction passing through the 3D
m_hat bottleneck and the 225-RBF decoder. Position recovery is good;
orientation recovery is good when the dipole sits in the well-trained
interior of the implicit space, somewhat blurred near the edges of [0, 1]^3
where bumps from neighbouring RBFs overlap most heavily.
Implicit space (2D projections)

Each row uses a different colouring of the same scatter: row 1 by true x,
row 2 by true y, row 3 by true theta. Three pairs of m_hat dimensions
are shown per row. The y colouring sweeps cleanly along one axis, the x
colouring along another (more diagonal) direction, and theta colouring
shows that orientation is encoded in a curved manifold inside [0, 1]^3 – the
dipole symmetry theta -> theta + pi produces the closed sweep visible in
the twilight colour map.
Implicit space (3D)

The same m_hat cloud as a 3D scatter, coloured by orientation. Points lie
on a curved 3D manifold inside the unit cube. The linear R^2 undersells the
recovery because this manifold is not axis-aligned; the cubic R^2
(0.56-0.65 mean) is the honest measure.
Training curves

Loss, reconstruction MSE, the relative MDL proxy, and the linear-fit R^2 to
the true latents. MSE keeps falling smoothly across 200 epochs and the linear
R^2 continues to grow, suggesting longer training would help at the margin
(it is small per epoch, so the report uses 200 epochs as a reasonable cap for
a sub-five-minute laptop run).
Deviations from the original procedure
This is a small, faithful demonstration but not a bit-for-bit reproduction of the 1995 paper.
- Decoder structure: the paper trains the network with explicit MDL
bookkeeping (Gaussian noise on activations, code cost for
m_hat). Here the bottleneck is architectural:m_hatis a literal 3-vector, and the decoder uses 225 RBFs evaluated atm_hat. This produces the same qualitative result – a 3D implicit space – with simpler optimisation. - Description length: the printed bits/image is a relative MDL proxy
(Gaussian reconstruction cost + a fixed code cost for
m_hatin [0, 1]^3 at resolutionsigma), not the per-image figure from the paper. The Zemel & Hinton paper reports ~1.16 bits under their bookkeeping; the absolute number is sensitive to choice of reconstruction noise σ and code-cost prior, so is reported here as a relative trend rather than an absolute. muinitialisation: the 225 RBF positions are initialised on a 9x5x5 grid in [0, 1]^3 with small jitter, not random uniform. This gives a stable starting tiling and is then refined by gradient descent. Random init also works but converges more slowly.- Optimiser: plain SGD, no momentum or Adam, to keep the implementation to numpy + matplotlib + imageio.
Open questions / next experiments
- MDL pressure as a regulariser, not an architecture: re-run with a wide
hidden bottleneck (e.g. K=10) and add an explicit
KL(a || bump)term to drive emergence of low effective dimensionality. Does the network choose K_eff = 3, matching the latents? - Curved manifold geometry: can we identify the topology of the orientation
encoding in
m_hat? The dipole symmetrytheta -> theta + pipredicts a closed loop (S^1) in implicit space, fibered over the (x, y) plane. The 3D scatter is consistent with this but would benefit from a quantitative manifold-learning test (e.g. persistent homology). - Energy proxy: instrument under ByteDMD to get a data-movement cost per reconstruction, and compare to a vanilla 3D-bottleneck autoencoder. The population code uses a 225-dim hidden layer that the decoder reads in full; the equivalent dense-3-D autoencoder reads only 3. Does the population code’s robustness offset its data-movement cost?
- Sister problem: compare with
dipole-position: same architecture with one fewer latent (notheta), 100 hidden units, 2D implicit space. Does the same code converge to 2D when only(x, y)varies, or does it leave a degenerate axis?
Dipole what / where
Reproduction of the discontinuous “what / where” demonstration from
Zemel & Hinton, “Learning population codes by minimizing description
length”, Neural Computation 7(3):549-564 (1995). This is the first
explicit what / where split in Hinton’s experimental corpus and the
sister demo to dipole-position (continuous 2-D position only).

Problem
We render 8 x 8 images that are either a horizontal bar at a continuous row centre y in [0, 7] or a vertical bar at a continuous column centre x in [0, 7]. Bars are 1-pixel wide with a Gaussian fall-off (σ = 0.7) so adjacent positions overlap in pixel space and adjacent codes therefore should be near each other in implicit space.
- Inputs: 8 x 8 = 64 floats in [0, 1].
- Hidden: 100 sigmoid units (one fully-connected layer).
- Implicit space: 2-D bottleneck z, learned, no supervision.
- Decoder: z -> 100 sigmoid -> 64 logits.
- Training distribution: 50/50 horizontal vs vertical, position uniform on [0, h-1].
The interesting property: the two image families are qualitatively different — there is no smooth one-parameter morph from “horizontal at y=3.5” to “vertical at x=3.5”. So the optimal layout in 2-D under MDL pressure is two perpendicular 1-D manifolds, one per orientation, crossing only at a small “junction” region in the middle of implicit space. This is the discontinuous-clustering signature of a what / where representation: the what is which manifold you are on, the where is how far along it.
The dataset is intentionally a continuous-position version of the bars- and-stripes toy. Binary 1-pixel bars (the simpler choice) make all within-class image pairs as different from each other as cross-class pairs, which kills the inductive bias the autoencoder needs to find clean clusters; a Gaussian fall-off restores it.
Files
| File | Purpose |
|---|---|
dipole_what_where.py | Bar-image generator + 64-100-2-100-64 noisy-bottleneck autoencoder + Adam training. CLI: --seed --n-epochs --lambda-mdl --sigma-z. |
visualize_dipole_what_where.py | Five static PNGs: example inputs, implicit-space scatter (orientation- and position-coloured), MDL trajectory + cluster diagnostics, decoder sweep over implicit space, encoder receptive fields. |
make_dipole_what_where_gif.py | Generates dipole_what_where.gif (the animation at the top of this README). |
dipole_what_where.gif | Committed animation. |
viz/ | Output PNGs from the run below. |
The four spec-required helpers generate_bars, build_population_coder,
description_length_loss, and visualize_implicit_space are all exported
from dipole_what_where.py.
Running
Train and report final diagnostics:
python3 dipole_what_where.py --seed 1 --n-epochs 150
Training takes ~2 seconds on a laptop. Final cluster diagnostics for the default config:
| Metric | Value |
|---|---|
| reconstruction (mean per-pixel BCE) | 0.27 |
| MDL code length (0.5 ‖z‖² per dim) | 1.22 |
| linear-probe orientation accuracy | 0.58 |
| angle between H and V principal axes | 83 ° |
| H cluster mean | (-0.17, +0.09) |
| V cluster mean | (+0.42, +0.05) |
Regenerate visualisations:
python3 visualize_dipole_what_where.py --seed 1 --outdir viz
python3 make_dipole_what_where_gif.py --seed 1 --snapshot-every 3 --fps 12
Results
| Metric | Value |
|---|---|
| Final loss | 0.33 |
| Final reconstruction (BCE / pixel) | 0.27 |
| Final MDL code length | 1.22 |
| Linear-probe orientation accuracy | 0.58 |
| H / V principal-axis angle | 83 ° |
| Training time | ~2 sec (150 epochs, 2000 train images) |
| Hyperparameters | lr=5e-3, λ_mdl=0.05, σ_z=0.5, batch=64, hidden=100, init_scale=0.1 |
Two diagnostics are needed because there are two valid signatures of a what / where representation:
- Linear-probe accuracy catches the “two clusters in opposite corners” geometry. For our run it is only 0.58 (slightly above chance), which on its own is unimpressive.
- Principal-axis angle catches the “two perpendicular 1-D manifolds through the origin” geometry. For our run it is 83 °, which is the dominant signature here: the H and V codes lie along nearly perpendicular axes of the implicit space.
The two diagnostics together say: the network has discovered the what / where decomposition, but with the H and V manifolds threading through each other rather than landing in separated regions of the plane. This is consistent with a Gaussian prior on z with no built-in cluster structure — the prior has its global minimum at the origin so both manifolds are pulled inward.
Visualizations
Example inputs

Eight horizontal and eight vertical bars from the training distribution. The Gaussian fall-off (σ = 0.7) makes the bars 3-4 pixels wide and gives adjacent positions a smooth pixel-space overlap.
Implicit space

The 2-D code z for 400 held-out test images. Left: coloured by orientation. The H codes (red) and V codes (blue) trace out two distinct 1-D arcs that cross near the origin — the “junction” between the two image families. Right: same scatter coloured by within-class position (y for H, x for V). Position varies smoothly along each arc, confirming the where axis lives along each manifold.
Decoder sweep over implicit space

Reconstructions produced by sweeping z over a 9 x 9 grid in [-2.5, 2.5]². The picture cleanly factorises:
- left edge -> horizontal bars at varying row position
- right edge -> vertical bars at varying column position
- centre column -> “+” cross patterns (the junction region of implicit space, where the two manifolds meet)
This is the most direct picture of the what / where split: moving along one diagonal of the implicit space morphs the what (H to V), moving perpendicular to it morphs the where (bar position).
Description-length trajectory

The two losses balance early in training — reconstruction drops from 0.55 to ~0.30 in the first 30 epochs while MDL grows from 0 to ~1.0. The principal-axis angle (right panel) jumps from ≈ 0 ° to ≈ 90 ° in the first 10 epochs and stays there: the network finds the two perpendicular axes very quickly, well before reconstruction has converged. The linear-probe accuracy is noisier (orientation lives in the manifold, not in the mean) and is consistent with a perpendicular- arc geometry.
Encoder receptive fields

A small fraction (~15) of the 100 hidden units have learned clear horizontal-edge detectors (red row over blue row, or vice versa); the rest are diffuse. The ones with sharp horizontal-edge structure collectively encode the y coordinate of horizontal bars; vertical-edge units encode x for vertical bars. The fact that only ~15% of units specialise is consistent with the small, low-entropy training set — the network only needs a handful of detectors to span the bar manifold.
Deviations from the original procedure
The original Zemel & Hinton 1995 paper used:
- Hidden-activity bump constraint: a Gaussian-shape penalty on the hidden activity. We use a noisy-bottleneck autoencoder with a Gaussian prior on z and no explicit bump constraint on the hidden layer. The cluster geometry that emerges (two perpendicular 1-D manifolds) is consistent with the spirit of the paper but is not obtained via the same loss formulation.
- Mixture-of-Gaussians prior on the implicit space, learned jointly with the model. We use a single fixed unit-variance Gaussian prior. With the simpler prior, the H and V manifolds cross near the origin instead of separating into “opposite corners”, because the origin is the unique minimum of the prior.
- Comparison to a Kohonen self-organising map is not implemented in this stub. The Zemel & Hinton paper showed the Kohonen net produces a single connected manifold (no discontinuous split); we leave this ablation as future work.
- Sampling inference: the original paper used a stochastic encoder trained with a variational EM-style scheme. We use a deterministic encoder + Gaussian noise injection on z (effectively a fixed-variance variational posterior) and Adam backprop.
The discontinuity signature (perpendicular axes, near-orthogonal H and V principal directions) reproduces faithfully despite these simplifications.
Open questions / next experiments
- Mixture prior: replace the unit-variance Gaussian prior with a learned 2-component mixture of Gaussians. Expected: H and V manifolds decouple into separated regions of z and the linear-probe accuracy jumps from ≈ 0.6 to ≈ 1.0, while the principal-axis angle stays near 90 °.
- Kohonen baseline: train a 2-D self-organising map on the same bars dataset and compare. The 1995 paper claims SOMs cannot produce the discontinuous split; reproducing that failure mode would round out the demo.
- Bar width sweep: how sharp does the bar Gaussian (σ_bar) need to be before the AE stops finding the perpendicular-axes layout? Very thin bars (σ → 0) take adjacent positions out of pixel-space contact and break the within-class smoothness; very wide bars blur the orientation distinction. The σ-vs-axis-angle curve should peak at intermediate widths.
- Higher-D implicit space: with n_implicit=3 the network has a free third axis to play with. Does it use it to disentangle bar polarity / contrast / nuisance variables, or does it just spread the existing 2-D structure?
Helmholtz shifter
Helmholtz machine + wake-sleep reproduction of the shifter task from
Dayan, Hinton, Neal & Zemel, “The Helmholtz machine”, Neural Computation
7(5):889–904 (1995). Same wake-sleep machinery as the sibling
bars/ stub; different dataset and a multi-unit top layer
because a single top unit cannot break the t ↔ 1 - t symmetry on this
task.

Problem
Each 4×8 binary image is generated by a two-stage latent process:
- Sample a top row of 8 random bits (each on with probability pon = 0.3).
- Pick a shift direction (left or right, prior 1/2 each) and produce a bottom row that is the top row cyclically shifted by ±1.
- Duplicate: row 1 = row 0 (top), row 2 = row 3 (bottom). The image is therefore a “double-thick” 4×8 picture of the (row, shifted row) pair, exactly as in the original paper.
Support of pdata has 28 × 2 = 512 unique images; the data distribution is uniform on this support but most bit patterns are sparse (1.6 expected on-pixels per row at pon = 0.2, 2.4 at 0.3).
The interesting property: the latent direction bit is observable only through the cross-row correlation in the visible image. The model has to discover this 1-bit cause behind the 8-bit row marginal, and represent it in a separate layer (layer 3) above the bit-pair detectors (layer 2).
Architecture
Three-layer sigmoid belief net, top-down generative + bottom-up recognition,
identical wake-sleep deltas to bars/:
v (32 visible) <-- W_hv -- h (16 hidden, layer 2) <-- W_th -- t (4 top, layer 3)
v (32 visible) -- R_vh --> h (16 hidden, layer 2) -- R_ht --> t (4 top, layer 3)
All units are binary stochastic; each layer’s conditional is factorial. Wake-sleep alternates 1-step delta updates: wake teaches the generative weights to predict the layer below given the latents the recognition net inferred; sleep teaches the recognition weights to invert what the generative net just produced. No backprop.
Files
| File | Purpose |
|---|---|
helmholtz_shifter.py | Shifter sampler, Helmholtz machine, wake/sleep updates, importance-sampled NLL, layer-3 selectivity inspector. CLI for training. |
problem.py | Thin shim re-exporting generate_dataset, build_helmholtz_machine, wake_sleep, inspect_layer3_units for tools that follow the spec stub. |
_train_canonical.py | Trains the canonical run (seed 1, 1.5×106 samples) and saves weights + per-50K-step snapshots to viz/. |
visualize_helmholtz_shifter.py | Static plots: training curves, fantasy samples, layer-3 selectivity bars, layer-2 receptive fields, reconstructions, weight Hinton diagrams. |
make_helmholtz_shifter_gif.py | Renders helmholtz_shifter.gif from snapshots saved during the canonical run. |
helmholtz_shifter.gif | Animation at the top of this README. |
viz/ | Committed PNGs. (Training caches *.npz here too but those are gitignored — re-run _train_canonical.py to regenerate.) |
Running
The canonical pipeline trains, saves the model + snapshots, then renders the static plots and the GIF from the saved snapshots:
# 1. train (~3.5 min on a laptop, single-thread numpy)
python3 _train_canonical.py --seed 1 --n-passes 1500000 --p-on 0.3 \
--eval-every 50000 --snapshot-every 50000
# 2. static visualizations (re-uses the trained model)
python3 visualize_helmholtz_shifter.py --reuse
# 3. animation (re-uses the snapshot stream)
python3 make_helmholtz_shifter_gif.py --reuse --fps 8
Or run training only:
python3 helmholtz_shifter.py --seed 1 --n-passes 1500000 --p-on 0.3
Results
| Metric | Value |
|---|---|
| Seed | 1 |
| Architecture | 32 visible — 16 hidden (layer 2) — 4 top (layer 3) |
| Wake-sleep iterations | 1,500,000 (1 wake + 1 sleep update each) |
| Total samples | 1.5M wake + 1.5M sleep |
| Batch size | 1 (online) |
| Learning rate | 0.1 (constant, both phases) |
| Init scale | 0.1 |
| Visible-bias init | logit(pon) = logit(0.3) ≈ -0.85 |
| pon (row marginal) | 0.3 |
| Initial IS-NLL (random init) | 28.7 bits/image |
| Final IS-NLL (M=200 importance samples) | 9.36 bits/image |
| Direction recovery accuracy | 0.633 (chance = 0.5) |
| Best top-unit | selectivity |
| Wall-clock time (training) | 209 sec |
NLL is importance-sampled: for each held-out v in a fixed eval set of
256 patterns, we draw M latent samples from the recognition net,
compute log p(v | h) p(h | t) p(t) / q(h | v) q(t | h) for each, and take
log-mean-exp. Exact KL evaluation (as in bars/) would need to enumerate
217 = 131K latent configurations per query — computable but
costly — so the curve uses M=50 samples and the final number M=200.
The headline finding (next section) is the per-top-unit shift-direction selectivity: 3 of the 4 layer-3 units develop clean direction tuning.
Visualizations
Training curves

IS-NLL drops fast in the first 200K iterations then plateaus around 9.3 bits per image. Direction recovery (right panel) climbs from chance (0.5) to ~0.63 within the first 100K iterations and stays there. The recovery metric scans all 24 sign-vectors over the 4 top units and picks the best linear combination — sign-flip-invariant, so the residual gap to 1.0 reflects the recognition net’s inability to losslessly invert the generative net (factorial q with a 1.5×106-sample budget, not a fundamental limit).
Layer-3 unit shift-direction selectivity

The headline reproduction. Left panel: bar chart of P(right shift | tk=1) and P(left shift | tk=1) for each of the four top units, measured by sampling fantasies under one-hot top conditioning (2048 fantasies per unit, “shift signature” = bottom row exactly equals top row shifted by ±1).
- t1: 71% right, 0% left — clean right-shift detector.
- t2: 6% right, 51% left — clean left-shift detector.
- t3: 1% right, 41% left — left-shift detector.
- t0: 44% right, 11% left — partial right-shift signal, partly redundant with t1.
Right panel: 32 fantasies for each one-hot top configuration. The “all right-shift” rows under t1 and the “all left-shift” rows under t2/t3 are visually distinct: each is a gallery of valid shifter patterns of one direction.
Layer-2 (hidden-unit) generative receptive fields

For each hidden unit j, top row shows p(v | hj = 1, others off) − p(v | all h off), reshaped to 4×8. Bottom row shows the corresponding row of RvhT (the recognition counterpart).
Each unit lights up two specific pixels at offset ±1: one in row 0/1 (top half of image) and one in row 2/3 (bottom half), shifted by exactly one column. This is the bit-pair detection the paper predicted: “hj = 1 if pixel i is on AND pixel (i ± 1) of the shifted row is on”. The annotated Wth values above each panel show how each detector projects onto the 4 top units — the bipolar pattern is what makes specific top units favour specific shift directions.
Generated samples

64 fantasies drawn by ancestral sampling top → h → v through the trained generative net. Most fantasies have the duplicated-row structure (top half identical, bottom half identical), and a substantial fraction match a clean ±1 shift. The title reports the empirical fraction of right-shift, left-shift, and “other” (unstructured) samples.
Reconstructions

Top row: 16 fresh shifter inputs from pdata. Bottom row: the mean of p(v | h) where h ∼ q(h | v) was sampled from the recognition net. Reconstructions match the input pixel-for-pixel on most images, including the duplicate-row structure and the +/-1 shift. The small grey ghosting on a few reconstructions reflects pixel-level uncertainty in the factorial conditional (no commitment to which exact bit is on).
Weights

Hinton diagrams of Whv (16 hidden × 32 visible, generative top half) and RvhT (recognition bottom half). Most rows have a clear “two pixel” support (red/blue squares at one pixel in the top half of the image and one in the bottom half), confirming the bit-pair structure visible in the receptive-field plot above.
Deviations from the original procedure
- Top layer has 4 units, not 1. With ntop=1 the wake-sleep dynamics’ symmetry under t → 1 - t prevents the single top unit from breaking the left/right symmetry: 30/30 seeds at 500K iterations gave |selectivity| < 0.1 (chance level), while ntop ≥ 2 broke the symmetry on every seed I tried (3 seeds, |selectivity| up to 0.95 per unit). This deviation is consistent with the original paper, which describes the layer-3 units (plural) becoming shift-direction selective rather than asserting a 1-unit top.
- pon = 0.3, not 0.5. Pure random binary (pon = 0.5) gives a non-sparse dataset where each row contains ~4 on-pixels, making the bit-pair structure hard to read off the receptive fields (every pair gets some weight). At pon = 0.3 (or 0.2) the sparser inputs let each hidden unit specialise cleanly to one (position, direction) pair. The qualitative results — layer-3 selectivity, layer-2 bit-pair detectors — reproduce at both 0.2 and 0.3; 0.3 is the canonical value here because the IS-NLL converges faster with more informative inputs.
- Visible-bias init. Same as
bars/: bv initialised to logit(pon) so the all-hidden-off path already produces the pixel marginal. Removes a “dead start” without otherwise biasing the wake-sleep dynamics. - Constant learning rate. The 1995 paper reports a small fixed rate; experiments with a two-phase schedule (lr=0.1 then lr=0.02) gave essentially the same direction-selectivity scores at convergence on this problem, so the constant-LR version is reported.
- NLL evaluation is importance-sampled, not exact. The data support has 512 unique images, so exact NLL would require 512 × 217 = 67M sigmoid evaluations per check; the importance-sampled estimator with M=50 takes a fraction of that and is consistent across seeds. Final NLL uses M=200.
Open questions / next experiments
- Closing the recognition gap. Direction recovery at 0.63 leaves a lot
on the table (1.0 = perfect, 0.5 = chance). The factorial recognition
cannot represent the bimodal posterior on direction-ambiguous inputs
(all-zero rows, all-one rows, palindromic rows), but those inputs are a
small fraction of pdata at pon = 0.3. The bigger
gap is the recognition-vs-generative loss observed in
bars/too: the generative model fits well (NLL drops 3×) while the recognition net lags. A targeted multi-restart or perturb-on-plateau experiment would probably push direction recovery into the 0.8+ range. - Single-top-unit recipe. Is there a wake-sleep variant (e.g. anti- symmetric init or asymmetric prior on top) that lets ntop=1 succeed? The 4-unit top is a workaround, not a fundamental requirement.
- Energy/data-movement profile. All updates are 1-step delta rules, no backprop. Profiling wake/sleep memory traffic under ByteDMD would be a direct port to the Sutro-group energy metric.
- Larger images. The same architecture should learn 4×16 or 8×8 shifters. With ntop=4 already specialising 3 units to direction, the model has spare capacity for a multi-direction generalisation (e.g. shifts ±1, ±2 with 4 latent classes).
Bars task
Helmholtz machine + wake-sleep reproduction of the bars experiment from Hinton, Dayan, Frey & Neal, “The wake-sleep algorithm for unsupervised neural networks”, Science 268 (1995).

Problem
Each 4×4 binary image is generated by a two-level latent process:
- Pick orientation: vertical with prior 2/3, horizontal with prior 1/3.
- Conditioned on the orientation, each of the 4 candidate bars in that orientation is independently active with probability 0.2. Pixels are the union (logical OR) of the active bars.
There are 16 (vertical) + 16 (horizontal) − 2 (blank and all-on, shared between orientations) = 30 distinct images in the support of pdata. The blank image alone has probability 0.4096; the distribution is heavily peaked.
The interesting property: the true posterior p(top, h | v) is not factorial. Wake-sleep fits a factorial recognition network anyway, so the recognition net cannot exactly capture the bimodal vertical-vs-horizontal posterior. The paper’s headline is that despite this approximation, the generative model still converges to a low-KL fit of pdata via the wake-sleep delta rules — no backprop, no exact inference, just two alternating local update rules.
Files
| File | Purpose |
|---|---|
bars.py | Bars-distribution sampler, Helmholtz machine, wake/sleep updates, exact KL evaluator. CLI for training. |
_train_canonical.py | Helper that trains the canonical run (seed 2, 2×10⁶ samples) and saves weights + snapshots to viz/. |
visualize_bars.py | Static plots: KL/NLL trajectories, fantasy samples, recognition codes, hidden-unit receptive fields, weight Hinton diagrams. |
make_bars_gif.py | Renders bars.gif from snapshots saved during the canonical run. |
bars.gif | Animation at the top of this README. |
viz/ | Committed PNGs. (Training caches *.npz here too but those are gitignored — re-run _train_canonical.py to regenerate.) |
Running
The canonical pipeline trains, saves the model, then renders the static plots and the GIF from the saved snapshots:
# 1. train (~4 min on a laptop, single-thread numpy)
python3 _train_canonical.py --seed 2 --n-steps 2000000 --lr 0.1 \
--batch-size 1 --snapshot-every 50000
# 2. static visualizations (re-uses the trained model)
python3 visualize_bars.py --reuse
# 3. animation (re-uses the snapshot stream)
python3 make_bars_gif.py --reuse --fps 8
Or run training only:
python3 bars.py --seed 2 --n-steps 2000000 --lr 0.1 --batch-size 1
Results
| Metric | Value |
|---|---|
| Seed | 2 |
| Architecture | 16 visible — 8 hidden — 1 top, sigmoid belief net |
| Wake-sleep iterations | 2,000,000 (each = 1 wake update + 1 sleep update) |
| Total samples | 2,000,000 wake + 2,000,000 sleep |
| Batch size | 1 (online) |
| Learning rate | 0.1 (constant, both phases) |
| Init scale | 0.1 |
| Visible-bias init | logit of pixel marginal (≈ −1.39) |
| Final KL[pdata ‖ pmodel] | 0.451 bits |
| Final NLL of pdata under model | 3.880 bits |
| Entropy H(pdata) | 3.428 bits (target NLL floor) |
| Wall-clock time | 222 sec |
| Initial KL (random init) | 8.16 bits |
The KL is computed exactly: enumerate the 30 support images of pdata, marginalise pmodel(v) over the 2⁹ = 512 latent configurations of (top, h) under the trained sigmoid belief net, then sum pdata(v) · log2(pdata(v) / pmodel(v)).
Visualizations
KL and NLL trajectories

Both curves are the exact values (no Monte Carlo): the asymmetric KL is evaluated at every snapshot by enumerating the 512 latent configurations of the trained net. The NLL plateau approaches H(pdata), the entropy of the bars distribution, which is the lowest cross-entropy any generative model can achieve.
Generative samples

64 fantasies drawn by ancestral sampling top → h → v through the
trained generative net. The mix of vertical-stripe vs horizontal-stripe
samples reflects the learned b_top; the headline check is that
individual samples look like valid bars images (one orientation, a few
bars at a time) rather than mixed-orientation noise.
Hidden-unit specialization

Top row: p(v | hj = 1, all other h off) reshaped to 4×4.
Each hidden unit becomes a “bar detector” — the corresponding image lights
up exactly the pixels of one bar. Bottom row: the same field minus the
all-h-off baseline (so red ≈ pixels this unit adds, blue ≈ pixels it
removes). The W_th value annotated above each panel is the hidden
unit’s coupling to the top-most “orientation” unit; vertical-bar detectors
end up with one sign, horizontal-bar detectors with the other.
Recognition activations on the data support

For each of the 30 unique images in the support of pdata, the lower panel shows the recognition net’s per-hidden-unit output q(hj = 1 | v). Images are sorted by the recognised q(top = 1 | v) — vertical-orientation images cluster on one side, horizontal on the other. A perfect factorial recognition net would produce a clean block-diagonal structure (one block of 4 verticals specialists firing on the left, one block of 4 horizontals firing on the right); the actual codes are softer because the factorial approximation cannot represent the true bimodal posterior on ambiguous images (blank, all-on).
Weight matrices

Generative Whv (left) and recognition RvhT (right) as Hinton diagrams. Square area is √|w|; red = positive, blue = negative. Each row of Whv is the (signed) bar template the corresponding hidden unit has carved into the visible bias; the recognition matrix is roughly the transpose, modulo the asymmetry between “present in v” and “explain-away signal”.
Deviations from the original procedure
- Visible-bias init. Initialising bv to logit(0.2) ≈
−1.39 (the pixel-on marginal of pdata) noticeably
accelerates early training. With bv = 0 the network has
to learn to suppress every pixel before any hidden unit can usefully
light some pixels back up; the marginal-logit init removes that dead
start without otherwise biasing the wake-sleep dynamics. CLI flag
init_visible_bias_to_marginal=True(default). - Constant learning rate. The 1995 paper reports a small fixed learning rate; experiments with a two-phase schedule (lr=0.1 then lr=0.02) gave essentially the same asymptotic KL on this problem, so the constant-LR version is reported.
- Recognition is fully factorial. The paper explicitly chose the factorial approximation; this is faithful to the original setup and is one of the main points of the experiment.
- KL evaluation is exact. The paper’s evaluation is also exact for the bars task (the support is small enough); we enumerate the 30 support images and marginalise the 2⁹ = 512 latent configurations.
- Asymptotic KL gap. The paper reports KL ≈ 0.10 bits at convergence on a single representative run. Our reproduction at 2×10⁶ samples converges to 0.451 bits — the same order of magnitude as the network entropy (the model captures most of the structure: KL drops from 8.16 bits at init to 0.45 bits) but ≈ 4.5× higher than the paper’s reported headline. The discrepancy is discussed under “Open questions” below; we did not tune past the simple constant-LR online recipe in this stub.
Open questions / next experiments
- Closing the KL gap. The paper’s reported KL ≈ 0.10 bits beats our reproduction (0.45 bits) by ≈ 4.5×. Plausible explanations: (a) different parameterisation (e.g. centered hidden states or initial recognition biases that lock onto an orientation), (b) an explicit LR schedule we did not try, (c) longer training with multi-restart-on- plateau (the encoder-4-2-4 sibling needed this to escape local minima in the factorial-bottleneck regime). A targeted sweep over these axes is the obvious next step.
- Recognition vs generative gap. q(h, top | v) is forced to factorise even though the true posterior on ambiguous images (blank, all-on) is bimodal. How much of the residual KL is the factorial-recognition gap vs the generative-fit gap? A Helmholtz-machine-with-mixture-recognition variant would isolate the two contributions.
- Energy/data-movement profile. All wake-sleep updates are 1-step delta rules, no backprop. Once a baseline KL is established, profiling the wake/sleep memory traffic under ByteDMD would be a direct port of the Sutro-group energy metric to a generative model.
- Scaling. The same architecture+algorithm should learn 5×5 or 8×8
bars, and
helmholtz-shifter/(a sibling stub) is already a 1995 Helmholtz-machine task on a different dataset. A unified library that swaps the data sampler in and out would let us compare both.
Bars problem for RBM training
Source: Hinton, G. E. (2000), “Training products of experts by minimizing contrastive divergence” (Gatsby tech report; Neural Computation 14(8), 2002). The bars task itself is from Foldiak, P. (1990), “Forming sparse representations by local anti-Hebbian learning.”
Demonstrates: the canonical sanity check for RBM / contrastive-divergence training — after CD-1 training, each hidden unit specializes to a single bar.

Problem
- Visible: 16 binary pixels arranged as a 4×4 image.
- Hidden: 8 binary feature detectors (canonical setting; one per bar).
- Connectivity: bipartite RBM (visible ↔ hidden only).
- Training distribution: each image is generated by independently
activating each of 8 single-bar templates (4 horizontal rows + 4 vertical
columns) with probability
p_bar = 1/8, then taking the logical OR over activated bars to get the pixels. So each image is a superposition of bars.
The interesting property: the data has a clean latent factor structure (8 independent on/off causes), but the visible pixels are tangled by the OR mixture. Backprop on a single image cannot recover the bars, because there is no per-pixel target. The RBM gets the bars from the statistics alone: under CD-1, each hidden unit learns to fire iff one specific bar is present, because that maximizes the model’s likelihood under the bipartite factorization.
The trick is that an over-explanatory hidden code (e.g. one unit that fires for any bar) reconstructs the data poorly — the model needs distinct units for distinct causes. Sparsity in the hidden activations + the bipartite structure together make the per-bar decomposition the natural local optimum.
Files
| File | Purpose |
|---|---|
bars_rbm.py | Bars dataset + BarsRBM + cd1_step() + train() + per_unit_bar_purity() + visualize_filters(). CLI for reproducing the headline run. |
visualize_bars_rbm.py | Static PNGs: receptive fields, training curves, sample reconstructions, data examples, bar-template reference. |
make_bars_rbm_gif.py | Generates bars_rbm.gif (the animation at the top of this README). |
bars_rbm.gif | Receptive fields evolving across 300 epochs of CD-1. |
viz/ | Output PNGs from the headline seed=2 run. viz/n_hidden_16/ holds the over-complete (n_hidden=16) sibling run. |
problem.py | The original stub signatures. Re-exports from bars_rbm.py. |
Running
python3 bars_rbm.py --seed 2 --n-hidden 8 --n-epochs 300
Training time: ~1.5 s on a laptop. Final result for seed=2: 7/8 bars
recovered, mean per-unit purity 0.90, reconstruction MSE 0.016.
To regenerate visualizations:
python3 visualize_bars_rbm.py --seed 2 --n-hidden 8 --n-epochs 300
python3 make_bars_rbm_gif.py --seed 2 --n-hidden 8 --n-epochs 300
To explore over-complete coding (more hidden units than bars):
python3 visualize_bars_rbm.py --seed 0 --n-hidden 16 --n-epochs 300 \
--outdir viz/n_hidden_16
Results
Headline run (n_hidden = 8, seed = 2)
| Metric | Value |
|---|---|
| Bars recovered (purity ≥ 0.5) | 7 / 8 |
| Mean per-unit purity | 0.90 |
| Reconstruction MSE (one CD step) | 0.016 |
| Training time | ~1.5 s |
Convergence statistics (10 seeds, n_hidden = 8)
| Outcome | Count |
|---|---|
| ≥ 7 bars recovered | 8 / 10 |
| All 8 bars recovered | 2 / 10 |
| Mean bars / 8 | 7.0 |
So the typical outcome is “7 of 8 hidden units lock onto distinct bars; one unit duplicates a neighbour or stays partially mixed.” That is the standard CD-1-on-bars result reported in the original literature — cleaner sparsity penalties or PCD push the rate higher (see Open questions below).
Over-complete (n_hidden = 16, seed = 0)
| Metric | Value |
|---|---|
| Bars recovered | 8 / 8 |
| Mean per-unit purity | 0.94 |
| Reconstruction MSE | 0.0001 |
| Training time | ~2.8 s |
With twice the hidden units, every bar is found by at least one unit. Some units duplicate (two units detecting the same bar), some learn slightly shifted/rotated mixtures — but no bar is missed.
Hyperparameters used
| Param | Value |
|---|---|
n_train (samples) | 2000 |
batch_size | 20 |
lr | 0.10 |
momentum | 0.5 |
weight_decay (L2) | 1e-4 |
sparsity_cost | 0.1 |
sparsity_target | 1 / n_hidden |
init_scale (W) | 0.01 |
b_v init | logit(data_mean) |
b_h init | logit(sparsity_target) |
p_bar (per-bar activation prob) | 0.125 |
n_epochs | 300 |
Visualizations
Receptive fields (the headline)

Each subplot is the incoming weight slice W[:, j] for one hidden unit,
reshaped back to a 4×4 image (red positive, blue negative). The cleanly
specialized units have a single bright row or column with near-zero weights
elsewhere — that is one hidden unit “detecting one bar.” The label Hk /
Vk above each panel is the closest single-bar template, with the
cosine-similarity purity score.
Bar templates (reference)

The 8 single-bar templates the RBM is being asked to recover.
Training data

A random batch of 16 generated images. Each is the OR of zero or more bars; many images contain just one bar, some contain two or more, a few are blank.
Training curves

- Reconstruction MSE drops from ~0.1 (random init) to ~0.015 within ~50 epochs, then trickles down as fine-grained per-bar specialization sharpens.
- Mean bar purity climbs from ~0.0 (random filters) to ~0.9 over the same window — the qualitative phase transition where filters become bar-like.
- Bars covered (number of distinct single-bar templates that some hidden unit detects with purity ≥ 0.5) climbs to 7/8 by epoch ~100 and stays there. The final missed bar is typically duplicated by another unit instead.
Reconstructions

Top row: data (8 random images). Middle row: one-step CD reconstruction
p(v | h(v)). Bottom row: hidden-unit activations p(h | v). Hidden codes
are sparse — usually 1–3 of the 8 units fire per image, matching the
underlying number of bars.
Over-complete (n_hidden = 16)

With 16 hidden units, all 8 bars are reliably recovered, with most bars detected by 2 hidden units. A few units learn slightly mixed (bar-fragment) detectors. Reconstruction MSE drops to ~1e-4 — essentially exact.
Deviations from the original procedure
-
Sparsity penalty on
b_h— Hinton 2000 reports clean per-bar specialization with vanilla CD-1, partly because the original experiments use larger images / different sparsity priors. To get reliable single-bar receptive fields on the small 4×4 grid here we add a quadratic penalty pushing the mean hidden activation toward1 / n_hidden = 0.125(a standard practical addition; see Lee, Largman, Pham, Ng 2009). Withsparsity_cost = 0, the per-seed success rate drops noticeably. -
Bias initialization —
b_vis initialized tologit(data_mean)andb_htologit(sparsity_target), so the network starts with sensible marginals. Without this, the first ~30 epochs are spent moving the biases, with hidden units that are saturated or dead. -
Momentum + weight decay —
momentum = 0.5,weight_decay = 1e-4. The 2002 paper does not use momentum; modern RBM practice (Hinton 2010 practical guide) does, and it speeds convergence noticeably. -
Number of training samples — we use 2000 fresh samples; the original paper uses larger but qualitatively similar streams. Sample count is not the limiting factor at 4×4.
Open questions / next experiments
-
Sparsity-free convergence rate. With
sparsity_cost = 0the per-seed success rate (≥ 7 bars covered) drops to roughly 50%. How does the rate scale withn_train,n_epochs, andinit_scalealone? Can we get to 100% with no sparsity term by tuning the other knobs? -
PCD vs CD-1. Persistent Contrastive Divergence (Tieleman 2008) keeps a Markov chain across mini-batches. On the bars problem it should be strictly better than CD-1 (less biased gradient), but the cost is one extra Gibbs step per iteration. Quantify the gap on this benchmark.
-
Energy / data-movement cost. Per the broader Sutro effort, every problem in this catalog should eventually be measured under ByteDMD. For bars-RBM the per-iteration cost is dominated by
v @ Wandh @ W.T— bothO(n_visible * n_hidden). Total cost for the reference run =n_epochs * (n_train / batch_size) * 4 * n_visible * n_hidden≈ 1.5 × 10⁹ float-mults; what does ByteDMD say the data-movement bill is? -
Larger grids. Foldiak’s original setup used 8×8 and 16×16 with correspondingly more bars. Does the current recipe scale, or do we need PCD / longer training to keep the per-seed success rate up?
-
Why does one bar typically go missing? The lost bar is usually one with a high-overlap neighbour (e.g. two adjacent rows). Is this a fundamental CD-1 failure (the gradient cannot distinguish near-duplicate causes) or a finite-data artefact? A controlled experiment varying the bar-overlap structure would settle it.
transforming-pairs
Gated three-way (input × output × hidden) conditional RBM, trained on pairs
(x, y) of binary 13×13 random-dot images where y is x after a known
transformation drawn from {translation by ±1 pixel, 90° rotation}.
Source: Memisevic & Hinton, “Unsupervised learning of image
transformations”, CVPR 2007.
Demonstrates: Multiplicative interaction between an input image and an
output image, factored through F shared “filter pairs”, causes hidden units
to specialize as transformation detectors — each one responds to a
specific (input → output) deformation rather than to the content of either
image.

Problem
A pair (x, y) is generated by
- drawing a binary 13×13 image
x(each pixel on with probability 0.10, so ~17 lit pixels per image), and - choosing a transformation
Tuniformly from a fixed pool, then settingy = T(x).
The pool depends on --transforms. Default (shift,shift_max=1) is the
8 cardinal one-pixel shifts {(±1, 0), (0, ±1), (±1, ±1)}. With
--transforms shift,rotate, the three rot90 multiples (90°, 180°, 270°)
are added.
The model is a conditional RBM p(y, h | x) whose energy is
E(y, h | x) = - Σ_f (Wx_f · x) · (Wy_f · y) · (Wh_f · h)
- b_y · y - b_h · h
i.e. the third-order weight tensor W_{i, o, j} is factored as
Σ_f Wx_{i,f} · Wy_{o,f} · Wh_{j,f}. Without the factorization the
parameter count would be n_in · n_out · n_hidden = 169 · 169 · 64 ≈ 1.8M; the factored form has `(n_in + n_out + n_hidden) · F = (169 + 169
-
- · 64 ≈ 26k
. Each factor is a *filter pair* — an input filterWx_f(a 13×13 image) and an output filterWy_f` (also 13×13) — and the gated activation rule
- · 64 ≈ 26k
p(h_j = 1 | x, y) = σ( Σ_f Wh_{jf} · (Wx_f · x) · (Wy_f · y) + b_h_j )
makes a hidden unit fire when the input matches Wx_f and the output
matches the transformed filter Wy_f. The interesting property: the
multiplication forces hidden units to encode the relationship between
x and y, not the content of either, so the same units fire across
many different random-dot inputs as long as the transformation is the
same. This is the seed of capsule-style “transformation features”.
Files
| File | Purpose |
|---|---|
transforming_pairs.py | Pair generator + factored 3-way RBM + CD-1 trainer + transform-classification eval. CLI --seed --transforms .... |
problem.py | Spec-compatible re-export shim. |
visualize_transforming_pairs.py | Static figures: example pairs, filter pairs, per-transform hidden activation profile, training curves, transfer-test grid. |
make_transforming_pairs_gif.py | Generates transforming_pairs.gif (the animation at the top of this README). |
transforming_pairs.gif | Committed animation (~1 MB). |
viz/ | Committed PNGs and a results.json from the canonical run. |
Running
# Default headline run (8 one-pixel shifts; ~2 s on a laptop):
python3 transforming_pairs.py --seed 0 --transforms shift --shift-max 1
# Static visualizations (training + plots; ~4 s):
python3 visualize_transforming_pairs.py --seed 0 --transforms shift --shift-max 1
# Animation (~12 s):
python3 make_transforming_pairs_gif.py --seed 0 --transforms shift --shift-max 1
# Mixed transforms (8 shifts + 3 rot90's = 11 classes):
python3 transforming_pairs.py --seed 0 --transforms shift,rotate --shift-max 1
Wall-clock for the headline experiment (1 CPU core, M-series Mac, no GPU): ~2.0 s for 100 epochs of CD-1 over 4000 training pairs.
Results
Headline configuration: --transforms shift --shift-max 1, 8 one-pixel
shift classes. Chance level on the held-out classification metric is
1/8 = 12.5%.
| Metric | Value |
|---|---|
| Hidden-unit transform specificity (median across units) | 1.62 (max possible 7.0; vs 0.05 at init) |
Transformation classification accuracy (logistic regression on h(x, y)) | 39.4% (chance 12.5%, ~3.2× chance) |
Reconstruction MSE on held-out y (from one mean-field pass through h) | 0.076 |
| Reconstruction bit accuracy (threshold 0.5) | 89.9% |
| Wall-clock to train | ~2.0 s |
| Hyperparameters | n_factors=64, n_hidden=64, init_scale=0.10, lr=0.10, momentum=0.5, weight_decay=1e-4, batch=100, 100 epochs |
Per-seed reproducibility (5 seeds, otherwise identical config): transform classification 39.4 / 39.6 / 42.0 / 40.2 / 41.0 %; specificity 1.62 / 1.59 / 1.67 / 1.25 / 1.39. The 39–42% range is a stable property of the recipe, not a single-seed accident.
Other transform pools (same config, seed 0):
| Pool | Classes | Chance | Test acc |
|---|---|---|---|
shift (shift_max=1) | 8 | 12.5% | 39.4% |
rotate only | 3 | 33.3% | 44.6% |
shift,rotate (shift_max=1) | 11 | 9.1% | 25.8% |
v1 baseline metrics (per spec issue #1 v2)
| Reproduces paper? | Partial. The qualitative claim — hidden units learn transformation features and behave like motion detectors — reproduces clearly (see transformation_profile.png). Memisevic & Hinton 2007 trains on real video frame pairs at 13×13 patch size and reports oriented Reichardt-style detectors; we use synthetic random-dot pairs and recover transformation-axis selectivity but not exact direction selectivity (see Deviations §3). |
| Run wallclock | ~2.0 s for python3 transforming_pairs.py --seed 0 --transforms shift --shift-max 1. |
| Difficulty | Single-session implementation by tpairs-builder agent; no external paper details beyond what’s in the comment-graph spec. |
Visualizations
Example pairs

Eight (x, y) pairs from the test split. Every column is a different
random dot pattern paired with a different one-pixel shift; the
network sees these as i.i.d. samples with no transformation label.
Hidden activation profile (the headline)

Mean hidden activation per (transformation, unit), with rows = the 8 shift classes and columns = the subset of hidden units whose responses are peakier than 0.5 (specificity threshold). Two units stand out:
- Hidden ~4 fires almost exclusively when the shift is
(±1, 0)— i.e. horizontal motion in either direction. A horizontal-motion detector. - Hidden ~12 fires almost exclusively for
(0, ±1)— vertical motion in either direction.
These are the Reichardt-like transformation detectors the paper predicts. Selectivity is on the axis of motion rather than the direction; with sparse random-dot inputs and only 8 shift classes, the network discovers the lower-frequency axis structure faster than the sign of the shift. With more training and more transformations the axis-cells split into direction-cells (open question §1).
Filter pairs

Top 16 factors ranked by ‖Wx_f‖ · ‖Wy_f‖. Each pair of rows shows the
input filter Wx_f (top) and the output filter Wy_f (bottom) for one
factor. Several factors (e.g. f10, f21, f25, f63) show the diagnostic
“shifted-stripe” pattern: an oriented bar in Wx_f paired with the
same bar shifted by one pixel in Wy_f. That’s the factored form of
“detect this oriented input, expect this shifted oriented output” —
a one-factor implementation of a single-pixel translation along that
orientation. Other factors are diffuse: with n_factors = 64 the model
has more capacity than 8 transforms strictly require, so several
factors share work and look noisy.
Transformation transfer

Each row picks one transformation T and a single reference pair
(x_ref, y_ref = T(x_ref)). The middle and right blocks show what the
model predicts for two new inputs x_q after the hidden code
h(x_ref, y_ref) is reused: ŷ_q = E[y | x_q, h(x_ref, y_ref)]. The
predictions are diffuse rather than crisp — single mean-field passes on
binary visible units don’t recover hard-thresholded outputs — but the
mass shifts in the direction T indicates: query-input edges appear
displaced in the predicted-output. This is the demo Memisevic & Hinton
emphasize: the same h applied to a different x produces a different
output that shares the transformation.
Training curves

- Reconstruction MSE drops monotonically from 0.11 to 0.066. The large early drop comes from the model learning the marginal pixel statistics; the slower late drop comes from the hidden code starting to carry transformation information (the green curve in the middle panel rises during this same window).
- Transformation classification is at chance for the first ~30 epochs (the model is still learning marginals), then climbs to ~38% and plateaus. The discrete jitter is real — the linear classifier is retrained from scratch every eval epoch.
- Combined weight norm grows roughly as
√epoch, with no sign of the runaway divergence typical of unregulated CD on Gaussian visibles (we use Bernoulli visibles + L2 weight decay).
Deviations from the original procedure
- Synthetic random-dot pairs, not video. Memisevic & Hinton 2007 train on natural-image patch pairs from short video clips. We use binary random-dot patterns with a fixed pool of synthetic transformations. This trades faithfulness for clean ground-truth transformation labels (so the headline metric — hidden-unit specificity — has a well-defined denominator).
- CD-1 with mean-field hidden units in CD-1. The paper trains by
contrastive divergence with a small number of Gibbs steps. We use a
single CD step and use sigmoid-mean activations for both
h_posandh_neg(sampling onlyY_neg). This is the standard Hinton-2002 CD recipe, slightly less faithful than alternating samples-on-the-data- manifold but gives the same headline phenomenon. - Axis selectivity, not direction selectivity. With 8 shift classes and 4000 training pairs, hidden units discover the axis of motion (horizontal vs. vertical) before they split into per-direction cells. The paper reports both axis and direction cells on natural video. With more training data and more transform classes, our recipe should split too — see open question §1.
- No per-pixel correlated noise in the inputs. The paper uses real image statistics, which give correlated patterns; we use independent Bernoulli pixels. This is the simplest baseline, deliberately.
- No PCD, no temperature schedule. Vanilla CD-1, momentum 0.5, L2 weight decay 1e-4. No annealing.
Open questions / next experiments
- Splitting axis-cells into direction-cells. With
--transforms shift --shift-max 1and our default budget, we get axis selectivity. Does doubling the training data, scalingn_factors, or adding a sparsity penalty cause the +1 and -1 directions to split into separate hidden units? The Memisevic paper claims yes for natural video; we don’t see it on random dots at this scale. - Transfer quality. The transfer outputs in
transfer_examples.pngare diffuse. Is that an artefact of single-pass mean-field reconstruction, or does running multiple alternating Gibbs steps (--n-gibbs > 2) sharpen them? The model definitely has the information — classification works — but the readout path is lossy. - Composing transformations. Can two stacked gated RBMs learn compositional codes (e.g. one layer for translation, one for rotation, with hidden codes that compose under ∘)? The paper hints at this but doesn’t run the experiment.
- Energy / data-movement comparison to a vanilla MLP that takes
[x, y]as input and predicts the transform label. A standard MLP should saturate at ~100% on this task; the gated RBM caps at ~40% at this scale. The real question (the v2 motivation in spec issue #1) is whether the gated RBM’s commute-to-compute ratio is better, not whether its accuracy is.
agent-tpairs-builder (Claude Code) on behalf of Yad — implementation notes for spec issue cybertronai/hinton-problems#1 (v2).
Bouncing balls (2 balls, TRBM)
Reproduction of the synthetic-video benchmark from Sutskever, I. & Hinton, G. E., “Learning multilevel distributed representations for high-dimensional sequences”, AISTATS 2007.
Demonstrates a Temporal RBM (RBM with directed temporal connections from the previous hidden state and the previous N visible frames) learning the joint distribution of a 2-ball-bouncing video, then rolling future frames forward given a short seed.

Problem
- Input: a synthetic video of two balls bouncing in a rectangular box.
Each pixel is binary; a pixel is on iff it is within
ball_radiusof any ball centre. Wall collisions are perfectly elastic. Ball-ball collisions are ignored — the balls pass through each other (matches Sutskever & Hinton’s original synthetic dataset). - Frame size: 16 × 16 pixels (256 visible units). Sutskever & Hinton used 30 × 30 — see Deviations for why we shrunk it.
- Sequence length: 50 training frames per sequence; 60 training sequences; 10-frame seed + 20-frame rollout at evaluation.
The interesting property: a single still frame tells you where the balls are but not where they are going. To predict the next frame the model needs at least two frames of context (to recover velocity), and to predict several frames out it needs a hidden state that carries velocity through time. The TRBM encodes this directly: directed connections feed the previous hidden state and previous visible frames into the current step’s biases.
Files
| File | Purpose |
|---|---|
bouncing_balls_2.py | Physics simulator + TRBM (W, W_hh, W_hv, W_vh, W_vv) + CD-1 trainer + rollout. CLI flags --seed --n-balls --h --w --n-epochs --n-hidden --n-lag --feedback. |
visualize_bouncing_balls_2.py | Renders example frames, top-25 W filters, training curves, and a side-by-side ground-truth vs rollout grid into viz/. |
make_bouncing_balls_2_gif.py | Renders bouncing_balls_2.gif — input video next to TRBM rollout, frame by frame. |
bouncing_balls_2.gif | The animation linked at the top of this README. |
viz/ | Static PNGs from the run below. |
results.json | Hyperparameters + environment + final metrics for the run below. |
Running
python3 bouncing_balls_2.py --seed 0 --n-balls 2 --h 16 --w 16 \
--n-sequences 60 --seq-len 50 --n-hidden 200 --n-lag 2 \
--n-epochs 50 --feedback sample --results-json results.json
Run wallclock: 6.2 s training + ~1 s rollout eval on a laptop CPU (M-series, numpy 2.2.5). Final per-frame CD-1 reconstruction MSE: 0.0268.
To regenerate the visualizations and the GIF:
python3 visualize_bouncing_balls_2.py --seed 0 --outdir viz
python3 make_bouncing_balls_2_gif.py --seed 0 --out bouncing_balls_2.gif
Results
| Metric | Value |
|---|---|
| Frame size | 16 × 16 = 256 visible units |
| Hidden units | 200 |
Visible-history lag (n_lag) | 2 frames |
| Training sequences × length | 60 × 50 frames |
| Final CD-1 reconstruction MSE | 0.0268 |
| Rollout MSE (1 held-out sequence, 20 future frames, seed 9999) | 0.0497 |
| Rollout MSE (20 held-out sequences, mean ± std) | 0.0672 ± 0.0329 |
| Baseline: predict-mean-frame | 0.0502 |
| Baseline: copy-last-seed-frame | 0.0973 |
| Training time | 6.2 s |
| Hyperparameters | lr=0.05, momentum=0.5, weight-decay=1e-4, batch-size=10, init-scale=0.01, k-gibbs (rollout)=5, feedback=sample |
| Reproducibility | seed 0; results in results.json; git commit recorded in env field |
Reproduces paper? Partial. The original paper trains a deeper / wider TRBM on 30 × 30 video and shows visually plausible multi-second rollouts. Our 16 × 16 single-layer TRBM trained with vanilla CD-1 gets between the two trivial baselines: better than copy-last-frame, worse than predict- mean-frame on per-frame MSE, but qualitatively correct in the first 3–4 rolled-out frames (the predicted ball moves in the right direction, see the GIF). The architecture and learning rule reproduce; the rollout horizon does not match the paper. The discussion below explains why and what would close the gap.
Visualizations
Example input frames

Eight evenly-spaced frames from one training sequence. Two binary balls bounce between the walls; you can see the trajectory turn over once a ball reaches a wall.
Hidden-unit receptive fields

The 25 hidden units (out of 200) with the largest L2 weight norm. Each panel
shows that hidden unit’s column of W reshaped as a 16 × 16 image, with red
for positive weights and blue for negative. Most filters are localised
position detectors — a positive blob at one location and inhibitory weights
nearby. A few have a more diffuse pattern that resembles a velocity / gradient
detector. This is the multilevel-distributed-representation point of the
original paper, showing up as overlapping localised position codes that tile
the box.
Training curves

Per-frame CD-1 reconstruction MSE drops from ≈ 0.27 at initialisation to
0.027 by epoch 50, while ‖W‖_F grows roughly linearly from 0 to ~10. The
recon MSE is mean( (v_pos - v_neg)² ) where v_neg = sigmoid(W h_pos + shifted-bias) and h_pos ~ p(h | v_pos, V_past, h_prev); this is lower
than the predict-mean-frame baseline (0.027 vs 0.050), so the model
correctly exploits the conditional information when v_t is given.
Ground truth vs rollout

Top row: 10 seed frames (orange border) followed by 12 ground-truth future
frames. Bottom row: same seed, then 12 frames generated by model.rollout(...).
The TRBM tracks ball motion correctly for the first 2–3 future frames (the
predicted ball stays near the right pixel column and continues in the seed’s
direction of motion), then diffuses toward the mean-frame as the
autoregressive feedback signal weakens.
Deviations from the original procedure
- 30 × 30 → 16 × 16 frames. The original paper used 30 × 30. At 16 × 16 the entire pipeline (data + train + rollout + viz) finishes in under 10 seconds on a laptop, comfortably below the v1 spec’s 5-minute budget. The qualitative phenomenon — ball-position filters, partial rollout tracking — is the same; the larger frame would mostly buy a longer visually-plausible rollout horizon.
- Single-layer TRBM, not stacked. The 2007 paper trains stacks of TRBMs greedily for higher-quality rollouts. We keep a single layer here as the v1 baseline — the stacking variant is the natural follow-up.
- Visible-frame lag = 2. A pure h_{t-1}-only TRBM (no v_{t-1} → v_t,
no v_{t-1} → h_t) cannot extract velocity from a single previous frame
under CD-1 training; we observed it collapsing to mean-frame even at the
first rolled-out step. Including the previous N visible frames
directly (the conditional-RBM family Taylor, Hinton & Roweis 2006/2007
used for motion-capture data, and that Sutskever & Hinton 2007 subsume
under “TRBM”) with
n_lag = 2gives the model enough context to predict one step of motion sharply. The CLI default is--n-lag 2;--n-lag 1reproduces the velocity-blind variant. - CD-1 instead of full BPTT through time. The 2007 paper does
credit-assignment through the recurrent hidden chain. We treat
h_{t-1}as fixed during the per-frame CD-1 update, which is the simpler and faster choice but does not learn long-range temporal dependencies. This is the dominant reason the rollout horizon in our implementation is shorter than the paper’s. - Mean-field visible during the negative phase, sampled hidden. Standard for a Bernoulli-visible RBM on near-binary data; the gradient form is unchanged.
- Two balls, no ball-ball collisions. Matches the spec and the original simplification. Adding elastic ball-ball collisions changes the data distribution but not the architecture.
- Feedback strategy. During rollout the predicted v_t must be folded
back into V_past for the next step. Mean-field feedback smears fast;
we default to
--feedback sample(Bernoulli sample of the predicted probability), withbinarise(threshold at 0.5) andmeanavailable. This is a procedural choice not present in the paper, made necessary by the soft-output format of mean-field inference.
Open questions / next experiments
- BPTT credit assignment. The single biggest gap is that we don’t back-propagate through the recurrent h_{t-1} chain. Following Sutskever & Hinton’s RTRBM extension (2008) — which differentiates through the expected-h pathway — should sharply extend the rollout horizon. This is the natural follow-up.
- Stacking. Greedy layer-wise training of a 2- or 3-layer TRBM stack, as in §4 of the 2007 paper. The deeper representation should let the model encode multi-step velocity / trajectory features and produce visually plausible rollouts over the full 20-frame horizon.
- Compare to predict-mean-frame on a trajectory metric. Per-frame MSE rewards predicting the marginal mean. A position-tracking metric (e.g. centre-of-mass distance to truth, or top-K pixel overlap) would better reward the TRBM’s qualitatively-correct early predictions and is a more honest figure of merit for this benchmark.
- 30 × 30 frames. The 16 × 16 shrink is a v1 convenience. Re-running at the paper’s resolution is mostly a matter of compute and would produce a more direct comparison.
- Energy metric. Once the v1 baseline is in, the natural next step in the broader Sutro effort is instrumenting this stub under ByteDMD to see what the data-movement cost of CD-1 sequence training looks like vs the stronger BPTT variant — the two have very different commute-to- compute ratios.
Bouncing balls (3 balls, RTRBM)
Source: Sutskever, I., Hinton, G. E., & Taylor, G. W. (2008/2009), “The recurrent temporal restricted Boltzmann machine,” NIPS 21.
Demonstrates: A Recurrent Temporal RBM (RTRBM) — an RBM whose hidden
bias at time t is shifted by W_h r_{t-1} — can model continuous video
of three balls bouncing in a box. The recurrent matrix W_h is the only
structural addition over a per-frame RBM.

Problem
- Visible: 30 × 30 = 900 binary pixels (anti-aliased disks rendered by the simulator; values in [0, 1]).
- Hidden: 100 binary hidden units.
- Recurrent: a 100 × 100 hidden-to-hidden matrix
W_h. This is the only structural difference between an RTRBM and a per-timestep RBM. - Sequences: 30 sequences of length 100 generated fresh per run by a Newtonian simulator (3 balls, radius 3, speed 1, elastic wall and ball-ball collisions).
The interesting property: bouncing balls have deterministic-ish
trajectories interrupted by stochastic collisions, so the next frame is
almost — but not quite — predictable from the current one. A pure
per-frame RBM can model the marginal pixel distribution, but it has no way
to express which way the ball is going. The RTRBM gets velocity for free
by passing the mean-field hidden expectation r_t into the next step’s
hidden bias. The same hidden code that “explains” frame t also “primes”
the units that will explain frame t+1.
The RTRBM defines
p(v_{1:T}) = ∏_t p(v_t, h_t | r_{t-1})
r_t = E[h_t | v_{1:t}] = σ(W v_t + b_h + W_h r_{t-1})
Each conditional is an RBM whose hidden bias has been shifted by
W_h r_{t-1}. The shifted bias is the recurrence channel.
Files
| File | Purpose |
|---|---|
bouncing_balls_3.py | Simulator (simulate_balls, make_dataset), model (build_rtrbm, RTRBM), training (forward_mean_field, cd_step_sequence, train), inference (rollout, teacher_forced_recon, free_rollout_mse). CLI for the headline run. |
visualize_bouncing_balls_3.py | Static PNGs: data samples, receptive fields, recurrent matrix, training curves, reconstructions, rollout grid. |
make_bouncing_balls_3_gif.py | Generates bouncing_balls_3.gif (the animation at the top of this README). Trains once, then snapshots a held-out rollout at multiple checkpoints. |
bouncing_balls_3.gif | 5-checkpoint rollout animation showing the prediction sharpening as training progresses. |
viz/ | Output PNGs from the headline seed=0 run. |
problem.py | Original stub signatures. Re-exported by bouncing_balls_3.py. |
Running
# headline run (~5s training, ~10s with viz):
python3 bouncing_balls_3.py --seed 0
# regenerate viz/*.png:
python3 visualize_bouncing_balls_3.py --seed 0
# regenerate bouncing_balls_3.gif:
python3 make_bouncing_balls_3_gif.py --seed 0
Defaults: 30 sequences × length 100, 30×30 frames, 100 hidden units, 50 epochs, lr=0.05, momentum=0.9, weight-decay=1e-4. Total wallclock ~3 s on a laptop for the bare training loop, ~10 s including viz.
Results
Headline run (seed = 0)
| Metric | Value |
|---|---|
| Final CD-1 reconstruction MSE | 0.0053 |
| Validation teacher-forced MSE (held-out seq) | 0.0084 |
| Validation free-rollout MSE (warmup=10, future=30) | 0.13 |
| Training time | 3.4 s |
| Pixel mean (data) | 0.063 |
| Pixel variance (data) | ≈0.06 |
Stability across 5 seeds (seeds 0–4)
| Metric | Mean | Std |
|---|---|---|
| Final CD-1 reconstruction MSE | 0.0053 | 0.0000 |
| Teacher-forced MSE (held-out) | 0.0092 | 0.0007 |
| Free-rollout MSE | 0.1315 | 0.0118 |
| Training time | 3.08 s | 0.16 s |
Reconstruction is the well-converged metric — the RBM piece reliably learns the pixel distribution. The free-rollout MSE is in the same ballpark as the data variance (~0.06), which means the rollout blurs over many steps but does keep ball-shaped blobs at plausible positions. This is consistent with the simplified training scheme used here (see Deviations below).
Comparison to the paper
Sutskever et al. report log p(test seq) rather than reconstruction MSE,
on a similar 30×30 / 3-ball setup. We don’t compute the partition
function so the numbers aren’t directly comparable. Reproduces:
qualitatively yes — the RTRBM trains stably, learns spatial receptive
fields, and rolls out plausible ball-like blobs. Quantitatively: we
do not match the paper’s per-step log-likelihood because we use the
simplified gradient (no BPTT-through-time correction). See Open
questions for what closing that gap would look like.
Hyperparameters used
| Param | Value |
|---|---|
h × w (frame size) | 30 × 30 |
n_balls / radius / speed | 3 / 3.0 / 1.0 |
n_sequences × seq_len | 30 × 100 |
n_hidden | 100 |
n_epochs | 50 |
lr | 0.05 |
momentum | 0.9 |
weight_decay (L2 on W and W_h) | 1e-4 |
init_scale | 0.01 |
b_v init | logit(data_mean) |
b_h, r_init, W_h diag | 0 |
| Gibbs steps per future frame (rollout) | 25 |
Catalog row (per spec issue #1)
| Problem | Source paper (year) | Reproduces? | Difficulty | Run wallclock |
|---|---|---|---|---|
bouncing-balls-3 | Sutskever, Hinton & Taylor (2008) | qualitative yes (no per-frame log-lik comparison; see Deviations) | medium (RTRBM + simulator + GIF + 5 PNGs in one wave) | 3.4 s training; ~10 s incl. viz |
Visualizations
Receptive fields (viz/receptive_fields.png)

Each subplot is the incoming weight column W[:, j] for one of the
hidden units, reshaped to a 30 × 30 image (red positive, blue negative).
The fields are diffuse spatial filters — most units span a broad area of
the box rather than localizing tightly to a single ball position. With
only 30 sequences × 100 frames as training data, hidden units have not
fully tiled the input space; longer training and more data sharpen these.
Recurrent matrix (viz/recurrent_matrix.png)

The 100 × 100 recurrent matrix W_h. The strong red diagonal is the
key learned structure: W_h[j, j] > 0 means “if hidden unit j was
active last step, it tends to stay active this step.” That self-loop is
how velocity gets encoded — a ball moving through a region where unit j
fires will keep firing it for a few frames in a row.
Training curves (viz/training_curves.png)

CD-1 reconstruction MSE drops from ~0.07 (b_v alone, the data-mean prior) to ~0.005 in 50 epochs. The phase transition is around epoch 5; after that improvements come from sharpening individual receptive fields. The log-scale plot makes the epoch-30+ tail visible.
Sample data (viz/data_samples.png)

The first 16 frames of training sequence 0. Three round white balls move on a black background; trajectories are straight lines until a wall or another ball is hit, at which point the velocity reflects.
Reconstructions (viz/reconstructions.png)

Top row: held-out validation frames. Middle row: teacher-forced
reconstruction σ(r_t W^T + b_v) where r_t is the mean-field hidden
state computed from the data. Bottom row: the mean hidden activation
r_t (one column per frame). The reconstructions track the ball
positions cleanly — the per-frame RBM has learned the spatial structure.
Rollout grid (viz/rollout_grid.png)

Top row: ground-truth bouncing-balls sequence. Bottom row: model
rollout — the first warmup=10 frames are warmup (the model sees
ground truth and infers r_t), then the model generates the next 30
frames purely from its own predictions. Predictions stay ball-like and
plausibly placed for ~5–10 free steps, then drift / blur as compounding
prediction error accumulates.
Deviations from the original procedure
-
No BPTT-through-time correction in the gradient. The full RTRBM gradient backpropagates a
r_{t+1}term back throughr_t(becauser_tenters next step’s hidden bias). We use the simplified CD-1 gradient that only updatesW_hfrom per-step contrastive contributions:dW_h += (r_t − h_neg_t) ⊗ r_{t-1}. This is the “TRBM-style” approximation (Sutskever et al. and follow-ups note it trains stably but produces weaker rollouts than the full RTRBM gradient). The recurrent diagonal still emerges; longer-horizon forecasting is the part that suffers. -
n_hidden = 100, not 400/3000. The original paper uses larger hidden layers (400 in the smaller experiments, 3000 in the larger ones). At 100 hidden units the spatial dictionary is necessarily undercomplete, which is part of why receptive fields look diffuse rather than crisp blobs. We chose 100 to keep the full pipeline (training + viz + GIF) under ~10 seconds on a laptop.
-
Anti-aliased pixel rendering instead of binary disks. Each ball contributes
clip(radius + 0.5 − dist, 0, 1)per pixel, then balls max-combine. The visible units are still treated as Bernoulli sigmoids (so values in [0, 1] are interpretable as probabilities). The original paper used binary disks. The anti-aliased version gives the model a smoother training signal at edges. -
Mean-field positive phase. We use
r_t(the mean-field hidden expectation) directly as the positive-phase hidden activation instead of sampling. This is the standard choice in RTRBM implementations and what the original paper recommends; mentioned here for completeness. -
No persistent contrastive divergence (PCD). Plain CD-1 within each timestep. PCD/Tieleman-2008 would be a strict improvement on the gradient bias.
Open questions / next experiments
-
Does the BPTT correction actually help here? Adding the through-time gradient
∂L/∂r_t = W_h^T (r_{t+1} − h_neg_{t+1}) + diag(r_t (1 − r_t)) · ...should sharpen the rollouts. As an ablation we ran a “static” variant (W_h ≡ 0) and got teacher-forced MSE 0.009 / free-rollout MSE 0.125 — essentially tied with the full RTRBM (0.009 / 0.13). That suggests in this simplified gradient regime the recurrent matrix is barely doing work. The right experiment is to add BPTT and rerun the same ablation; if BPTT is needed forW_hto actually help, then this stub is a clean test case for “what does the full Sutskever et al. gradient buy you?” -
Energy / data-movement cost (per the broader Sutro effort). Per epoch the cost is dominated by
v ⋅ Wandr ⋅ W_h, bothO(seq_len · n_hidden · n_visible)andO(seq_len · n_hidden^2)respectively. ByteDMD on a single CD step would tell us how much of that lives in cache. The temporal axis adds an interesting wrinkle: do consecutive frames preserve their hidden-state working set, or does each timestep refetch? -
Larger hidden layer. Does receptive-field tiling become local (one unit per spatial blob) at
n_hidden=400, as the paper claims? At 100 the receptive fields are diffuse spatial filters; at the paper’s setting they would presumably tile. -
Longer rollouts and chaos. Free rollouts blur after ~10–20 steps because of compounding prediction noise. How does this scale with
n_gibbsand withn_hidden? Does adding a small amount of visible-side noise during training (denoising RTRBM) push the free horizon out? -
Generalisation across
n_balls. Train on 3 balls, test on 1, 2, 4. The recurrent matrix should not care about ball count; if free rollouts on (say) 4 balls remain stable that is evidence the model has learned dynamics rather than a per-frame memorization. -
Varying ball masses / radii. The paper considers equal balls. Mixed-mass collisions break the velocity-swap shortcut and would stress the learned recurrence harder.
Transforming auto-encoders
Numpy reproduction of Hinton, Krizhevsky & Wang, “Transforming auto-encoders”, ICANN 2011 — the seminal capsule paper.
The translation-only variant: each “capsule” learns a recognition head that
outputs (presence, x, y) for its entity, then a generative head that
produces a 22x22 reconstruction patch from the post-transformation
coordinates (x + dx, y + dy). The full output is a presence-weighted sum
across capsules.

Problem
Take a centered MNIST digit, randomly translate the input by t_in (in the
range +/-5 pixels), then translate it again by (dx, dy) to produce the
target. The network sees (input_image, dx, dy) and must reconstruct the
22x22 centered crop of the target image.
The point of the experiment is not the reconstruction quality per se. It’s
the architectural constraint: the only path from (dx, dy) to the output
is by being added to each capsule’s (x, y) recognition output. So if the
loss decreases, the network must have learned an (x, y) representation
that lives in pixel-equivalent units. That’s the disentanglement.
The disentanglement test: take a held-out pair (image1, image2) related
by an unknown (dx, dy), run the recognition heads on each, and read off
dx_pred = mean_c (x_c^{(2)} - x_c^{(1)}) (weighted by presence). If the
network learned the disentanglement, predicted (dx, dy) correlates with
truth.
- Why MNIST is jittered at the input. With centered digits the entity is
always at
(14, 14)and the recognition(x, y)outputs collapse to constants — the network can satisfy the loss by learning a fixed template per capsule and translating it via(x_c + dx, y_c + dy). With random input jitter,(x, y)must track the entity’s position to reconstruct correctly.
Architecture
| Stage | Layer | Shape per capsule | Activation |
|---|---|---|---|
| Recognition hidden | linear | 784 -> 20 | sigmoid |
| Instantiation | linear | 20 -> 3 (p, x, y) | sigmoid on p, linear on (x, y) |
| (transform) | add | (x', y') = (x + dx, y + dy) | – |
| Generative hidden | linear | 2 -> 128 | ReLU |
| Generative output | linear | 128 -> 484 (22x22) | sigmoid |
| Aggregation | sum | recon = sum_c p_c * patch_c | – |
- 30 capsules, all per-capsule weights stacked (no Python-level capsule loop).
- Per-capsule contractions go through
np.matmulrather thannp.einsum— the equivalenteinsum('bch,chp->bcp', ...)runs ~90x slower because numpy doesn’t dispatch that signature into BLAS. See helpers_per_cap_*intransforming_autoencoders.py. - Trained with Adam (lr=1e-3, beta=(0.9, 0.999)). SGD+momentum did learn but reached only R²(dx)~0.06 in 15 epochs vs Adam’s ~0.5.
Files
| File | Purpose |
|---|---|
transforming_autoencoders.py | Model, MNIST loader, transform-pair generator, training loop, predict_transformation. CLI: --seed --n-epochs --n-capsules. |
visualize_transforming_autoencoders.py | Static figures: training curves, example pairs, capsule presence heatmap, predicted-vs-true (dx, dy) scatter, reconstruction grid. |
make_transforming_autoencoders_gif.py | Animated training GIF. |
transforming_autoencoders.gif | Output of the GIF script (1.4 MB). |
viz/ | Static PNG outputs from the visualization script. |
Running
# Train and print per-epoch metrics (~100 sec for 30 epochs)
python3 transforming_autoencoders.py --n-epochs 30 --seed 0
# Train + render all static figures into viz/
python3 visualize_transforming_autoencoders.py --n-epochs 30 --seed 0 --outdir viz
# Train + render the animated GIF
python3 make_transforming_autoencoders_gif.py --n-epochs 30 --snapshot-every 200 --fps 8
The MNIST loader downloads train-images-idx3-ubyte.gz from
storage.googleapis.com/cvdf-datasets/mnist/ on first run and caches it at
~/.cache/hinton-mnist/.
Results
Defaults: 30 epochs x 200 steps x batch 64 = 6,000 Adam updates over 10,000 MNIST images, on a single thread. Run wallclock ~100 s on an M-series Mac.
| Metric | Value |
|---|---|
| Final train MSE | 0.063 |
| Final val MSE | 0.064 |
| R²(dx) on held-out pairs | 0.78 |
| R²(dy) on held-out pairs | 0.67 |
(dx, dy) prediction policy | top-3 capsules by min(p1, p2) (vs full-mean: 0.61 / 0.43) |
| Active capsules per image (median, threshold p > 0.1) | 9 |
| Wallclock | ~100 s for 6,000 Adam updates |
Per-step time was 14 ms after replacing the per-capsule einsums with
np.matmul (down from 515 ms — 36x).
Predicted vs true (dx, dy)

Held-out pairs of MNIST digits; the network predicts (dx, dy) from the
recognition outputs alone (no (dx, dy) is ever supplied at inference).
Diagonal is y = x. Most points cluster on the diagonal; the spread at
extreme shifts reflects that t_in + dxdy can push the digit nearly out
of the 28x28 canvas, leaving the recognition layer with very little to
match.
Training curves

Reconstruction MSE drops fast in the first 2 epochs (the network learns to
output a centered blob), then slowly as it learns the (x, y) -> patch
geometry. R² stays near zero for ~5 epochs while the recognition layer
saturates, then climbs sharply between epochs 6 and 14. dx is easier than
dy in this seed; we believe this is a seed-level asymmetry (different seeds
flip which axis is faster).
What the capsules learn

Across 24 random validation digits, ~9 of the 30 capsules are active per
image (median) at p > 0.1. A small set of capsules (~3, 11, 17, 19) fire
on most inputs; the rest are mostly silent. This sparse-presence pattern
is what allows the top-3 inference rule to dramatically outperform the
full-mean (R²(dx) 0.78 vs 0.61).
Reconstructions

The network produces blob-like reconstructions — recognizable as digits when shape is simple (4, 5, 6) but blurred for thinner strokes. With only 30 capsules x 128 ReLU units and ~6,000 updates this is expected; the paper’s experiments use much wider generative nets and longer training. The point of this implementation is the disentanglement, not the visual quality of the reconstruction.
Deviations from the 2011 paper
- Translation only. The paper covers 2D translation, scaling, and
full 2D / 3D affine transformations (with 9-component instantiation
parameters). This implementation handles the 2D-translation case only,
so the recognition head is 3-dim
(p, x, y)rather than 9-dim. - Adam instead of SGD with momentum. The paper uses plain SGD; we found Adam reaches comparable disentanglement in roughly 5x fewer updates on this seed.
- MNIST input jitter is explicit. The paper’s pipeline draws training
pairs by transforming a source image; our
make_transformed_pairsamples both an input jittert_inand a(dx, dy)so that the recognition(x, y)outputs have non-trivial signal to track. Withoutt_in, MNIST digits are always centered and(x, y)degenerates to a constant. - 22x22 centered crop as target. The paper places each capsule’s
patch at its predicted
(x', y')and sums into a full-size canvas; we sum patches at a fixed canonical center and compute MSE against the centered crop of the transformed image. This is the simpler invariant and cleanly tracks the disentanglement metric (you can shift the canvas by(dx, dy)without changing the loss formulation).
Correctness notes
- Per-capsule matmul helpers.
_per_cap_weight_grad,_per_cap_input_grad, and friends are equivalent to the naturalnp.einsumformulations (verified in the smoke tests atatol=1e-3). They are written asnp.matmulcalls because numpy’s einsum path for'bch,chp->bcp'-style contractions does not go through BLAS — direct benchmark on this problem: einsum 109 ms / matmul 1.2 ms for theW_gen2op. - Top-k inference.
predict_transformationdefaults totop_k=3(selecting the 3 capsules with highestmin(p_1, p_2)per pair). On this trained model this lifts R²(dx) from 0.61 to 0.78 — most capsules havep ~ 0on any given input and contribute pure noise to the averaged(x, y)difference. Passtop_k=Noneto use all capsules weighted by(p_1 + p_2)/2. (x, y)units are pixel units. The architecture forces this:dxis in integer pixels, gets added tox, and the generative net’s output must shift bydxpixels to match the target. So the network has to put(x, y)on the same scale as the translation.- dx vs dy asymmetry. The dx and dy R² curves are not identical at any one seed. Across seeds the asymmetry flips. With 30 capsules and only ~9 active per input, capsule allocation between “horizontal” and “vertical” detectors is decided early in training and isn’t always balanced. Multi-seed averages would smooth this; we report a single seed for clarity.
Open questions / next experiments
- Full affine. The paper’s recognition head is 9-dim per capsule (presence + 2D-to-2D affine). Extending the same pipeline to scaling and rotation would test whether the matmul-vectorized batched-capsule setup scales without code changes (only the recognition output dim and the transformation rule change).
- Larger generative nets. With 128 ReLU units the reconstructions are blob-like. A wider generative net (paper-scale: 200+ units) should make reconstruction qualitatively better while leaving the disentanglement invariant — useful as a sanity check.
- Sparsity penalty on
p. The trained model already converges to ~9 active capsules per input; an explicit L1 penalty onpmay push that down further and produce sharper “what” detectors per capsule. - End-to-end translation prediction. Currently
predict_transformationuses an explicit “subtract(x, y)outputs” rule. Training a small classifier head on top of(xy_2 - xy_1)would let us see how much of the residual variance is recoverable by a learned readout.
Deep Lambertian synthetic spheres
Reproduction of the joint albedo / surface-normal recovery experiment from Tang, Salakhutdinov & Hinton, “Deep Lambertian Networks”, ICML 2012.

Problem
Given several images of the same Lambertian-shaded surface taken under different known light directions, jointly recover the surface’s per-pixel albedo (RGB reflectance) and per-pixel surface normal. The Lambertian image-formation model is:
pixel(p, k) = albedo(p) * max(0, normal(p) . light_dir(k)) * intensity
The “Deep Lambertian Network” of Tang et al. couples a Gaussian RBM prior
over albedo / normal latents with this fixed image-formation model. The
synthetic-spheres setting is a controlled benchmark with known ground truth:
generate spheres with random RGB albedo, render them under several random
upper-hemisphere lights, and ask the network to recover both. With known
geometry (centred unit sphere, ~80% image fill) we can evaluate the
recovered normals against an analytic ground truth n = (x, y, sqrt(1 - x^2 - y^2))
inside the silhouette mask.
- Resolution: 32 × 32 RGB
- Albedo: per-sphere RGB ~
np.random.uniform(0, 1, size=3) - Lights: per-view random unit vectors with
n_z > 0.15(4–8 per sphere; default 6) - Sphere fill: ~80% of the smaller image dimension
- No shadows (pure Lambertian, paper’s setup)
Files
| File | Purpose |
|---|---|
deep_lambertian_spheres.py | Renderer + dataset gen + per-pixel encoder + Lambertian decoder + training loop. CLI entry point. |
visualize_deep_lambertian_spheres.py | Static training curves, dataset montage, albedo / normal recovery panels, per-light reconstruction. |
make_deep_lambertian_spheres_gif.py | Builds deep_lambertian_spheres.gif. |
deep_lambertian_spheres.gif | Animated training progress (37 frames, 0.33 MB). |
viz/ | PNG outputs from the static visualiser. |
Running
Quick check (smoke test, ~3 seconds):
python3 deep_lambertian_spheres.py --seed 0 --n-spheres 32 --n-test-spheres 8 --n-epochs 5
Full training run (~30 seconds on a laptop, drives the numbers below):
python3 deep_lambertian_spheres.py --seed 0 --results-json viz/run.json
To regenerate the static visualisations:
python3 visualize_deep_lambertian_spheres.py --seed 0 --outdir viz
To regenerate the GIF (uses 80 epochs to keep it under 1 MB):
python3 make_deep_lambertian_spheres_gif.py --seed 0 --n-epochs 80 --snapshot-every 4 --fps 10 --dpi 75
Results
Held-out test set (64 spheres, seed 10000):
| Metric | Value | Target |
|---|---|---|
| Normal angular error, mean | 27.01° | < 30° |
| Normal angular error, median | 23.71° | — |
| Albedo MSE (per-sphere RGB) | 0.0120 | constant-predictor baseline ≈ 0.083 |
| Reconstruction MSE | 0.01024 | — |
| Training wallclock | 33.3 s | — |
Defaults (locked in _parse_args):
n_spheres=400, n_test_spheres=64, n_lights_per_sphere=6, n_epochs=120,
resolution=32, hidden=192, batch_size=1024, lr=0.02 (cosine 1.0 -> 0.05),
momentum=0.9
Environment: Python 3.11.10, numpy 2.3.4, macOS-26.3-arm64.
Dataset examples

Each row is one sphere; each column is the same sphere rendered under one of the 6 random upper-hemisphere lights. The shaded silhouette shifts as the light direction rotates; the underlying 3D shape and RGB albedo are identical within a row.
Albedo recovery

Top row: ground-truth per-sphere RGB albedo. Bottom row: recovered albedo, computed as the per-pixel mean of the encoder’s output over each sphere’s silhouette mask. The L2 error per swatch is printed below. Mean MSE across 64 held-out spheres is 0.0120 vs. ~0.083 for a constant-mean predictor.
Normal recovery

Three-panel grid for four held-out spheres: the analytic GT normal map
(RGB-encoded as ((n_x + 1)/2, (n_y + 1)/2, n_z)), the recovered normal
map, and the per-pixel angular error in degrees. Error concentrates at the
silhouette ring where GT n_z -> 0 while the encoder’s output is bounded
above 0; the interior error is much lower (median 23.7°).
Per-light reconstruction

For one held-out sphere: input view, network re-render, and absolute residual under each of the 6 lighting conditions. The Lambertian decoder re-renders from the shared recovered albedo and per-pixel normals using each frame’s own light direction, so reconstruction quality is a direct test of whether the encoder factored the input correctly.
Training curves

Mean angular error crosses the 30° target around epoch 60. Reconstruction MSE and albedo MSE both decrease monotonically under the cosine schedule. A constant LR of 0.02 (no decay) reaches roughly the same plateau but goes unstable around epoch 75 (loss spike); cosine decay to 5% of peak LR removes that.
Architecture
The full Deep Lambertian Network of the paper has three pieces:
- A Gaussian RBM prior over visible image pixels conditioned on inferred (albedo, normal) latents.
- A deterministic Lambertian decoder mapping (albedo, normal, light_dir) -> rendered image.
- Variational / contrastive training over the joint.
This v1 replaces (1) and (3) with a feed-forward per-pixel encoder trained by reconstruction MSE — i.e. drops the GRBM prior. The decoder is identical to the paper.
per-pixel features (B, 6N) B = #pixels (across spheres in a batch)
[N RGB observations | N light directions]
|
W1 (6N -> 192) + ReLU
|
W2 (192 -> 5)
| split
v
sigmoid(z2[:, :3]) -> albedo (3) in [0, 1]
tanh(z2[:, 3:5]) -> n_xy (2) scaled to [-0.985, 0.985]
sqrt(1 - n_x^2 - n_y^2) -> n_z (1) forces unit normal, n_z > 0
|
Lambertian decode with each frame's known light_dir
|
reconstruction MSE over (B, K, 3)
The encoder is per-pixel: it sees only that pixel’s RGB observations
under the N lights plus the N light directions. It does not see pixel
coordinates or the silhouette mask. Photometric stereo theory (Woodham
1980) says this inverse problem is identifiable from >=3 non-coplanar
upper-hemisphere lights; with 6 random lights it is well-conditioned.
Manual backprop (backward() in deep_lambertian_spheres.py) handles the
ReLU image-formation cut at max(0, n . l) and the n_z = sqrt(...)
hemisphere parametrisation.
Deviations from the paper
- No GRBM. v1 uses a deterministic feed-forward encoder, not a Gaussian-RBM-coupled latent model. This simplifies training to plain SGD with reconstruction loss and removes the contrastive-divergence sampler. Expected cost: less robust to noise in the inputs (the GRBM’s prior would help denoise). On this clean synthetic data the simplification is essentially free.
- Per-pixel inference. The paper’s GRBM couples nearby pixels via the prior. Here each pixel is decoded independently. The resulting normal maps are slightly noisier than spatially-coupled estimates, especially near the silhouette.
- Closed-form Lambertian decoder, no shadows. Matches the paper’s reported synthetic-sphere setup.
- Fixed
n_lights = 6(parameterisable on the CLI, in[4, 8]per the spec). The paper sweeps the number of lights; we report a single working point.
Correctness notes
-
Hemisphere parametrisation. Predicting
n_xy = 0.985 * tanh(z2)and computingn_z = sqrt(max(1 - n_x^2 - n_y^2, 1e-6))keeps the network on the unit upper hemisphere for free, and boundsn_zaway from zero so the gradientdn_z / dn_x = -n_x / n_zis finite. Without the0.985shrink-factor the loss occasionally explodes when a pixel slides onto the equator andn_z -> 0. -
ReLU at
max(0, n . l). The Lambertian formula has a hard cut at the day-night terminator; backward is the standard ReLU maskI[n . l > 0]. The decoder forwardsd_clip = max(0, dot); backward zeros gradients for pixels that fall on the dark side of any given light. -
Per-sphere albedo aggregation. The encoder predicts a per-pixel albedo, but the GT albedo is per-sphere. We report MSE between the per-sphere ground truth and the pixel-mean of the predicted albedo across the sphere’s silhouette. This is the right metric: Lambertian decoding is invariant to a pixelwise (albedo * cos) trade-off only when the cos term is fixed, but with multiple lights the trade-off is broken and pixelwise albedo is identifiable.
-
Light-direction conditioning. Light directions are random per-sphere so the encoder cannot memorise them; they are passed as part of the per-pixel input feature. Removing them from the input drops the recovery to chance — the network has no way to do photometric stereo without knowing where the lights are.
-
Silhouette ring. Most of the residual angular error comes from the silhouette ring (
n_z -> 0). Median angular error (23.71°) is well below the mean (27.01°) for this reason. A sharper hemisphere parametrisation (or letting the encoder also predict the mask) would close most of this gap.
Open questions / next experiments
- Add the GRBM prior. Tang et al.’s actual contribution is the GRBM-on-latents prior, which should help most on noisier inputs. v1 has no test for this — we render clean, noise-free images. Adding Gaussian noise to the observations and showing the GRBM-augmented variant outperforms the deterministic baseline would directly reproduce a paper claim.
- Spatial coupling. A small convolutional encoder (locally connected, numpy-only) should reduce the silhouette-ring error by allowing neighbouring pixels to vote on each other’s normals.
- Vary
n_lights. Sweepn_lights in {3, 4, 6, 8}and report the recovery curve. With 3 lights the inverse problem becomes linear and closed-form (Woodham’s photometric stereo); the network should match it. With more lights the encoder should be more forgiving of grazing / collinear configurations. - Compare to closed-form photometric stereo. Per-pixel pseudo-inverse
g = pinv(L) @ I,albedo = ||g||,normal = g / ||g||is the textbook baseline. Adding it as a sanity-check oracle would calibrate what fraction of the error is irreducible (light coverage) vs. network-induced. - Anchor with the paper’s quoted numbers. Tang et al. 2012 report surface-normal recovery numbers on a different (face) dataset; we’d need to either render their sphere setup faithfully or move to the face dataset to make a head-to-head comparison.
RNN pathological long-term-dependency tasks
Source: Sutskever, Martens, Dahl & Hinton (2013), “On the importance of initialization and momentum in deep learning”, ICML. Tasks originally proposed in Hochreiter & Schmidhuber (1997), “Long Short-Term Memory”, Neural Computation 9(8):1735-1780 (the LSTM paper, sections 5.1-5.3).
Demonstrates: A vanilla tanh RNN trained with SGD + Nesterov-style momentum + gradient clipping can solve long-range memory tasks only if the recurrent weight matrix is initialized as a random orthogonal matrix. Holding everything else fixed, the same architecture with a small-Gaussian recurrent matrix at the same T stays at the chance baseline forever. The gap is the headline of the 2013 paper and the structural reason orthogonal init became the default for vanilla RNNs.

Problem
The Hochreiter-Schmidhuber pathological tasks isolate one capability per task: hold a piece of information for T timesteps, ignoring noise in between. We implement four of the seven; three are run as the headline experiment.
| Task | What the network must do | Output |
|---|---|---|
addition | Sum two real-valued markers placed in a stream of i.i.d. Uniform[0, 1] noise. One marker in the first half, one in the second. | scalar regression (MSE) |
xor | Same layout as addition but values are binary; target is XOR of the two marked bits. | 2-class softmax |
temporal_order | Vocabulary of 6 symbols. At one cued position in the first 10-20% and one in the 50-60% region, A or B appears; the rest are distractors. Classify which (sym1, sym2) pair was placed. | 4-class softmax |
3bit_memorization | First 3 timesteps drop a 3-bit pattern; the rest is i.i.d. noise; the final timestep is a query token. Recall the 3-bit pattern. | 8-class softmax |
For each task the network outputs a single prediction at the final timestep, so the gradient must flow back through T stacked tanh layers via BPTT. With T = 30..60 and a vanilla RNN, vanishing/exploding gradients are the dominant failure mode – which makes the choice of W_hh initialization the critical hyperparameter.
The interesting property: orthogonal W_hh keeps every singular value at exactly 1, so the Jacobian product through T timesteps preserves gradient norm. Random Gaussian W_hh (small scale, e.g. N(0, 0.1)) has singular values ranging from ~0.02 to ~1.6 – the small ones produce vanishing gradients in the directions they span. The plot in viz/spectrum_W_hh.png shows this directly.
XOR is the documented hardest of the seven (Sutskever et al. 2013, table 2: ~8x more iterations than addition). It is implemented and verified to run, but plain SGD + momentum at our budget cannot solve it under either init – both stay at 50%. We exclude XOR from the headline runs and keep it as a documented failure case in §Open questions.
Files
| File | Purpose |
|---|---|
rnn_pathological.py | Task generators (addition, xor, temporal_order, 3bit_memorization); vanilla tanh RNN with orthogonal or random W_hh init; manual BPTT in numpy; SGD + momentum + global gradient clipping; chance-level baselines; CLI (--task, --init, --sequence-len, --seed, …) and an --all mode that runs all (task, init) combos and dumps results.json. Numpy only. |
visualize_rnn_pathological.py | Reads results.json. Emits ortho_vs_random.png (training curves, the headline), summary_table.png (final-metric grid), spectrum_W_hh.png (singular-value spectra at init and after training – the structural explanation), and task_examples.png (one input/target visualised per task). |
make_rnn_pathological_gif.py | Animated GIF on 3bit_memorization (T=60), training both inits to 80 epochs and animating the loss / accuracy curves diverging. Default output rnn_pathological.gif. |
rnn_pathological.gif | Committed N = 64-hidden, T = 60, 80-epoch animation (122 KB). |
viz/ | Committed PNG outputs. |
results.json | Cached per-(task, init) histories and config from the headline run; consumed by the visualizer. |
Running
Reproduce the headline experiment (3 tasks x 2 inits = 6 runs, fresh seeds):
python3 rnn_pathological.py --all --seed 0
Wall-clock: 42 s total on an M-series MacBook (about 4-9 s per run, depending on T). Writes results.json.
Then regenerate the visualizations and the GIF:
python3 visualize_rnn_pathological.py # reads results.json -> viz/*.png
python3 make_rnn_pathological_gif.py # ~25 s, writes rnn_pathological.gif
Single-task run (e.g. only addition with random init):
python3 rnn_pathological.py --task addition --init random --sequence-len 30 --seed 0
python3 rnn_pathological.py --task temporal_order --init ortho --sequence-len 60 --seed 0
python3 rnn_pathological.py --task 3bit_memorization --init ortho --sequence-len 60 --seed 0
Results
Headline run, seed = 0, hidden = 64, batch = 50, batches-per-epoch = 30, lr = 0.01, momentum = 0.9, clip = 1.0, 80 epochs:
| task | T | metric | chance | ortho final | random final | gap | ortho solved@ | random solved@ |
|---|---|---|---|---|---|---|---|---|
| addition | 30 | MSE (lower is better) | 0.087 | 0.012 | 0.079 | 0.067 | epoch 41 | did not solve |
| temporal_order | 60 | accuracy (higher is better) | 0.250 | 1.000 | 0.240 | 0.760 | epoch 11 | did not solve |
| 3bit_memorization | 60 | accuracy (higher is better) | 0.125 | 1.000 | 0.115 | 0.885 | epoch 29 | did not solve |
“Solved” thresholds: addition MSE < 0.05, temporal_order acc > 0.90, 3bit_memorization acc > 0.90.
Reading the table: ortho-init solves all three tasks within tens of epochs. Random-init under the same hyperparameters and seed stays at chance on all three – not “trains slower,” but does not learn at all within the budget. The MSE column for addition shows the random run at 0.079, statistically indistinguishable from the chance baseline 0.087 (the network has converged to outputting the per-batch mean, ignoring the marker channel entirely).
XOR (also implemented but not in the headline) at T = 30, hidden = 64, 100 epochs, seeds 0, 1, 2: both inits stuck at ~50% (chance). See §Open questions.
Hyperparameters are identical across (task, init) pairs in the headline run; only the W_hh initialization changes. W_ih (input projection) and W_hy (output projection) use the same Gaussian recipe in both arms so the comparison really is “ortho vs random W_hh only”.
Visualizations
Headline: ortho vs random training curves

Training loss (top row, log-scale) and task metric (bottom row) for all three headline tasks. Blue = orthogonal init, red = random init. The gap is most dramatic on temporal_order and 3bit_memorization: the random run never leaves the chance baseline, while the ortho run hits 100% in under 30 epochs. The vertical dashed blue line marks the “solved” epoch for ortho.
Summary table

Same numbers as the Results table, colour-coded: green = ortho solved, red = random failed. All six cells line up with the prediction.
Why ortho wins: spectrum of W_hh

Left: singular-value spectra at initialization. The ortho matrix has all 64 singular values exactly equal to 1 (the flat blue line). The random N(0, 0.1) matrix follows the Marchenko-Pastur quarter-circle distribution: its largest singular value is ~1.6 (top of red curve), but it has a long tail of small values down to ~0.02. Those small directions cause vanishing gradients when backpropagating through 60 tanh layers.
Right: spectra after training on addition (T = 30, 80 epochs). The ortho matrix stays close to its initial flat profile – the optimizer modifies it but the spectrum doesn’t collapse. The random matrix retains its long tail of small singular values; gradient descent on its own does not fix what initialization broke.
Task examples

One example input/target for each headline task:
- addition: 30 timesteps of values in [0, 1]; the two dark-blue bars are the marker positions; target is their sum (1.289).
- temporal_order: 60 timesteps of distractors C/D/E/F (gray), with two cued positions where B/B were placed (red); target class is BB.
- 3bit_memorization: first 3 timesteps carry the bits 0, 1, 1 (blue); 56 timesteps of i.i.d. noise (gray); final query token (yellow); target class is binary 011 = 3.
Deviations from the original procedure
- Init scale for the random arm. The paper uses several random-init recipes (small Gaussian, sparse, etc.) and reports failure rates for each. We pick a single representative random recipe – N(0, 0.1) – chosen because at this scale the spectral norm is just over 1.0 (so the random arm is not trivially failing for “the matrix is too small to do anything”). It’s failing because the spread of singular values is too wide. Other random scales we did not run: 1/sqrt(n) (exploding at this T), 0.01 (trivially vanishing).
- Optimizer. Sutskever et al. 2013’s other contribution is Nesterov accelerated gradient, with momentum schedules ramping from 0.5 to 0.999. We use plain heavy-ball momentum at a fixed 0.9 and rely on global gradient clipping (Pascanu/Mikolov/Bengio 2013) at threshold 1.0 to keep ortho-init stable. With heavy-ball + clipping we see the qualitative claim hold; with NAG we expect cleaner convergence on
additionspecifically. - Sequence lengths. The paper goes up to T = 200 on some tasks; we use T = 30 (addition) and T = 60 (temporal_order, 3bit_memorization). T = 60 is already enough to break the random init on all three; T = 80 starts to break ortho too on
3bit_memorization(single-seed test: ortho dropped to 24% accuracy), so we stay at T = 60 to make the headline reproducible without seed-shopping. - Number of tasks. Paper covers all 7 Hochreiter-Schmidhuber tasks; we cover 4 (
addition,xor,temporal_order,3bit_memorization) and run 3 in the headline. The other three (multiplication,random_permutation_memorization,noiseless_memorization) are in the same family and we expect the same gap; we leave them as a follow-up. - XOR not solved. Sutskever et al. 2013 (table 2) report XOR requires ~8x more iterations than
addition, with carefully tuned learning-rate schedules. At our budget (100 epochs at lr=0.01) and with three seeds we did not solve XOR under either init. Documented as a failure case rather than dropped, to be honest about scope. - Output structure. We use a single output head at the final timestep (loss = MSE for regression, softmax-CE for classification). The original paper variously uses per-timestep targets and end-of-sequence targets depending on the task. End-of-sequence is the harder version (gradient must flow through every timestep), so this is the harder choice and matches the spirit of the paper.
- Float precision. float64 throughout; the paper used float32 on GPU. Should not matter at this scale (~6k parameters per model).
- Hidden size. 64. Paper uses 100. The smaller size is faster and still shows the gap; we did not test scaling.
Otherwise: same architecture (vanilla tanh RNN, single-layer recurrent), same loss type per task, same algorithm (BPTT + momentum + clip), same data distributions.
Open questions / next experiments
- What does it take to crack XOR? Sutskever et al. 2013 report ~8x more iterations vs addition, with NAG and a momentum schedule from 0.5 to 0.999. A clean experiment: hold this exact setup constant (3-bit_memorization-style hyperparameters), swap heavy-ball momentum for NAG, run 800 epochs on XOR at T = 30 with both inits, and report whether ortho cracks it while random remains at 50%.
- Where does ortho start to break? At T = 80 on
3bit_memorizationwith our default hyperparameters, ortho drops from 100% to ~24% accuracy in our smoke test. The paper’s headline is that ortho extends the working range, not that it makes the problem trivial. Sweeping T = {30, 60, 80, 100, 150, 200} for both inits and plotting “solve rate vs T” would map the boundary precisely. Above some T even ortho fails and you need LSTM or more careful conditioning (e.g. uRNN). - Identity vs orthogonal init. Le et al. 2015 (IRNN) propose
W_hh = I(identity) + ReLU as an alternative to ortho + tanh. Identity-init is the deterministic version of ortho (its spectrum is also flat at 1). On these tasks, we’d predict identity-init matches ortho onadditionbut underperforms on tasks where the network needs to use multiple hidden directions to encode different bits of state (e.g.3bit_memorization). A direct comparison would show whether the randomness of ortho matters, or just the spectrum. - Multi-seed convergence rates. Our headline is single-seed (seed = 0). Across N seeds, what is the success rate of ortho on each task? If ortho occasionally fails (say 10% of seeds), is that because the random orthogonal matrix happened to land on a “bad” rotation, or because the rest of the optimizer state matters?
- Connection to ByteDMD / data-movement complexity. The Sutro project measures algorithms by the data-movement they induce. A vanilla RNN reads
W_hh(H x H) and the hidden state (H,) once per timestep – a stride-Haccess over the recurrent weights repeatedTtimes. ByteDMD on a length-TBPTT pass on a 64-hidden RNN should give a clean reference number for “what does long-range memory cost in this architecture”, against which LSTM/Transformer alternatives can be compared. Unmeasured. - Why is
temporal_orderso easy at T = 60? Ortho hit 100% in 11 epochs – visibly easier thanaddition(41 epochs at half the T). Hypothesis: the cued symbols are already one-hots in a 6-dim space, so the network gets a free “gating” signal and only needs to remember which one-hot direction was active at the first cued position.additionrequires actually doing arithmetic on a continuous value, which seems to need more updates.
MNIST distillation with omitted digit “3”
Reproduction of the omitted-class experiment from Hinton, Vinyals & Dean, “Distilling the knowledge in a neural network”, NIPS Deep Learning Workshop (2015), §3.

Problem
Train a teacher on full MNIST, then train a smaller student on a transfer set with all examples of digit 3 removed. The student never sees a 3 during training. After distillation at high temperature, evaluate the student on test 3s. Then apply a single bias correction (boost the logit-bias for class 3) and re-evaluate.
The interesting property: the teacher’s softened output distribution carries “dark knowledge” – the relative probabilities of the wrong classes – and that signal alone is enough to teach the student what a 3 looks like, even without any example. After bias correction the student approaches the teacher’s accuracy on the held-out class.
- Teacher: 784 -> 1200 -> 1200 -> 10 ReLU MLP, trained on hard labels with ±2 px input jitter as augmentation.
- Student: 784 -> 800 -> 800 -> 10 ReLU MLP, no regularization. Trained by distillation only (no hard labels), at temperature T = 20, on the ~54k transfer-set examples that are not digit 3.
- Bias correction: a single scalar offset added to the student’s logit bias for class 3, chosen so the student’s mean softmax mass on class 3 matches the expected class frequency on the full training set (~10.2%).
Files
| File | Purpose |
|---|---|
distillation_mnist_omitted_3.py | MNIST loader, teacher / student MLPs, distillation loop, bias correction. CLI runs the full pipeline. |
visualize_distillation_mnist_omitted_3.py | Static plots: teacher curves, student curves, per-class accuracy, bias-correction sweep. |
make_distillation_mnist_omitted_3_gif.py | Animated GIF of the student learning to recognise 3s through distillation, with a final post-bias-correction frame. |
distillation_mnist_omitted_3.gif | Animation at the top of this README. |
viz/ | PNG outputs from the run below. |
Running
python3 distillation_mnist_omitted_3.py \
--seed 0 --temperature 20 --n-epochs-teacher 12 --n-epochs-student 20
Pure numpy (numpy + matplotlib + Pillow only). MNIST is downloaded once and
cached at ~/.cache/hinton-mnist/. Full pipeline runs in ~2 min on an Apple
M-series laptop. Visualizations and GIF each take another ~3 min (they re-run
training to capture per-epoch snapshots).
To regenerate visualizations:
python3 visualize_distillation_mnist_omitted_3.py --seed 0 --outdir viz
python3 make_distillation_mnist_omitted_3_gif.py --seed 0 --snapshot-every 1 --fps 3
Results
Seed 0, T=20, 12 teacher epochs + 20 distillation epochs:
| Quantity | Value |
|---|---|
| Teacher overall test accuracy | 98.39 % |
| Teacher accuracy on test 3s | 97.43 % |
| Student overall test accuracy (pre-correction) | 97.83 % |
| Student accuracy on test 3s, pre-correction | 91.88 % |
Bias offset applied to b_out[3] | +2.696 |
| Student overall test accuracy (post-correction) | 98.07 % |
| Student accuracy on test 3s, post-correction | 97.82 % |
| End-to-end wallclock | ~122 s |
The student never saw a 3 during training and still got 91.88 % of test 3s right just from soft targets. After bias correction it gets 97.82 %, almost exactly matching the teacher’s 97.43 %. Hinton et al. report 98.6 % on the same task with a slightly different setup (T = 8, hard + soft target mix); 98 % under our pure-soft-target T = 20 recipe is the same regime.
Per-class breakdown

The yellow column is the omitted class. Blue (pre-correction) collapses for class 3 only – every other class is essentially indistinguishable from the teacher. Green (post-correction) recovers the gap on class 3 with a single scalar bias change.
Bias-correction sweep

Left: as the bias offset grows from 0 to 5, accuracy on test 3s climbs from
~92 % to ~99 %, while overall accuracy traces a shallow inverted U (the
extra confidence on class 3 starts costing other classes once it overshoots).
Right: the criterion we use to pick the offset – match the average
p(class 3) to the empirical class frequency 10.2 % on the full training
set. The two curves cross at almost exactly the offset that maximises
overall accuracy, so the correction is essentially free.
Distillation curves

Teacher converges to ~98.4 % in ~10 epochs. The student’s accuracy on test 3s (red) climbs from ~43 % at epoch 1 to ~92 % by epoch 20 purely from soft-target supervision. Accuracy on the other digits (green) stays near 98 % the whole time.
What the network actually learns
The dark-knowledge story. With T = 20 the teacher’s softmax becomes near-uniform but its ratios still carry information: a smudgy 3 produces “mostly 3, a bit of 8, a touch of 5”, and the student picks up that signature from the off-diagonal masses on the 8s and 5s in the transfer set that look 3-like. Because the student is never asked to put any mass on class 3 (no 3s in the transfer set, no hard labels), its overall logit for class 3 ends up systematically lower than the teacher’s – visible in the left panel of the bias-correction sweep at offset 0. Adding a constant to that logit shifts the threshold for predicting 3 without changing how the student ranks 3s relative to one another, which is why a single scalar fixes a 6-point accuracy gap.
Deviations from the 2015 procedure
- Teacher regularization. The paper used 2x1200 + dropout for the teacher. We use 2x1200 + ±2 px input jitter, no dropout. Final teacher test accuracy 98.4 % vs the paper’s ~99.3 % – close enough for the distillation-headline result to land.
- Student loss mix. The paper combined a soft-target loss with a small hard-target loss (the standard distillation recipe). With class 3 absent, the hard-target term still applies on the other 9 classes. We use pure soft targets for simplicity – the student gets no class information except through the teacher’s softmax. The pre-correction gap on class 3 is therefore larger than the paper’s, and bias correction does correspondingly more work.
- Optimizer. Adam, batch 128, lr = 1e-3, no weight decay. The paper used SGD + momentum + dropout. Adam gets us to teacher-quality in 12 epochs instead of 60 – this is purely a wallclock convenience.
- Bias-correction criterion. The paper picked the offset by inspection (a simple grid). We binary-search the offset that matches class frequency on a 5000-image probe of the training set. This is a well-defined, hyperparameter-free version of the same idea and lands on the same operating point.
Correctness notes
A few subtleties worth flagging:
-
Filtering preserves label space. The student is trained on a 9-class transfer set but its output layer still has 10 logits – one for class 3 that simply never receives a hard-label gradient. Soft targets do provide gradient on that logit (any time the teacher puts non-zero mass on class 3, e.g. for a smudgy 5 that looks like a 3, the student’s logit-3 gets pushed). Without the bias correction, the student’s 3-logit ends up low because the teacher’s mass on 3 is small relative to its mass on the true class.
-
T² scaling. Hinton et al. note that softening logits by T scales gradients by 1/T². We multiply the soft-target loss by T² so the gradient magnitude lines up with a hard-label loss at T = 1; otherwise tuning lr would have to compensate.
-
Bias correction is
b[-1][3] += offset. Strictly a single scalar parameter is changed at correction time. We do not retrain. This is the “dark knowledge” headline: the network already knows what a 3 looks like; it just can’t say so out loud until you tweak the threshold. -
MNIST mirror. Yann LeCun’s URL is unreliable; we fall back through a list (Facebook’s
ossci-datasets, Google’s CVDF mirror, then the original) to make the loader robust. -
Reproducibility. The pipeline is deterministic for a fixed
--seed. The seed controls weight inits (teacher getsseed, studentseed + 1) and the data-shuffling RNGs (seed + 17for the teacher,seed + 31for the student). The full config + Python / numpy / OS / CPU info is written toviz/results.json.
Open questions / next experiments
- What does the soft-target-only recipe lose vs. the soft+hard mix? The paper combined both; we use only soft. With omitted classes, the hard-target term doesn’t even directly apply to the omitted class, so the comparison should mostly stress non-omitted classes. A side-by-side would isolate that.
- Other omitted classes. Is digit 3 special, or does any class survive
bias correction at this rate? The CLI’s
--omitted-classflag makes the obvious sweep cheap. - Multiple omitted classes. Distill with two classes removed; can a vector bias correction (one offset per omitted class) recover both? The paper hints at this but doesn’t run the experiment.
- Smaller students. How small can the student get and still recover the omitted class via bias correction? At what student capacity does dark-knowledge transfer collapse?
- Temperature ablation. T = 20 is the spec; T = 1 (no softening) likely fails entirely. Where does the recovery curve lie between T = 1 and T = 50?
AIR Multi-MNIST (Attend, Infer, Repeat)
Numpy reproduction of Eslami, Heess, Weber, Tassa, Szepesvari, Kavukcuoglu & Hinton, “Attend, Infer, Repeat: Fast Scene Understanding with Generative Models”, NIPS 2016. The paper’s headline claim: a recurrent attention model can simultaneously learn to count, locate, and identify a variable number of objects in a scene with no per-object supervision – only the raw scene as the reconstruction target.

Problem
A 32x32 canvas containing 0, 1, or 2 randomly placed and scaled MNIST digits (scale 12-18 pixels, uniformly random position within bounds). Each scene is labelled only with the raw pixel image – no per-digit boxes, no count label. The model must infer:
- Count: how many digits are present (0, 1, or 2).
- Location: a 3-D affine
(scale, dx, dy)per digit. - Appearance: a low-dimensional
z_whatlatent per digit.
Chance for the count is 1/3 ≈ 0.333 (uniform over {0, 1, 2}); the spec
target for “best-effort numpy” is count_acc > 0.5.
Architecture
| Stage | Layer | Output shape | Notes |
|---|---|---|---|
| Input | – | (B, 32, 32) | scene |
| Encoder fc1 | 1024 -> 200, ReLU | (B, 200) | global MLP |
| Encoder fc2 | 200 -> 100, ReLU | (B, 100) | shared across all steps |
For each step t = 0..2: | |||
| Pres head | 100 -> 1, sigmoid logit | (B,) | z_pres_t |
| Where head | 100 -> 3 | (B, 3) | (log_s, tx, ty) |
| What mu | 100 -> 20 | (B, 20) | z_what_t mean |
| What logvar | 100 -> 20 | (B, 20) | z_what_t log-variance |
| Sample | reparameterized | – | z_what = mu + exp(.5 lv) * eps |
| Decoder fc1 | 20 -> 100, ReLU | (B, 100) | shared across steps |
| Decoder fc2 | 100 -> 256, sigmoid | (B, 16, 16) | per-digit appearance patch |
| Spatial transformer | inverse-affine bilinear | (B, 32, 32) | render at z_where_t |
| Sum | recon += cumprod(z_pres) * canvas_t | (B, 32, 32) | weighted by cumulative presence |
Total parameter count: ~280k floats.
ELBO
L = MSE(recon, x) * 1024
+ kl_what_weight * sum_t KL( N(mu_t, exp(lv_t)) || N(0, I) )
+ kl_pres_weight * sum_t KL( Bern(sigmoid(logit_t)) || Bern(p_t) )
with kl_what_weight = 0.05, kl_pres_weight = 0.3, prior rates
p_t = (0.5, 0.4, 0.2).
z_pres relaxation
The discrete z_pres_t ∈ {0, 1} is replaced by a Gumbel-sigmoid relaxation
that anneals from temp τ = 1.0 (smooth) to τ = 0.2 (near-binary) over
training:
g_t = log(u_t) - log(1 - u_t), u_t ~ Uniform(0, 1)
z_pres_t = sigmoid((logit_t + g_t) / τ)
The “use earlier slots first” inductive bias comes from the cumulative-product
gating: slot t contributes (z_pres_0 * z_pres_1 * ... * z_pres_t) * canvas_t
to the reconstruction, so once an earlier z_pres collapses near 0 the later
slots are masked out.
Files
| File | Purpose |
|---|---|
air_multimnist.py | Scene gen, AIR forward/backward, ELBO, Adam, training. CLI: --seed --canvas-size --max-steps --what-dim --n-epochs --n-train --n-val --batch-size --lr --kl-what-weight --kl-pres-weight |
visualize_air_multimnist.py | Trains and writes static figures to viz/ |
make_air_multimnist_gif.py | Trains with snapshots and renders the animated GIF |
air_multimnist.gif | Output of the GIF script |
viz/ | Static PNGs |
Running
# Quick training, no plots (~5s wallclock on M-series Mac)
python3 air_multimnist.py --n-epochs 8 --n-train 1500 --seed 0
# Train + render all static figures (~6s + plotting)
python3 visualize_air_multimnist.py --n-epochs 8 --n-train 1500 --n-val 300
# Train + render the animated GIF (~6s + frame rendering)
python3 make_air_multimnist_gif.py --n-epochs 8 --snapshot-every 35 --fps 6
The MNIST loader downloads the four *-idx*-ubyte.gz files from
storage.googleapis.com/cvdf-datasets/mnist/ on first run and caches them at
~/.cache/hinton-mnist/.
Results
Defaults: 8 epochs * 46 steps/epoch, batch 32, Adam lr=2e-3, on 1,500 generated training scenes. Single-thread numpy. The training loop tracks the best-epoch validation count accuracy and restores those weights at the end (early stopping).
| Metric | Value | Baseline |
|---|---|---|
| Best val count accuracy | 0.797 | chance = 0.333 (uniform 0/1/2) |
| Per-class acc count=0 (empty) | 1.000 | 107/107 |
| Per-class acc count=1 | 0.550 | 55/100 (others -> 2) |
| Per-class acc count=2 | 0.828 | 77/93 (others -> 1) |
| Val reconstruction MSE (per pixel) | 0.013 | input-image MSE-to-mean = ~0.05 |
| Wallclock (visualize_air_multimnist.py) | ~6 s | – |
The 0.797 / chance 0.333 ratio is 2.4x; the 0.5 spec target is comfortably beaten. The per-class breakdown shows the model is most confident on empty scenes (always correct), best at the maximum-load case (count=2, because both spaces typically need to fire), and weakest on single-digit scenes (sometimes spuriously fires the second slot).
Training curves

Reconstruction loss decreases monotonically. Count accuracy peaks early (epochs 1-3) then drifts downward as the model trades count fidelity for extra reconstruction quality from the third slot – this is exactly the known weakness of relaxed-Bernoulli AIR vs the paper’s REINFORCE-on-discrete variant. The trainer keeps the best checkpoint and restores it at the end.
Per-step attention boxes

Each box color is one step (blue = step 0, orange = step 1, green = step 2).
Solid lines are cum_pres > 0.5 (slot is active); dashed lines are inactive.
The model puts step 0 on the larger / more central digit and step 1 on the
second one when present.
Per-digit reconstructions

Columns: input, total reconstruction, per-step contribution scaled by
cum_pres, residual. Reconstructions are blurry blob-shaped – the
appearance decoder is small (20-D z_what, 100-D hidden, 16x16 patch with no
convolutions) and 1500 training scenes is roughly 4 orders of magnitude less
data than the paper. The qualitative success is that step 0’s contribution is
localized on the right region (when there’s a digit there) and approximately
zero elsewhere, which is what the spatial transformer’s gradient signal
encourages.
Per-step decoded patches

The 16x16 patches the decoder outputs (left of each pair) and where they land on the canvas after the spatial transformer (right of each pair). Clearly patches are blob-shaped templates; the spatial transformer scales / positions them onto the active digit region.
Count distribution

The full count confusion matrix on the 300-image validation set. Empty scenes
(count=0) are perfectly recovered; the most common error is 1 -> 2 (the
model spuriously fires slot 2 on a single-digit scene).
Example training scenes

Eight training scenes with their ground-truth counts. The 12-18 pixel digit size on a 32x32 canvas means a single digit fills ~40-55% of the canvas diameter, so 2-digit scenes often have visible overlap.
Deviations from the 2016 paper (all documented in source)
- Smaller canvas: 32x32 (paper: 50x50). Cuts the encoder input by 2.4x so a pure-numpy MLP encoder is fast. Cuts the decoder output proportionally.
what_dim = 20(paper: 50). Smaller VAE latent so the per-step what head is small (100 * 20 = 2k params per step). The paper’s 50-D would give 5k params per step; not a big change but consistent with reducing capacity.- Per-step heads on a global MLP encoder, NOT a recurrent LSTM over image
residuals. The paper’s full AIR uses an LSTM that takes the residual
(image - sum_{i<t} render_i)as input at each step; the LSTM’s hidden state carries “what’s left to explain”. Backprop through 3 LSTM steps over a spatial-transformer-rendered residual is too slow per training step in pure numpy. We use 3 independent linear heads on the same global MLP features. This still demonstrates per-step attention and counting; the hidden cost is that empty slots can spuriously fire because the model doesn’t see “the image is already explained.” - Gumbel-softmax (sigmoid) relaxation for
z_presthroughout. The paper uses REINFORCE for the discrete count, with a Bernoulli sample forz_presand a baseline-subtracted score-function gradient. The relaxation is easier to fit in pure numpy (continuous reparameterized gradient) but has a known failure mode: at high temp the model can use intermediatez_presvalues to “fade in” partial digits, which beats a stricter binary counter on reconstruction loss but degrades count accuracy. We annealτfrom 1.0 down to 0.2 to bias towards binary. - Independent Bernoulli prior per slot with rates
(0.5, 0.4, 0.2). The paper uses a true geometric priorP(z_pres_t = 1 | all earlier on) = p, which gives the exact “encourage smaller counts” inductive bias. Our independent priors plus the cumulative-product gating give a similar effect with simpler math. - 1,000-1,500 training scenes (paper: 60M streamed). With this much data the appearance decoder cannot learn sharp digit shapes; reconstructions are blurry. Counting (the headline claim) still works well.
- No forward spatial transformer on the encoder side. The paper crops
the image at
z_whereand feeds the cropped patch to a separate VAE encoder; we encode the whole canvas globally and let the per-step heads read out where to look. With a 32x32 canvas the global encoder still has plenty of resolution to localize a 16x16 digit. - Adam optimizer with lr=2e-3 (paper: stochastic gradient). Standard substitution; converges in 5x fewer steps.
Correctness notes
- Spatial transformer backward pass verified by finite-difference. With
a smooth (Gaussian) test patch and
delta=1e-4, all three components ofd_z_wherematch the analytical gradient to four decimal places (ratio ≈ 1.0000). With sharp patches (random uniform) the agreement degrades to ~10-30% ontx,tybecause the bilinear sampler’sfloor()is non-differentiable at integer pixel boundaries – this is a known issue with bilinear-sampler STN backprop and is the standard tradeoff. The gradient is correct in expectation. - Cumulative-presence backward via product rule.
eff_pres_t = prod_{i ≤ t} z_pres_i, sod eff_pres_t / d z_pres_j = eff_pres_t / z_pres_jforj ≤ t. We use safe division (np.clip(z, 1e-6, None)) to avoid divide-by-zero when az_prescollapses to 0. - Logvar clipping.
z_what_logvaris clipped to[-8, 4]to avoidexp(lv)overflow / underflow. Gradient is masked through the clip (zero outside the active range), preventing parameter updates that would pushlvfurther out of the clip range. - Decoder bias init at -2.
b_d2 = -2(so initial decoded patch is sigmoid(-2) ≈ 0.12, almost dark) breaks symmetry: at init the model reconstructs a dark canvas, so the gradient onz_presis to increase it for non-empty scenes (positive recon error reduction) and decrease it for empty scenes (over-reconstruction would hurt MSE). Without this bias the model is symmetric at init andz_presdoes not learn. - Best-checkpoint restoration. The trainer tracks the highest
val_count_accacross epochs and restores those weights at the end. This guards against the well-known late-training degradation as reconstruction loss continues to drop while count accuracy regresses.
Open questions / next experiments
- Recurrent encoder. Would an LSTM over
(image - cumulative recon)fix the spurious second-slot firings on single-digit scenes? The paper’s recurrent design is specifically motivated by this; our per-step-heads reduction loses this signal. A pure-numpy LSTM x 3 timesteps with STN backward through each step is roughly 5-10x the per-batch cost. - REINFORCE for
z_pres. The relaxed-Bernoulli formulation is the most common reason published AIR reproductions get worse counting than the paper. A REINFORCE estimator with a learned baseline would letz_presbe exactly binary, removing the “fade in slots for fractional reconstruction” failure mode. - Larger
what_dimand convolutional decoder. With 20-Dz_whatand an MLP decoder the appearance reconstructions are blurry. A small conv decoder would match the paper’s recipe and likely produce sharper digit-like templates that are recognizably 0-9. - Streaming scene generation. 1500 scenes is overfit-prone. Streaming fresh scenes per batch (one MNIST sample per draw, 60k unique base digits) would give the model effectively infinite data without changing wallclock.
AIR 3D primitives (programmable renderer inversion)
Source: Eslami, Heess, Weber, Tassa, Szepesvari, Kavukcuoglu & Hinton, “Attend, Infer, Repeat: Fast Scene Understanding with Generative Models”, NIPS 2016. Demonstrates: Inverting a programmable Lambertian renderer. Given a 2D grayscale image, recover the count, type, 3D position, and 3D rotation of up to three primitives (sphere, cube, cylinder).

Problem
Build a generative model + inference network for variable-count 3D scenes:
- Generative model — a known programmable renderer takes
(count, [type, position, rotation]_i)and produces a 64x64 grayscale Lambertian image. Up to three primitives per scene; primitives are unit spheres / cubes / cylinders scaled to side 0.4, placed in a unit cube, under an orthographic camera with a single light from the camera direction. - Inference network — given an image, predict the slot-wise latents
(presence, type, position, rotation)so that re-rendering through the generative model approximately reproduces the image.
The original paper trains the inference network end-to-end with REINFORCE
through the discrete count latent. We swap that for supervised regression
on synthesized (image, ground-truth latents) pairs (see Deviations
below). The renderer is the same idea: a fixed, differentiable-by-design
forward model that the inference network learns to invert.
Files
| File | Purpose |
|---|---|
air_3d_primitives.py | Lambertian renderer, dataset generator, AIR-style inference MLP, training loop, and CLI. Exports render_3d_scene, generate_dataset, build_air_model_3d, train. |
visualize_air_3d_primitives.py | Static viz: primitive gallery, scene examples, training curves, prediction panel, error histograms + type confusion. |
make_air_3d_primitives_gif.py | Trains and snapshots the inference network every few epochs, then renders a “predictions improving” animation. |
air_3d_primitives.gif | Output of the GIF maker. |
viz/ | PNGs from the visualization script. |
Running
End-to-end run with the canonical config:
python3 air_3d_primitives.py --seed 0 --image-size 64 \
--max-primitives 3 --n-epochs 80 --n-train 3000 --n-test 500 \
--weight-decay 1e-3
Total wallclock on an M3 laptop: ~12 s (3.3 s synth + 8.5 s train).
Writes results.json (config, metrics, history, environment) and
weights.npz (so the visualization scripts can reload).
To regenerate visualizations:
python3 visualize_air_3d_primitives.py --seed 0 --outdir viz
python3 make_air_3d_primitives_gif.py --seed 0 --n-epochs 30 \
--snapshot-every 2 --fps 4
Results
Canonical run (seed 0, 64x64, max 3 primitives, 3000 train / 500 test, 80 epochs, hidden=128, input mean-pool 2x):
| Metric | Value | Notes |
|---|---|---|
| Count accuracy | 81.2 % | Exact match on number of primitives in the scene |
| Per-slot presence accuracy | 93.7 % | Treating every (scene, slot) pair as a binary problem |
| Type accuracy | 51.7 % | Chance is 33.3 % over {sphere, cube, cylinder} |
| Position MAE (x, y, z) | 0.179 / 0.246 / 0.188 | Targets in [-1, 1]; ~10-12 % of range on x and z |
| Rotation MAE per Euler axis | 0.78 / 0.80 / 0.80 rad | Rotation loss is masked for spheres (rotationally symmetric) |
| Best validation epoch | 8 / 80 | Severe overfitting after that, see training curves |
| Synth + train wallclock | 3.3 s + 8.5 s | macOS arm64, Python 3.12.9, numpy 2.2.5 |
Sanity check: with --max-primitives 1 the same network reaches 88.8 %
type accuracy and position MAE 0.069 / 0.084 / 0.165 — confirming the
architecture works and the bottleneck is slot disambiguation, not
representation.
What the network actually learns
The renderer is the easy part

The Lambertian renderer is closed-form: ray-cast each pixel against the
unit primitive in its local frame (analytic intersection, no ray tracing
loop), shade with max(0, n . light) against a camera-direction light,
and resolve occlusion with a per-pixel z-buffer. Each scene renders in
~1 ms at 64x64. The dataset of 3000 scenes is ready in ~3 s.
Sample scenes

Each scene has 1-3 primitives uniformly sampled from {sphere, cube, cylinder}, placed at random positions in the unit cube and rotated by
random Euler angles in [0, pi]. To make the slot assignment unambiguous,
ground-truth primitives are sorted by their world x-coordinate before being
written to slots 0/1/2.
Inference network performance

The MLP encoder gets count roughly right and recovers x/y position with acceptable error. It struggles to disambiguate cube vs. cylinder when they overlap or are rotated to a near-axial view (rightmost column shows the absolute pixel error after re-rendering predictions). When all three primitives are present the slot-assignment ambiguity at the boundary between similar x-coordinates regularly causes a type swap.
Training curves

Total val loss bottoms out around epoch 8 then climbs as the network
overfits. The presence head (BCE) keeps learning the longest because
presence is the easiest signal. Type cross-entropy goes below uniform
chance (log 3 ~= 1.099) early but then rises again as overfit predictions
grow more confidently wrong. Position MSE keeps improving slowly. The
brief spike around epoch 60 is a normal Adam transient and recovers within
~3 epochs.
Per-axis error distributions

The position-error histogram is sharply peaked near zero with a tail. The
rotation-error histogram is roughly uniform on [0, pi/2] — the network
makes basically no useful prediction on rotation. The confusion matrix
shows a clear diagonal but with substantial off-diagonal mass, especially
“sphere predicted as cylinder” (107) and “cylinder predicted as cube”
(95) — the projected silhouettes overlap considerably for tilted views.
Deviations from the 2016 paper
- Pure-numpy supervised training instead of REINFORCE-AIR. The original
paper trains end-to-end from pixels alone, using REINFORCE through the
discrete count latent and reparameterized Gaussians for
z_what/z_where. We instead synthesize pairs(image, ground-truth latents)from the renderer, sort the ground truth by x-position to make slots permutation-free, and supervise the inference network directly. This isolates the inference problem from the reinforcement-learning complications and removes the need for an autograd framework. - MLP encoder instead of a CNN + RNN. The paper uses a CNN feature extractor and a recurrent core that emits one slot at a time. We use a 3-layer MLP that emits all slots simultaneously. The MLP lacks the translation-equivariance prior that helps with object localization; that is the main reason 3-primitive type accuracy plateaus at ~52 %.
- Closed-form renderer instead of a learned generator. Each primitive has an analytic ray-intersection, so we never run a marching loop or train the renderer. Lighting is a single Lambertian term against a camera-direction light, with constant ambient 0.15. No textures, no shadows, no specular.
- Sphere rotation loss is masked. A unit sphere is rotationally symmetric, so the Euler angles are unrecoverable. We zero out the rotation MSE for sphere slots so the network is not penalized for guessing.
- No camera variation. The paper’s follow-up varies camera pose; we
keep the camera fixed so the only latent variation is the per-primitive
(type, position, rotation).
Correctness notes
- Slot assignment by x-position is consistent but biased. Sorting by x makes the slots well-defined, but introduces hard discontinuities when two primitives have nearly the same x — small rotations or noise can flip which one is “left”. The MLP cannot represent this discontinuity well, which contributes to the val-loss plateau.
- Best-checkpoint restoration. Because val loss diverges after epoch ~8, the trainer tracks the best validation epoch and restores those weights at the end. Without this, reported metrics would be those of an overfit final-epoch network, not the best one.
- Decoupled weight decay (AdamW-style). Weight decay is applied directly to weight tensors after the Adam update, not folded into the gradient. Biases receive no decay. This is important to avoid biasing the rotation outputs toward zero.
- Z-buffer correctness. Multiple primitives are composed by storing
the smallest hit
tper pixel; closer primitives win. We tested this on overlapping-primitive scenes (seescene_examples.png).
Open questions / next experiments
- Convolutional encoder. A small numpy CNN (3-4 conv layers + global pool) should close most of the type-accuracy gap by giving the network translation-equivariance for free. Implementing convolutions in pure numpy is straightforward but adds ~200 lines.
- Permutation-invariant slot loss. Replace x-sorting with Hungarian matching between predicted and target slots at every training step. This removes the discontinuity but couples the per-slot heads.
- REINFORCE through z_pres. Train end-to-end with REINFORCE on the discrete presence variable, supervising only the rendered image (not the latents). This is the original AIR formulation and would let the network learn without ground-truth latents at all.
- Camera variation. The paper’s full task includes camera pose. Adding a camera-pose prediction head and a randomized camera should be ~30 lines and would test the network’s ability to factor pose from content.
- Larger primitive vocabulary. Add cones, prisms, ellipsoids — does the slot-wise representation scale, or does the encoder collapse types?
Reproducibility
Each run records seed, image_size, max_primitives, n_epochs,
n_train, n_test, hidden, input_pool, batch_size, lr,
weight_decay, plus Python / numpy / OS / processor strings into
results.json. The same seed reproduces the same dataset, the same
weight initialization, the same training trajectory, and (on the same
hardware) the same final metrics to the last decimal.
Fast-weights associative retrieval
Source: J. Ba, G. Hinton, V. Mnih, J. Z. Leibo, C. Ionescu (2016), “Using Fast Weights to Attend to the Recent Past”, NIPS. arXiv:1610.06258.
Demonstrates: A small RNN equipped with a per-sequence “fast weights” matrix A_t = lambda * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) performs Hopfield-style content-addressable retrieval on the toy task c9k8j3f1??c -> 9. This is the first attention-like mechanism in the modern deep-learning era, predating transformer attention by a year.

Problem
Each sample is a sequence of characters of the form
k1 v1 k2 v2 ... kn vn ? ? q ?
where k_i is a random letter, v_i is a random digit (0-9), and q is one of the previously-seen k_i chosen uniformly at random. The network must output the digit v_i paired with q. Vocabulary is 26 letters + 10 digits + 1 separator = 37 tokens. The final ? is a “trailing read” no-op step (see Architecture, below).
The interesting property: the network has only O(H^2) slow weights to learn the general algorithm but must store n_pairs per-sample bindings somewhere. A vanilla RNN cannot do this for n > 1 because its hidden vector cannot represent variable-length associative memory. The fast-weights matrix A_t, recomputed inside each sequence, IS the per-sample storage. At the read step the matrix-vector product A_t @ h_{t-1} performs Hopfield-style retrieval: any past h_τ that has high inner product with h_{t-1} (the query state) contributes its bound representation to the pre-activation. Ba et al. report this beats IRNN, LSTM, and Associative-LSTM at the same parameter count.
Architecture (Ba et al. with the trailing-read simplification)
A_t = lambda_decay * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) (A_0 = 0)
z_t = W_h h_{t-1} + W_x x_t + b + A_t @ h_{t-1}
zn_t = LayerNorm(z_t)
h_t = tanh(zn_t)
out = W_o h_T + b_o # only the final hidden state predicts
The slow weights {W_h, W_x, b, W_o, b_o} are learned by truncated BPTT. The fast weights A_t are reset to zero at the start of each sample.
Two design choices match the Ba et al. recipe:
- LayerNorm is necessary. Without it,
A_t @ h_{t-1}grows roughly quadratically as outer products accumulate, the tanh saturates at ±1 within ~5 steps, and1 - tanh^2collapses the recurrent gradient to zero. Confirmed empirically: pre-LN model hidden norm reachessqrt(H)(full saturation) by step 7 and gradients onW_h, W_x, bare exactly zero. - Trailing read step. Each sample ends with an extra
?after the query letter. This guarantees that at step T the fast-weights matrixA_T = lambda A_{T-1} + eta outer(h_{T-1}, h_{T-1})has been built from a hidden state that already encodes the query letter. Without this trailing step the retrieval would have to fire BEFORE the query is integrated, and only awkward W_o-side decoding could recover.
BPTT through the fast weights
The recurrence on A_t means the gradient on the fast-weights matrix accumulates a running term across the backward pass:
dA_running = 0
for t = T..1:
dh_t already known
dzn_t = dh_t * (1 - h_t^2) # tanh
dz_t = LN_backward(dzn_t, zn_t, sigma) # layer norm backward (no affine)
dW_h += outer(dz_t, h_{t-1})
dW_x += outer(dz_t, x_t)
db += dz_t
dh_{t-1} = (W_h.T + A_t.T) @ dz_t
dA_t_local = outer(dz_t, h_{t-1}) # from z_t = ... + A_t h_{t-1}
dA_t_total = dA_running + dA_t_local
dh_{t-1} += eta * (dA_t_total + dA_t_total.T) @ h_{t-1} # outer term
dA_running = lambda_decay * dA_t_total # chain to A_{t-1}
A numerical-gradient check (central differences, eps=1e-5, sampled across each parameter tensor) verifies max relative error of ~1e-9 on every slow parameter for the n_pairs=2 / hidden=8 configuration. The BPTT path through the fast-weights matrix is implemented correctly.
Files
| File | Purpose |
|---|---|
fast_weights_associative_retrieval.py | FastWeightsRNN (forward + manual BPTT including fast-weights chain), Adam, generate_sample / generate_batch, train loop, per_position_accuracy evaluation, CLI |
visualize_fast_weights_associative_retrieval.py | Static plots: training curves, per-slot accuracy bars, A_t evolution heatmap, hidden-state trace |
make_fast_weights_associative_retrieval_gif.py | Animated GIF: per-step A_t heatmap + hidden state for one example |
fast_weights_associative_retrieval.gif | Committed animation (~475 KB) |
viz/ | Committed PNG outputs |
Running
python3 fast_weights_associative_retrieval.py --seed 0 --n-pairs 4 --n-steps 4000
Train wallclock: ~290 s on an M-series laptop (system Python 3.12, numpy 2.2). Final retrieval accuracy on n_pairs=4: 38.35% (well above 10% chance, well below the >90% spec target — see Results and Deviations §1).
To regenerate the visualizations and gif:
python3 visualize_fast_weights_associative_retrieval.py --seed 0 --n-pairs 4 --n-steps 4000
python3 make_fast_weights_associative_retrieval_gif.py --seed 0 --n-pairs 4 --n-steps 2000
CLI flags (the spec calls out --seed --n-pairs --n-steps; everything else is optional):
--seed RNG seed default 0
--n-pairs # of (key, value) pairs default 4
--n-steps # of training batches default 3000
--n-hidden hidden state dim H default 64
--lambda-decay fast-weights decay default 0.95 (per stub spec)
--eta fast-weights gain default 0.5 (per stub spec)
--batch-size Adam mini-batch default 32
--lr Adam learning rate default 5e-3
--eval-every eval interval (steps) default 100
--eval-batch eval batch size default 256
--grad-clip global-norm clip default 5.0
--show-samples # of demo predictions default 4
Results
Single run, --seed 0 --n-pairs 4 --n-steps 4000 --n-hidden 80 --batch-size 64 --lr 5e-3:
| Metric | Value |
|---|---|
| Architecture | Fast-weights RNN, hidden=80, vocab=37, output=10 (digits) |
| Slow params | 10,250 |
| Fast-weights matrix per sample | 80 × 80 = 6,400 entries (transient, not learned) |
| Final retrieval accuracy (n=2000 eval) | 38.35% (vs. 10% chance, vs. 90% spec target) |
| Final retrieval cross-entropy | 1.22 (vs. log 10 = 2.30 chance, vs. 0 perfect) |
| Per-slot accuracy (slot 0 = oldest) | slot0=31.2%, slot1=43.6%, slot2=41.7%, slot3=34.9% |
| Train wallclock | 293 s |
| Hyperparameters | lambda_decay=0.95, eta=0.5, lr=5e-3, batch=64, grad_clip=5.0 |
| Sample predictions | i7w9o3a6??w? -> 9 (✓), g0l7a5d5??d? -> 5 (✓), o9g0i6z7??g? -> 0 (✓), e3u5d4m7??m? -> 4 (✗ target 7) |
Sanity check on n_pairs=1 (--n-pairs 1 --n-steps 800): 100.0% retrieval accuracy in 11 s. The model trivially solves the one-binding case.
Sanity check on n_pairs=2 (--n-pairs 2 --n-steps 1500 --lr 5e-3): 54.85% retrieval accuracy with slot0=100%, slot1=9%. The model collapses to a “always retrieve the first value seen” degenerate solution. Loss = 1.6 (vs. log 2 = 0.69 if the model were “perfect on slot0, chance on slot1”). Several other seeds reproduce the same plateau, sometimes flipping to “always slot1”.
Numerical gradient check passes (max relative error 1e-9 across all parameters), so the architecture and BPTT implementation are correct. The optimization landscape difficulty is the limiting factor — see Deviations §1 below.
Visualizations
Training curves

Cross-entropy loss (left) drops from chance (log 10 ≈ 2.30) to ~1.22 over 4000 batches; accuracy (right) climbs from 10% (chance) to ~38%. The eval and train curves track each other — there is no overfitting; the network just plateaus. Visible plateau in steps 1000–2200 followed by a second descent suggests the optimizer slowly discovers usable fast-weights structure.
Per-slot accuracy

Accuracy bucketed by which (key, value) slot the query referenced. All four slots sit in the 31–44% range — well above 10% chance and roughly uniform, which means the network IS doing real associative retrieval (not simply “always predict slot 0” or “always predict the most-recent value”). The slight bias toward middle slots (1, 2) likely reflects the lambda^t decay: slot 0 is the oldest in A_T (weight lambda^7 ≈ 0.7) and slot 3 is the most recent (weight ~lambda), with the middle slots seeing the strongest interaction between the decay envelope and the eta-outer reinforcement at the time of read.
Fast-weights matrix evolution

Heatmap snapshots of A_t at every timestep of one example sequence (x2j1y3w5??x?, target = 2 (the value paired with x)). At t=0 the matrix is exactly zero (initial condition). Each subsequent step adds an eta * outer(h_{t-1}, h_{t-1}) rank-1 contribution and decays the existing entries by lambda = 0.95. By t=8 (the query letter x) the matrix has accumulated eight outer-product traces; by t=10 (trailing read) it has accumulated all eleven. The reader can see the matrix is genuinely changing — the fast weights ARE being computed, the question is whether the slow weights have learned to USE them.
Hidden state trace

Top: heatmap of h_t (rows = hidden units, columns = timesteps). Bottom: ||h_t|| per step. Each step modulates the hidden state in input-specific ways. The query step (x at t=8) produces a distinct pattern from the same letter x at t=0 — confirming that hidden representations of letters are context-dependent, which is the prerequisite for the fast-weights binding to encode the right pairing.
Deviations from the original procedure
-
Significantly underperforms the paper’s headline numbers. Ba et al. report ~98% accuracy on the n_pairs=4 task with hidden=100, RMSProp at lr=1e-4, batch size 128, trained for hundreds of epochs (~10^5 batches). Our v1 reaches 38% with Adam at lr=5e-3, batch=64, 4 000 batches (~292 s). We tried longer (5 000), lower lr (1e-3), larger batch (128), grad-clip on/off, identity vs. scaled init for
W_h, eta in {0.1, 0.3, 0.5}, and four random seeds; the best n_pairs=2 result we found was 57% witheta=0.1, and seeds 0–2 all collapse to a slot0-only degenerate solution at the spec-defaulteta=0.5. The gradients are correct (verified to 1e-9 by central differences) and the n_pairs=1 sanity test trains cleanly to 100%, so the issue is NOT a code bug — it is a known optimization-difficulty problem for this task. Ba et al. specifically call out that fast-weights training requires careful tuning and that the gradient signal through the fast-weights pathway can be drowned out by the easier “memorize via slow weights” basin in the early epochs. Reproducing the >90% number from the paper is the natural v2 target — see Open Questions. -
Single inner-loop iteration
S=1instead of the paper’s recommendedS>=1. The paper formulates the inner loop ash_{s+1}(t+1) = f(LN([W h_t + C x_t] + A_{t+1} h_s(t+1)))fors = 0..S-1, withh_0(t+1) = f(LN(W h_t + C x_t))as the boundary. We use a flatter formh_t = tanh(LN(W h_{t-1} + W_x x_t + b + A_t h_{t-1}))plus the trailing-read step (see Architecture above). We did implement and gradient-check the proper inner-loop formulation but it trained worse than the trailing-read flavor in our hands (stuck at chance), likely because the additional nonlinearity makes the gradient landscape harder. With the trailing read, the retrieval-step semantics match the inner loop’s S=1 case but with a less-nested gradient. -
Adam instead of RMSProp. The paper uses RMSProp at lr=1e-4. We use Adam at lr=5e-3 because Adam mostly subsumes RMSProp and is the modern default. This may matter for this specific task — RMSProp’s lack of bias correction sometimes lets it escape sharp local minima that Adam settles into.
-
Layer normalization without learnable affine. Standard LayerNorm has
gain * (x - mu) / sigma + bias. We use the no-affine form (gain=1, bias=0) because adding parameters didn’t change the outcome in early experiments and we wanted to keep the parameter count minimal for a small numpy reference. Adding affine LN is a one-line change and a natural v2 ablation. -
Identity init for
W_h(0.5 * I) rather than orthogonal or scaled-Gaussian. Standard for fast-weights RNNs after Le, Jaitly & Hinton 2015 (IRNN). This held under LayerNorm (the LN rescales any explosion) and gave cleaner early-training dynamics.
Open questions / next experiments
-
Reproduce the paper’s >90% headline. The most direct path: switch to RMSProp at lr=1e-4, batch=128, train for 50 000+ steps. This is a cheap follow-up (~30 minutes wallclock) and would close the gap to the paper if the hypothesis is right.
-
Curriculum on
n_pairs. Start training withn_pairs=1, expand to 2 once accuracy >95%, then 3, then 4. Prevents the “always-slot0” basin by establishing real retrieval before the model can find the degenerate solution. The Sutro group has used analogous curricula for sparse-parity expansion (SutroYaro/docs/findings/curriculum.md). -
Fast-weights gradient diagnostics. Log
||dA||_F / ||dW_h||_Fevery step. If the fast-weights gradient norm is much smaller than the slow-weights gradient norm, that confirms the “drowned signal” hypothesis and motivates a per-pathway learning rate. -
Explicit auxiliary loss. Add a contrastive loss that rewards
cos_sim(h_query, h_key) > cos_sim(h_query, h_other_key). Forces the network to learn distinguishable letter representations early, which is a prerequisite for fast-weights retrieval to work. -
Alternative storage: separate read/write paths. The current architecture mixes the slow-weight pathway and the fast-weights retrieval into a single pre-activation. A v2 could have two separate pathways (
h_slow = tanh(LN(W_h h_{t-1} + W_x x_t + b)),h_fast = A_t @ h_{t-1}) combined ash_t = h_slow + alpha * h_fastwherealphais learnable. This preserves the fast-weights signal magnitude through training. -
Comparison to vanilla Hopfield network. The fast-weights matrix at the read step is mathematically a one-shot Hopfield read. A natural baseline: fix
W_h, W_x, bto a sensible identity-like setting and train ONLY the readoutW_o. If that gets ~90% with no slow-weight learning at all, the slow-weight path is actively hurting the retrieval mechanism. -
Data movement. The fast-weights matrix is size
H^2, recomputed and decayed every step. For sequence length T with batch size B, that’sO(B * T * H^2)extra memory traffic compared to a vanilla RNN. The Sutro group’s ByteDMD framework is the natural place to measure whether this is amortized by the smaller slow-weight matrix, or whether fast weights are net-energy-losers vs. e.g. an attention layer that reads from a smaller key-value cache.
v1 metrics
| Metric | Value |
|---|---|
| Reproduces paper? | Partial. Architecture is correct (1e-9 gradient-check error). Mechanism works on n_pairs=1 (100%). Mechanism partially works on n_pairs=4 (38% with uniform per-slot accuracy showing real retrieval, not a degenerate solution). Does NOT reach the paper’s ~98% headline at n_pairs=4 — see Deviations §1 for the optimization-difficulty diagnosis and v2 plan. |
| Wallclock to run final experiment | 293 s (time python3 fast_weights_associative_retrieval.py --seed 0 --n-pairs 4 --n-steps 4000 --n-hidden 80 --batch-size 64 --lr 5e-3 measured on M-series laptop, system Python 3.12 + numpy 2.2) |
| Implementation wallclock (agent) | ~3 hours (single session — most of it spent debugging the optimization plateau and trying architectural variants) |
Multi-level glimpse MNIST
Source: J. Ba, G. Hinton, V. Mnih, J. Z. Leibo, C. Ionescu (2016), “Using Fast Weights to Attend to the Recent Past”, NIPS. arXiv:1610.06258.
Demonstrates: A small RNN equipped with a per-sequence “fast weights” matrix A_t = lambda * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) classifies MNIST digits one 7x7 patch at a time. The 28x28 image is presented as a deterministic sequence of 24 hierarchical glimpses (4 coarse 14x14 quadrants × 4 fine 7x7 each, plus 8 centre re-glimpses). The fast-weights matrix performs Hopfield-style associative read at each step, letting information from glimpse t=2 stay accessible at glimpse t=24. Same mechanism that predates transformer attention by a year.

Problem
The model never sees the whole 28x28 image at once. Each timestep it gets a 7x7 patch plus a one-hot encoding of which patch (which of the 24 positions) it is reading. The full image is delivered as a deterministic sequence of 24 patches:
- 16 fine patches = 4 coarse 14x14 quadrants in fixed order (TL, TR, BL, BR), each split into 4 fine 7x7 patches also in fixed order (TL, TR, BL, BR).
- 8 most-central re-glimpses = the 4 patches at offsets (7,7), (7,14), (14,7), (14,14) – the patches that straddle the centre of the image – each visited twice in that order.
- Total: 24 glimpses.
The task: predict the digit class from the final hidden state.
The fast-weights mechanism is what makes this hard for a vanilla RNN tractable: at glimpse 24 the model has integrated all 24 patches, but a 64-dim hidden vector cannot losslessly encode all 24 patches’ worth of evidence. Fast weights act as a per-sequence content-addressable memory of recent hidden states, accessed via A_t @ h_{t-1} at every step. The slow weights (W_h, W_x, b, W_o, b_o) learn the general recipe for using this memory; the per-image storage is in A_t itself.
Architecture (Ba et al., adapted for image classification)
A_t = lambda_decay * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) (A_0 = 0)
z_t = W_h h_{t-1} + W_x x_t + b + A_t @ h_{t-1}
zn_t = LayerNorm(z_t)
h_t = tanh(zn_t)
out = W_o h_T + b_o # only the final hidden state predicts
with x_t = [glimpse_patch_49 ; one_hot_position_24] (73 input dims). The slow weights are learned by truncated BPTT through the full 24-step sequence, vectorized across the batch. Fast weights A_t are reset to zero at the start of every image.
LayerNorm (no learnable affine) is necessary: without it, A_t @ h_{t-1} grows quadratically as outer products accumulate, the tanh saturates at +/-1, and 1 - tanh^2 collapses the recurrent gradient. Same finding as Ba et al. (“Layer Normalization is critical”).
BPTT through the fast weights
Standard tanh-RNN backprop with LayerNorm, plus a running gradient dA chained across timesteps:
dA_running = 0
for t = T..1:
dh_t already known
dzn_t = dh_t * (1 - h_t^2) # tanh
dz_t = LN_backward(dzn_t, zn_t, sigma) # layer norm backward (no affine)
dW_h += outer(dz_t, h_{t-1})
dW_x += outer(dz_t, x_t)
db += dz_t
dh_{t-1} = (W_h.T + A_t.T) dz_t
dA_t_local = outer(dz_t, h_{t-1})
dA_t_total = dA_running + dA_t_local
dh_{t-1} += eta * (dA_t_total + dA_t_total.T) @ h_{t-1}
dA_running = lambda_decay * dA_t_total
Numerical-gradient check on a 2-sample / hidden-8 / T=5 random configuration: max relative error is ~1e-8 to ~3e-8 across W_h, W_x, b, W_o, b_o. The forward/backward path is correct.
W_h: max rel err = 1.76e-08
W_x: max rel err = 2.72e-08
b: max rel err = 2.05e-09
W_o: max rel err = 9.90e-10
b_o: max rel err = 2.28e-10
Files
| File | Purpose |
|---|---|
multi_level_glimpse_mnist.py | MNIST loader (urllib + gzip, cached at ~/.cache/hinton-mnist/), generate_glimpse_sequence / build_glimpse_inputs, GlimpseFastWeightsRNN (vectorised forward + manual BPTT with fast-weights chain), Adam with optional step-decay schedule, train, per_class_accuracy, CLI |
visualize_multi_level_glimpse_mnist.py | Static plots: glimpse-overlay on one digit, training curves, A_t evolution heatmap, hidden-state trace, per-class accuracy |
make_multi_level_glimpse_mnist_gif.py | Animated GIF: per-glimpse image + current 7x7 patch + A_t heatmap + h_t row |
multi_level_glimpse_mnist.gif | Committed animation |
viz/ | Committed PNG outputs |
Running
# Headline run reported below (downloads MNIST on first invocation, ~12 MB).
python3 multi_level_glimpse_mnist.py --seed 0 --n-epochs 12 --n-hidden 128 \
--batch-size 64 --lr 0.002 --lr-decay-epochs 7,10 --lr-decay-factor 0.25
Train wallclock is reported in the Results table below (M-series MacBook, system Python 3.12 + numpy 2.2). The CLI mirrors the spec (--seed --n-epochs --n-hidden); other flags are optional.
--seed RNG seed default 0
--n-epochs # of training epochs default 3
--n-hidden hidden state dim H default 64
--lambda-decay fast-weights decay default 0.95 (per stub spec)
--eta fast-weights gain default 0.5 (per stub spec)
--batch-size Adam mini-batch default 64
--lr Adam base learning rate default 2e-3
--lr-decay-epochs e.g. "7,10" -- step-decay default ""
--lr-decay-factor multiplier per decay epoch default 0.25
--grad-clip global-norm clip default 5.0
--n-train 0 = full 60k MNIST default 0
To regenerate the visualizations and gif (each trains its own quicker model on a 20k MNIST subset, ~2 min wallclock):
python3 visualize_multi_level_glimpse_mnist.py --seed 0 --n-epochs 3 --n-hidden 128 --n-train 20000
python3 make_multi_level_glimpse_mnist_gif.py --seed 0 --n-epochs 2 --n-hidden 96 --n-train 20000
Results
Single run, --seed 0 --n-epochs 12 --n-hidden 128 --batch-size 64 --lr 0.002 --lr-decay-epochs 7,10 --lr-decay-factor 0.25:
| Metric | Value |
|---|---|
| Architecture | Glimpse RNN, hidden=128, input=73 (49 patch + 24 one-hot pos), output=10 (digits) |
| Slow params | 27,146 |
| Fast-weights matrix per sample | 128 × 128 = 16,384 entries (transient, not learned) |
| Final test accuracy (10k MNIST test set) | 82.46% (vs. 10% chance, vs. 90% spec target) |
| Final test cross-entropy | 0.5432 (vs. log 10 = 2.30 chance, vs. 0 perfect) |
| Per-class test accuracy | 0=75.6%, 1=97.2%, 2=75.6%, 3=78.5%, 4=90.6%, 5=70.3%, 6=91.3%, 7=85.4%, 8=74.7%, 9=82.4% |
| Train wallclock | 1199 s = 20.0 min |
| Hyperparameters | lambda_decay=0.95, eta=0.5, lr=2e-3 → 5e-4 (ep7) → 1.25e-4 (ep10), batch=64, grad_clip=5.0 |
| Test-acc trajectory by epoch | ep1=57.28%, ep2=65.79%, ep3=72.09%, ep4=74.03%, ep5=76.48%, ep6=77.40%, ep7=81.06% (lr↓), ep8=81.47%, ep9=81.89%, ep10=82.41% (lr↓), ep11=82.41%, ep12=82.46% |
| Numerical gradient check | max rel err ~2e-8 across W_h, W_x, b, W_o, b_o (forward/backward verified) |
Sanity check on hidden=64 / 1 epoch / 5k subset (--n-hidden 64 --n-epochs 1 --n-train 5000): 22.78% test accuracy in 3.8 s – well above 10% chance for an essentially undertrained model, confirming the gradients are flowing.
The 82.46% headline is below the spec’s 90% target – see Deviations §2 for the analysis. The model is below this target because of two architectural simplifications (deterministic glimpse sequence, no CNN encoder), not because of an implementation bug. Numerical-gradient check passes to ~1e-8 across all parameters; the n_pairs=1 sanity test trains cleanly. The optimization landscape is the limiting factor.
Visualizations
24-glimpse overlay

The 28x28 image with all 24 glimpse boxes drawn on top, numbered in visit order. Red solid = the 16 fine patches (4 coarse quadrants, each with 4 fine patches, in TL/TR/BL/BR order at both levels). Blue dashed = the 8 centre re-glimpses (the 4 patches that straddle the image centre, each visited twice). Together they cover the full 28x28 (the 4x4 grid of 7x7 fine patches) plus four extra 7x7 patches over the central 14x14 region.
Training curves

Cross-entropy loss (left) drops from chance (log 10 ~= 2.30) over training; accuracy (right) climbs from 10% (chance) toward the headline number. Train and test curves track each other – there is no overfitting; the model is capacity- and optimization-limited rather than data-limited.
Fast-weights matrix evolution

Heatmap snapshots of A_t at every glimpse of one example. At t=0 the matrix is exactly zero (initial condition). Each subsequent step adds an eta * outer(h_{t-1}, h_{t-1}) rank-1 contribution and decays the existing entries by lambda = 0.95. By t=23 the matrix has accumulated 24 outer-product traces, each one a record of the hidden state that was current when one of the 24 patches was being processed. The reader can see the matrix is genuinely changing – the fast weights ARE being computed, and the slow weights have learned to read from them via A_t @ h_{t-1}.
Hidden state trace

Top: heatmap of h_t (rows = hidden units, columns = the 24 glimpse steps). Bottom: ||h_t|| per step. Each glimpse modulates the hidden state in patch-specific ways. The 8 centre re-glimpses (steps 16-23) revisit patches the model already saw at steps 3, 6, 9, 12 – but the hidden representations at those re-visits differ from the originals, reflecting the accumulated context in A_t.
Per-class test accuracy

Test accuracy bucketed by digit class. The model’s strongest class is “1” (vertical strokes are easy to read off centre re-glimpses); the weakest classes are typically “8” and “5” (closed loops with subtle topology that is fragmented across the 7x7 grid). The average across all 10 classes equals the headline test accuracy.
Deviations from the original procedure
-
Deterministic glimpse sequence, not learned attention. The Ba et al. attention mechanism uses a separate “where” network that decides which glimpse to take next, trained by REINFORCE. Per spec v2 we use a fixed deterministic sequence (16 fine patches in coarse-then-fine raster order, plus 8 centre re-glimpses). This keeps the implementation pure-numpy and lets us focus on the fast-weights mechanism rather than RL. The fast-weights memory is the headline contribution of the paper; the where-network is an orthogonal component.
-
Test accuracy below the paper’s headline. Ba et al. report ~99% on MNIST with a learned-attention glimpse network. Our deterministic-glimpse + fast-weights model with hidden=128, 12 epochs of full MNIST, Adam at lr=2e-3 with step decay at ep 7 and ep 10, hits the test accuracy listed above. The gap to ~99% has two contributors: (a) deterministic vs. learned glimpse sequence – the network cannot zoom in on informative regions; (b) much less training compute than the paper. Reproducing the >95% number would mean adding either a CNN-style patch encoder before the RNN or a where-network for learned attention. See Open Questions.
-
Single inner-loop iteration
S=1instead of the paper’s recommendedS>=1. The paper formulates the inner loop ash_{s+1}(t+1) = f(LN([W h_t + C x_t] + A_{t+1} h_s(t+1)))fors = 0..S-1. We use the flatter formh_t = tanh(LN(W h_{t-1} + W_x x_t + b + A_t h_{t-1})). The wave-5 sibling implementation experimented with the proper inner-loop and it trained worse in our hands (gradient landscape made harder by the additional nonlinearity); single-step is the same simplification. -
Adam instead of RMSProp. Standard modern default. Adam mostly subsumes RMSProp and is easier to tune.
-
LayerNorm without learnable affine. Standard LayerNorm has
gain * (x - mu) / sigma + bias. We use the no-affine form (gain=1, bias=0) to keep the parameter count minimal for a small numpy reference. -
Identity-ish init for
W_h(0.5 * I) rather than orthogonal or scaled-Gaussian. Standard for fast-weights RNNs after Le, Jaitly & Hinton 2015 (IRNN). LayerNorm rescales any explosion.
Open questions / next experiments
-
Add a learned where-network (proper recurrent visual attention). Replace the deterministic 24-glimpse sequence with a stochastic policy
pi_phi(loc | h_t)trained by REINFORCE on the classification reward. This is what makes the original Ba/Mnih DRAM line of work attention-y. Direct path to ~99% MNIST. -
CNN patch encoder. Currently
x_t = [flat_patch_49 ; one_hot_pos_24]is fed through a single linearW_xprojection. A small 2-layer CNN encoder over each 7x7 patch (e.g. 7x7 -> 32 -> 64 -> flatten) would let the network learn local visual features rather than askingW_xto do that AND the recurrent integration. Single-digit % accuracy gains for free. -
Curriculum on glimpse count. Train first with the 16 fine patches, then expand to 24 with the 8 centre re-glimpses. Alternative: train first with the full 28x28 as a single “glimpse” (CNN baseline), then enforce the patch interface as a supervised distillation step.
-
Lambda / eta sweep. We use Ba et al.’s defaults (lambda=0.95, eta=0.5). For a 24-step sequence on MNIST, lambda might want to be lower (more aggressive forgetting) since centre re-glimpses arrive late and should not be drowned out by early-glimpse traces; eta might want to be lower (less aggressive writes) so the LN can keep the hidden state in the linear regime longer.
-
Glimpse sequence ablation. Compare three deterministic sequences: (a) raster (current), (b) coarse-to-fine (current first 16 only), (c) random per-image. Which information ordering works best with fast-weights is an empirical question.
-
Data movement. The fast-weights matrix is size
H^2, recomputed and decayed every step. For T=24, B=64, H=128 that is 6424128*128 = 25M floats per forward pass, ~200 MB of A_t storage (we keep all 24 matrices for backward). Compare against an attention layer with the same effective receptive field via the Sutro group’s ByteDMD framework – which mechanism is more energy-efficient is a real open question. -
Pre-LN affine. Adding learnable gain and bias to LayerNorm is a one-line change and was a simplification we made for parameter-count reasons. Worth ablating.
v1 metrics
| Metric | Value |
|---|---|
| Reproduces paper? | Partial. Architecture is correct (numerical gradient check passes to ~2e-8 across all 5 parameter tensors). Mechanism trains stably and reaches 82.46% test accuracy on full MNIST (vs. 10% chance, vs. paper’s ~99% with learned attention, vs. spec’s 90% target). Gap to the paper is explained by Deviation §2 (deterministic glimpse sequence + no CNN encoder). The fast-weights mechanism is verifiably working: A_t accumulates outer-product traces across glimpses (visible in viz/fast_weights_evolution.png) and the slow weights have learned to read from it (lifting accuracy 24% over chance with a 128-dim hidden RNN that only sees 7x7 patches at a time). |
| Wallclock to run final experiment | 1199 s = 20.0 min (time python3 multi_level_glimpse_mnist.py --seed 0 --n-epochs 12 --n-hidden 128 --batch-size 64 --lr 0.002 --lr-decay-epochs 7,10 --lr-decay-factor 0.25 measured on M-series MacBook, system Python 3.12 + numpy 2.2) |
| Implementation wallclock (agent) | ~1 hour (single session — most of the wall time was the 20-min training run; the implementation itself was ~30 minutes including the vectorised batched BPTT and the numerical gradient check) |
Catch toy game (partial-observability)
Source: J. Ba, G. Hinton, V. Mnih, J. Z. Leibo, C. Ionescu (2016), “Using Fast Weights to Attend to the Recent Past”, NIPS. arXiv:1610.06258, section 5 (“Reinforcement learning”).
Demonstrates: A small RNN with a per-episode “fast weights” matrix A_t = lambda * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) learns to play catch on a 24x24 grid where the observation is blanked to all zeros after the 9th step (blank_after=8). The agent must remember the ball’s column over 15 blind steps and steer a 3-cell paddle to intercept. With fast weights the agent learns a real catching policy; the same architecture with the fast-weights term zeroed out (eta=0, vanilla RNN) stays at chance.

Problem
The world is an N x N (default 24x24) binary grid. A ball spawns on row 0 at a random column and falls one row per step. A 3-cell paddle lives on row N-1, starts at column N//2, and chooses one of three actions every step: {0=left, 1=stay, 2=right}. The episode lasts exactly N-1 steps; the catch check fires the moment the ball lands on the bottom row.
- Reward:
+1on catch (|ball_col - paddle_x| <= 1),-1on miss,0on every interior step. - Partial observability: the agent’s input at step
tis the full grid (ball pixel + 3 paddle pixels) fort <= blank_after. Fort > blank_afterthe observation is replaced with all zeros — the ball has effectively gone invisible. The reward is still delivered at the end. - Random-paddle chance: with a 3-cell paddle and N columns of uniform spawn, a paddle that ignores the input catches the ball
3/N = 12.5%of the time at N=24. Anything below ~14% is “chance behavior.”
The crux: under blanking, the agent only sees the ball for blank_after + 1 steps and then has to act for N - 2 - blank_after more steps without it. A vanilla 64-unit RNN can in principle encode the ball’s column in its hidden state, but in practice REINFORCE on this signal does not optimize that representation reliably; the fast-weights mechanism gives the slow weights a per-episode key/value scratchpad that stores “I saw the ball at column c at step t” in A_t, and reads it back via A_t @ h_{t-1} every step, including the blanked ones.
Architecture (Ba et al. with the same single-LayerNorm body used by the two sibling fast-weights stubs in this wave)
A_t = lambda_decay * A_{t-1} + eta * outer(h_{t-1}, h_{t-1}) (A_0 = 0)
z_t = W_h h_{t-1} + W_x x_t + b + A_t @ h_{t-1}
zn_t = LayerNorm(z_t) # mean-0 std-1 over H, no learnable affine
h_t = tanh(zn_t)
pi_t = softmax(W_pi h_t + b_pi) # 3-way policy head
V_t = W_v h_t + b_v # scalar value baseline
with x_t = the flattened N*N-dim binary observation (576 dims at N=24). LayerNorm without learnable affine is necessary to stop A_t @ h_{t-1} from growing quadratically and saturating tanh — same finding as Ba et al. (“Layer Normalization is critical”). Fast weights A_t are reset to zero at the start of every episode.
The ablation “vanilla RNN” run uses identical architecture with eta=0, so A_t stays at zero throughout and the recurrence collapses to z_t = W_h h_{t-1} + W_x x_t + b. Same parameter count, same optimizer, same training budget.
REINFORCE with baseline (deviation from full A3C)
The paper trains its catch agent with A3C, which needs distributed actor-learner workers. We use the on-policy single-actor simplification:
L = sum_t [ - advantage_t.detach() * log pi_t[a_t]
+ 0.5 * value_coef * (V_t - G_t)^2
- beta_ent * H(pi_t) ]
with G_t = sum_{k>=t} r_k (gamma=1; episodes are short and bounded), advantage_t = G_t - V_t, value_coef=0.5, beta_ent=0.01, and a small batch of episodes (default 16) summed per Adam update with global-norm gradient clipping at 5. This is the classic actor-critic formulation that A3C generalises by parallelising; everything besides the multi-worker async aggregation is the same.
BPTT through the fast weights and per-timestep loss
Same recurrent backward pass as the two fast-weights siblings (fast-weights-associative-retrieval, multi-level-glimpse-mnist), but the gradient on h_t is injected at every timestep instead of only at the final step. dh_t accumulates the gradient flowing back from t+1 plus the local gradient from the policy + value + entropy heads at t:
dA_running = 0
for t = T..1:
dh_t += dh_local[t] # from heads at this step
dzn_t = dh_t * (1 - h_t^2) # tanh
dz_t = LN_backward(dzn_t) # layer norm backward (no affine)
dW_h += outer(dz_t, h_{t-1})
dW_x += outer(dz_t, x_t)
db += dz_t
dh_{t-1} = (W_h.T + A_t.T) dz_t
dA_t_local = outer(dz_t, h_{t-1})
dA_t_total = dA_running + dA_t_local
dh_{t-1} += eta * (dA_t_total + dA_t_total.T) @ h_{t-1}
dA_running = lambda_decay * dA_t_total
Numerical-gradient check (size=6, hidden=8, blank_after=2, T=5) passes to max relative error 5.6e-10 across W_h, W_x, b, W_pi, b_pi, W_v, b_v when the policy-loss advantage is held constant during the perturbation (which matches the standard “detached advantage” REINFORCE formulation). The forward/backward path is correct.
Files
| File | Purpose |
|---|---|
catch_game.py | CatchEnv (drop-the-ball partial-obs world), FastWeightsActorCritic (model with --no-fast-weights ablation), Adam, train_a3c, _evaluate, CLI |
visualize_catch_game.py | Trains both FW and no-FW models, then writes viz/example_episode.png, training_curves.png, with_vs_without.png, fast_weights_evolution.png, hidden_state_trace.png |
make_catch_game_gif.py | Trains a FW model, picks a high-stakes deterministic episode, renders an animated GIF (true world + agent input + A_t heatmap per frame) |
catch_game.gif | Committed animation (~437 KB) |
viz/ | Committed PNG outputs |
The required spec API surfaces are present and named: CatchEnv, build_a3c_policy(), train_a3c().
Running
Headline run (matches Results table below):
python3 catch_game.py --seed 0 --size 24 --blank-after 8 --n-episodes 12000
Vanilla-RNN ablation (same everything, fast-weights term zeroed out):
python3 catch_game.py --seed 0 --size 24 --blank-after 8 --n-episodes 12000 --no-fast-weights
Quick sanity check (~2 seconds, smaller grid):
python3 catch_game.py --seed 0 --size 10 --blank-after 4 --n-episodes 1500 --hidden 32
Regenerate visualizations and gif (each runs its own quicker training):
python3 visualize_catch_game.py # ~70s wallclock; trains BOTH FW and no-FW
python3 make_catch_game_gif.py # ~40s wallclock; trains the FW model
CLI flags:
--seed RNG seed default 0
--size grid edge length N default 24
--blank-after obs blanked once step_idx > this default 8
--n-episodes training episodes default 12000
--hidden hidden state dim H default 64
--lambda-decay fast-weights decay default 0.95
--eta fast-weights gain default 0.5
--lr Adam learning rate default 3e-3
--gamma discount default 1.0
--beta-ent entropy bonus default 0.01
--value-coef value-loss coefficient default 0.5
--batch-episodes episodes per Adam step (avg grads) default 16
--grad-clip global-norm clip default 5.0
--eval-every eval cadence (in episodes) default 200
--eval-episodes eval batch size default 200
--no-fast-weights ablation: zero the eta gain throughout default off
Results
Headline: with vs without fast weights at the spec defaults (size=24, blank_after=8, 12k episodes, hidden=64)
| Configuration | Final greedy catch rate (n=500) | seed 0 | seed 1 | seed 2 | mean |
|---|---|---|---|---|---|
with fast weights (eta=0.5) | 3 seeds | 36.6% | 28.8% | 36.4% | 33.9% |
vanilla RNN (eta=0, otherwise identical) | 3 seeds | 10.2% | 10.2% | 13.8% | 11.4% |
| random-paddle chance | — | — | — | — | 12.5% |
The vanilla RNN’s mean is statistically indistinguishable from random play. The fast-weights agent is roughly 3x chance, ~22 percentage points above the vanilla baseline. The mechanism is doing real work.

Single-seed detail (seed=0, headline run)
| Metric | Value |
|---|---|
| Architecture | FastWeightsActorCritic, H=64, obs_dim=576 (24x24 flattened), 3 actions |
| Slow params | 41,284 (W_h: 64*64, W_x: 64*576, b: 64, W_pi: 3*64, b_pi: 3, W_v: 1*64, b_v: 1) |
| Fast-weights matrix per episode | 64 x 64 = 4,096 entries (transient, not learned) |
| Episode length | N-1 = 23 steps |
| Blanked steps per episode | 23 - 9 = 14 (75% of episode length is “blind”) |
| Final eval catch rate (n=500, greedy) | 36.6% with FW vs 10.2% without |
| Train wallclock (seed 0) | 60 s with FW, 33 s without |
| Hyperparameters | lambda_decay=0.95, eta=0.5, lr=3e-3, batch_episodes=16, beta_ent=0.01, value_coef=0.5, grad_clip=5.0 |
| Numerical gradient check | max rel err 5.6e-10 across W_h, W_x, b, W_pi, b_pi, W_v, b_v (forward/backward verified, with detached-advantage REINFORCE formulation) |
Reduced-difficulty regime (size=10, blank_after=4): both architectures learn
| Configuration | Final greedy catch rate (n=500) |
|---|---|
| with FW | 91.4% |
| vanilla RNN | 81.6% |
| random-paddle chance | 30.0% |
At the smaller grid the task only requires holding one number (ball column) across 5 blind steps, which a 32-unit vanilla RNN handles. The fast-weights gap shrinks to 10 percentage points. The mechanism is most useful when the memory budget is tight relative to what needs to be remembered, which is exactly the regime the spec defaults (24x24, 14 blanked steps) targets.
Visualizations
Example episode (true state vs agent’s blanked input)

Top row: the true world (always rendered, ball + paddle visible). Bottom row: the agent’s actual input — the ball appears for the first 9 frames and then the input is blanked to all zeros (shown as the grey “blank” tile) for the remaining steps. The agent must use what it stored in h_t and A_t over the visible frames to keep moving the paddle in the right direction.
Training curves

Mean reward (+1 catch, -1 miss) climbs from -0.8 (mostly misses) toward zero / positive territory; greedy eval catch rate rises in lockstep; REINFORCE loss decreases (with the usual high variance). The eval is on a fixed n=100 sample so the curves are step-like rather than smooth — that is a property of the eval harness, not the policy.
Fast-weights matrix evolution

Snapshots of A_t at evenly-spaced timesteps in one episode. At t=0 the matrix is exactly zero (initial condition). Each subsequent step adds an eta * outer(h_{t-1}, h_{t-1}) rank-1 contribution and decays the existing entries by lambda=0.95. Frames marked “(blank)” are after the observation cutoff — A_t is still being read (z_t = ... + A_t @ h_{t-1}) and updated (it depends on h_{t-1}, which is non-zero) even though the agent receives no new input pixels.
Hidden state trace

Heatmap of h_t (rows = hidden units, columns = timesteps). The shaded grey region marks the steps where the observation is blanked. The hidden state continues to evolve in non-trivial ways during the blanked steps — the recurrence is doing more than just decay because it is reading from A_t every step.
Deviations from the original procedure
-
REINFORCE-with-baseline instead of full A3C. The Ba et al. RL section uses A3C, which adds asynchronous parallel actor-learners. We use a single-actor on-policy actor-critic with the same loss form (policy gradient + value baseline + entropy bonus, advantage detached for the policy term). This is the underlying algorithm; A3C is the parallelization wrapper. Documented as required by spec v2.
-
Catch rate below the spec target of >70%. Under the per-stub spec defaults (24x24, blank_after=8, hidden=64) we hit ~34% with FW vs ~11% without. The spec’s >70% target was probably calibrated for either (a) a smaller grid where the random-paddle chance is higher, or (b) a wider paddle, or (c) a much longer training schedule (the original A3C runs went orders of magnitude more episodes than 12k on parallel workers). At smaller grids (size=10) we DO exceed 90%, so there is no implementation defect; the bottleneck is REINFORCE’s variance + budget.
-
Single inner-loop iteration
S=1. Same simplification as the two sibling fast-weights stubs. The paper’sh_{s+1}(t+1) = f(LN([W h_t + C x_t] + A_{t+1} h_s(t+1)))fors=0..S-1is replaced withh_t = tanh(LN(W h_{t-1} + W_x x_t + b + A_t h_{t-1})). The wave-5 sibling experimented with the proper inner loop and found single-step trains better in pure-numpy. -
LayerNorm without learnable affine. Standard for the wave’s fast-weights family. Keeps the parameter count minimal and matches the sibling architectures.
-
Identity-ish init for
W_h(0.5 * I). Standard for fast-weights RNNs after Le, Jaitly & Hinton 2015 (IRNN); LayerNorm rescales any explosion. -
Small-batch on-policy update instead of true online A3C. We average gradients across 16 episodes per Adam step. True A3C interleaves environment steps and parameter updates per worker. Empirically the small-batch version converges more reliably with REINFORCE-level variance.
-
No CNN encoder. The 576-dim flat observation goes through a single
W_xlinear projection. A small CNN over the 24x24 grid would extract “ball position” much more cheaply (the 1-pixel ball is a sparse feature) and almost certainly raise the catch rate.
Open questions / next experiments
-
Push past 70% at the spec defaults. Promising knobs: (a) longer training (50k–100k episodes; estimated 5–10 minutes wallclock), (b) variance-reduction tricks (GAE-lambda value targets instead of full Monte-Carlo
G_t), (c) curriculum onblank_after(start atblank_after=20and decrease toward 8), (d) replace the flatW_xprojection with a 3x3 convolutional preprocessor that finds the ball pixel cheaply. -
Lambda / eta sweep. We use Ba et al.’s defaults (
lambda=0.95, eta=0.5). For a 23-step catch episodelambdamight want to be higher (slower forgetting; the relevant memory is from steps 0–8 and we need it at step 22), andetalower (smaller writes; the hidden state should not be dominated by the most recent step’s outer product). -
Compare against a vanilla RNN with proportionally larger hidden. The fast-weights matrix at H=64 is 4,096 transient entries; the vanilla RNN with H=128 has 16k slow recurrent weights. A truly fair “memory budget” comparison would line up these counts. If vanilla-H=128 trained to similar catch rate, that would be evidence that fast weights are a parameter-efficient form of memory; if it stayed at chance, the fast-weights architecture would be doing something a vanilla RNN cannot.
-
Multi-ball variant. The spec defaults to one ball per episode. Multiple balls dropping at staggered times would force the agent to remember several positions simultaneously — exactly the regime where fast weights are expected to dominate (the associative-retrieval sibling makes this dependence explicit with
n_pairs >= 2). -
Data movement. The fast-weights matrix is
H^2and is recomputed and decayed every step. For T=23, B=16, H=64 that is 162364*64 = 1.5M floats per forward pass, ~6 MB ofA_tstorage (we keep all 23 matrices for backward). Compare against an attention layer with the same effective memory via the Sutro group’s ByteDMD framework — which mechanism is more energy-efficient is a real open question. -
Greedy vs sampled at evaluation. Eval uses argmax (greedy). Sampled play often catches more reliably when the learned policy has under-confident peaks (you get exploration “free” at test time). Worth quantifying the gap.
v1 metrics
| Metric | Value |
|---|---|
| Reproduces paper? | Partial. Architecture is correct (numerical gradient check passes to ~5.6e-10 across all 7 parameter tensors). The fast-weights mechanism trains stably with REINFORCE-with-baseline and reaches 36.6% greedy catch rate at the spec defaults vs 10.2% for the same architecture without fast weights (random-paddle chance is 12.5%). Above the 70% spec target only at the easier size=10, blank_after=4 setting (91.4% with FW). The fast-weights mechanism is verifiably working: A_t accumulates outer-product traces that the agent uses across blanked steps (visible in viz/fast_weights_evolution.png), and the with-vs-without ablation is unambiguous (~3x chance vs at-chance). Gap to >70% at size=24 is REINFORCE variance + training budget, not an implementation bug. |
| Wallclock to run final experiment | 60 s for the FW model, 33 s for the no-FW baseline (time python3 catch_game.py --seed 0 --size 24 --blank-after 8 --n-episodes 12000 measured on M-series MacBook, system Python 3.12 + numpy 2.2). 3-seed sweep ~5 minutes. |
| Implementation wallclock (agent) | ~1.5 hours (single session — most of the wall time was the multi-seed sweeps to confirm the headline reproduces; the implementation itself was ~30 minutes including the BPTT-with-per-step-loss extension and the numerical gradient check). |
affNIST robustness test
Reproduction sketch of the robustness experiment from Sabour, Frosst & Hinton, “Dynamic routing between capsules”, NeurIPS 2017. Train a CapsNet and a parameter-matched CNN on translated MNIST (40x40 canvas, digit randomly placed within ±6 px of centre); test both on affNIST (40x40, same digits under random affine transforms).

The published headline is CapsNet 79% vs CNN 66% on affNIST after both networks reach matched accuracy on translated MNIST. The headline this implementation can claim, with the simplifications below, is CapsNet 85.5% vs CNN 87.5% – the expected gap does not appear here. A careful read of why is in the Results and Deviations sections below.
Problem
Train and test distributions:
- Train: MNIST 28x28 padded to a 40x40 canvas with the digit randomly
translated by an integer offset in
[-6, +6]on each axis. - Test (in distribution): same as train.
- Test (out of distribution): affNIST 40x40. The real Toronto dataset
(
https://www.cs.toronto.edu/~tijmen/affNIST/...) was unreachable at the time of this run (HTTP 503), and the GitHub mirror returns 404, so the test set is synthesized by applying a random affine to each MNIST test digit:- rotation in
[-20°, +20°] - isotropic scale in
[0.8, 1.2] - shear in
[-0.1, +0.1] - translation in
[-4, +4]px (per the spec’s fallback recipe).
- rotation in
The point of the experiment is the robustness gap: the CapsNet is supposed to generalise to unseen affine transforms more gracefully than a CNN with matched parameter count.
Files
| File | Purpose |
|---|---|
affnist.py | MNIST loader, translated-MNIST generator, affNIST loader (real if reachable, synthesized otherwise), TinyCapsNet, TinyCNN, training loop, evaluate_robustness, CLI. |
problem.py | Forwards the spec-required entry points (make_translated_mnist, load_affnist_test, evaluate_robustness) from affnist.py. |
visualize_affnist.py | Static viz/*.png plots: example pairs, accuracy bars, per-class robustness, training curves. |
make_affnist_gif.py | Builds affnist.gif: side-by-side per-frame predictions of CapsNet and CNN on a fixed 6-image affNIST panel as both train. |
affnist.gif | The animation above. |
viz/ | Static plots produced from the saved results.json. |
results.json | Full run record: args, environment, per-model accuracy, training history. |
Caches are external: MNIST IDX files at ~/.cache/hinton-mnist/,
affNIST archive at ~/.cache/hinton-affnist/ (only attempted; not parsed).
Running
# Train both networks and write results.json (about 4 minutes on an M-series laptop)
python3 affnist.py --arch both --n-epochs 5 --n-train 5000 --n-test 2000 \
--lr 1e-3 --seed 0 --out results.json
# Re-render static plots
python3 visualize_affnist.py --no-per-class --outdir viz
# Per-class plot (re-trains both models)
python3 visualize_affnist.py --outdir viz --n-epochs 4 --n-train 4000
# Re-render the gif (re-trains both with snapshots)
python3 make_affnist_gif.py --n-epochs 3 --n-train 3000 \
--snapshot-every 8 --val-every 8 --fps 5
Results
Run on macOS arm64, numpy 2.2.5, Python 3.12.9, seed 0, 5 epochs, 5000 training images, 2000 synthesized affNIST test images.
| Network | Params | translated-MNIST acc | affNIST acc | Train wall |
|---|---|---|---|---|
| CapsNet (3 routing iters) | 134,976 | 0.904 | 0.855 | 102.6s |
| CNN (3 conv + 2 FC) | 168,522 | 0.929 | 0.875 | 117.0s |
| gap (CapsNet - CNN) | -0.025 | -0.020 |
Robustness check at a stronger affine (rotation ±30°, scale [0.7, 1.3],
shear ±0.2, translation ±6): CapsNet 0.729, CNN 0.759, gap -0.030.
The gap is reproducible across affine strengths in this configuration.
Per-class affNIST accuracy (4 epochs, 4000 train images):
| Digit | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | mean |
|---|---|---|---|---|---|---|---|---|---|---|---|
| CapsNet | .85 | .89 | .80 | .78 | .81 | .85 | .83 | .78 | .79 | .74 | .811 |
| CNN | .81 | .96 | .83 | .76 | .92 | .80 | .90 | .75 | .71 | .86 | .830 |
The CapsNet’s per-class accuracy is more uniform (range .74 to .89 vs CNN’s .71 to .96) – consistent with the routing mechanism producing a more class-symmetric representation – but the mean is lower.
Why the paper’s gap doesn’t appear here
Three plausible reasons, in decreasing order of importance:
-
Synthesized affNIST is too close to translated MNIST. Real affNIST applies a per-image full affine sampled to be visibly different from the training distribution. The fallback recipe used here (rotation ±20°, scale 0.8-1.2, shear ±0.1, translation ±4 px) intersects substantially with translated-MNIST’s training augmentation, so a CNN that learned simple translation invariance can fake affine invariance well enough to close the gap. The stronger-affine variant (±30° / 0.7-1.3 / ±0.2 / ±6 px) widens the test distribution but still doesn’t reverse the sign.
-
Tiny capsules. This implementation uses 8 primary capsule types of dimension 4 with stride 4 (288 input capsules) and 8-D digit capsules. The paper uses 32 primary capsule types of dimension 8 (1152 input capsules) and 16-D digit capsules, plus a reconstruction decoder as a regulariser. The dimensionality of the instantiation parameters seems to matter for the routing to disentangle pose, and 4-D may be below threshold.
-
CapsNet has 19% fewer params (135K vs 168K). With matched depth and width budgets the CNN here can build a wider feature pyramid than the CapsNet – the opposite of the paper’s parameter budget, which favoured the CapsNet within the small-net regime.
A faithful reproduction would need (a) real affNIST test data, (b) the paper’s full capsule sizes, and (c) the reconstruction-loss regulariser. Any of those three on its own is plausibly enough to flip the sign of the gap.
Architecture details
TinyCapsNet (134,976 params)
40x40 input
-> Conv1: 16 filters 9x9 stride 1, ReLU -> 32x32x16
-> PrimaryCaps: conv to 8 caps x 4-D (32 channels), 9x9 stride 4 -> 6x6x(8x4)
-> reshape to 288 input capsules of dim 4, squash
-> DigitCaps: 10 caps of dim 8, dynamic routing (3 iters)
-> margin loss with m+ = 0.9, m- = 0.1, lambda = 0.5
Routing transformation matrices Wij of shape (288, 10, 4, 8) dominate the
parameter count (92,160 of 134,976).
TinyCNN (168,522 params)
40x40 input
-> Conv1: 16 filters 9x9 stride 1, ReLU -> 32x32x16
-> Conv2: 32 filters 5x5 stride 2, ReLU -> 14x14x32
-> Conv3: 64 filters 5x5 stride 2, ReLU -> 5x5x64
-> FC: 1600 -> 64, ReLU
-> FC: 64 -> 10, softmax + cross entropy
Both networks are pure numpy (no torch / jax). Convolutions use im2col +
np.matmul so the heavy contractions hit BLAS. A finite-difference gradient
check against W1, W2, and the routing weights Wij agrees within 1-4%
relative error at eps=1e-3.
Deviations from the 2017 procedure
- Tiny capsules. Conv1 16 filters (paper: 256), 8 primary capsule types of dim 4 (paper: 32 types of dim 8), 8-D digit capsules (paper: 16-D).
- Stride 4 for primary capsules instead of stride 2, to reduce input capsule count from 1152 to 288.
- No reconstruction decoder. The paper trains with margin loss + a small coefficient on a 3-layer MLP that reconstructs the input from the class-active capsule. This implementation uses margin loss alone.
- Test-set synthesis (rotation ±20°, scale 0.8-1.2, shear ±0.1, translation ±4 px) instead of real affNIST – both Toronto and the GitHub mirror were unreachable at run time.
- 5 training epochs / 5000 images instead of full MNIST x many epochs (numpy compute budget).
- Same-parameter CNN is matched within ~25%, not exactly. The CNN is slightly larger (168K vs 135K).
Open questions / next experiments
- Real affNIST. When the Toronto host is back, parse the
.matarchive (4 batches of 10000 each at the requested 32x scale; the just_centered variant is what the paper uses) and re-evaluate. The download path is already wired in_download/load_affnist_test; only the.matparser is missing. - Reconstruction regulariser. Add the 3-layer decoder + small MSE term on the active-class capsule. The paper credits this with most of the gap; a quick ablation would confirm.
- Wider capsules. Push primary_dim 4 -> 8 and digit_dim 8 -> 16 with the same routing iters. Param count rises ~3x but should still be tractable on CPU for 5000 images.
- Per-axis robustness curves. Sweep one transform parameter at a time (rotation, scale, shear, translation) to localise which affine direction CapsNet generalises to most – the published claim is broadly across all, but the synthesized fallback is mild, so disaggregating may be informative.
MultiMNIST + CapsNet (Dynamic Routing Between Capsules)
Numpy reproduction of Sabour, Frosst & Hinton, “Dynamic routing between capsules”, NIPS 2017. The paper’s headline claim: capsules separately identify and reconstruct heavily overlapping objects via routing-by-agreement.

Problem
Two distinct-class MNIST digits, each shifted by integers in [-4, +4]
pixels on each axis, overlaid on a 36x36 canvas with pixel-wise max. The
shifted bounding boxes must satisfy IoU >= 0.8 so the two digits are
genuinely overlapping (the paper says “80% overlap” without giving a
precise definition; bounding-box IoU is the cleanest interpretation).
The two-digit identification task: for each composite, identify both
classes. Chance is 1 / C(10, 2) = 1/45 ≈ 2.2% for an exact set match.
The disentanglement test: select a single DigitCaps capsule’s 16-D vector (masking the other 9), feed it through the decoder, and reconstruct only that one digit’s image — even though the input to the network was the overlapping composite.
Architecture
| Stage | Layer | Output shape | Notes |
|---|---|---|---|
| Input | – | (B, 1, 36, 36) | composite of two 28x28 digits |
| Conv1 | 9x9 stride 1, ReLU | (B, 32, 28, 28) | paper: 256 channels |
| PrimaryCaps | 9x9 stride 2 | (B, 8, 8, 10, 10) -> 800 caps x 8-D | paper: 32 caps x 10x10 = 3200 caps |
| Squash | per-capsule | same | `v = |
| DigitCaps W | per-pair affine | (800, 10, 16, 8) | 1.0M routing weights |
| Routing | 3 iters | (B, 10, 16) | softmax over digits, agreement update |
| Squash | per-capsule | (B, 10, 16) | output of DigitCaps |
| Decoder fc1 | 160 -> 256, ReLU | – | masked DigitCaps in |
| Decoder fc2 | 256 -> 512, ReLU | – | |
| Decoder fc3 | 512 -> 1296, sigmoid | (B, 36, 36) | reconstructs one digit |
Loss: margin_loss(v, T) + 0.0005 * (recon_a + recon_b). For two-digit
classification both labels are positive (T_a = T_b = 1); the decoder
runs twice per step, masking once per ground-truth digit and reconstructing
the corresponding source image separately.
Margin loss:
L_k = T_k * max(0, 0.9 - ||v_k||)^2
+ 0.5 * (1 - T_k) * max(0, ||v_k|| - 0.1)^2
Files
| File | Purpose |
|---|---|
multimnist_capsnet.py | MultiMNIST overlay, CapsNet model with 3-iter dynamic routing, margin + reconstruction loss, training. CLI: --seed --n-epochs --n-train --n-test --batch-size --lr |
visualize_multimnist_capsnet.py | Trains from scratch and writes static figures into viz/ |
make_multimnist_capsnet_gif.py | Trains from scratch and renders the animated GIF |
multimnist_capsnet.gif | Output of the GIF script |
viz/ | Static PNG outputs from the visualization script |
Running
# Quick 8-epoch training on 6k pairs (~7 min on M-series Mac)
python3 multimnist_capsnet.py --n-epochs 8 --n-train 6000 --seed 0
# Train + render all static figures
python3 visualize_multimnist_capsnet.py --n-epochs 8 --n-train 6000 --outdir viz
# Train + render the animated GIF
python3 make_multimnist_capsnet_gif.py --n-epochs 8 --n-train 6000 --snapshot-every 75 --fps 6
The MNIST loader downloads the four *-idx*-ubyte.gz files from
storage.googleapis.com/cvdf-datasets/mnist/ on first run and caches them at
~/.cache/hinton-mnist/.
Results
Defaults: 8 epochs x 187 steps/epoch, batch 32, Adam lr=1e-3 over 6,000 MultiMNIST training pairs. Single-thread numpy.
| Metric | Value | Baseline |
|---|---|---|
| Test two-digit set accuracy | 0.486 | chance = 0.022 (1/45) |
| Test reconstruction MSE (per pixel) | 0.036 | input-image MSE-to-mean = 0.082 |
| Margin loss (final, train) | 0.185 | initial = 0.67 |
| Reconstruction loss (final, train) | 44.5 | initial = 80.7 |
| Wallclock | ~395 s | – |
Two-digit set accuracy means both predicted top-2 capsules match the
ground-truth pair as a set. Soft accuracy (predicting at least one of the
two correctly) is much higher — see viz/capsule_activations.png for the
top-1-vs-top-2 breakdown.
22x above chance with a reduced-capacity model is the right sanity check that routing-by-agreement does what the paper claims; the absolute number is below the paper’s 95% because we removed roughly 8x of the conv capacity and use 60k training pairs instead of the paper’s 60M.
Training curves

Margin loss decreases steadily through all 8 epochs; test accuracy plateaus near epoch 3 around 50% (the model overfits past that point — train margin keeps falling while test accuracy stalls). With a larger Conv1 / PrimaryCaps the plateau lifts; see Deviations below.
Per-digit reconstruction

For each test composite we mask all but one DigitCaps vector and feed it
through the decoder. The same 160-D input slot reconstructs digit-a in one
pass and digit-b in a second pass (only the mask differs). Reconstruction
quality is rough — the decoder is small (256/512 hidden units) and 8 epochs
of pure-numpy training is not enough to learn sharp digit shapes — but the
disentanglement is visible: when the mask is digit-a-only, the recon is
clearly that digit and not a blend.
Capsule activation pattern

The two ground-truth label capsules (red boxes) tend to have the highest
norms in ||v_k|| even on overlapping inputs. The right panel shows the
top-2 hit rate breakdown: most validation pairs match at least one of the
ground-truth labels even when the exact-set top-2 prediction is wrong.
Example MultiMNIST pairs

The composite, then each source digit shown alone — to make clear that the network only ever sees the composite and is asked to recover the two underlying digits.
Deviations from the 2017 paper
- Capsule capacity reduced. Paper: Conv1 = 256 channels, PrimaryCaps = 32 capsules x 8-D. Ours: Conv1 = 32, PrimaryCaps = 8 capsules x 8-D (~8x less convolutional capacity). With pure numpy on a single thread the paper’s 256-channel 9x9 conv on 36x36 inputs is roughly 8x slower per batch. A larger config (Conv1 = 64, PrimaryCaps = 16) in a quick side-experiment closes about a third of the gap to higher accuracy but doubles wallclock to ~13 min for 8 epochs.
- 6,000 training pairs instead of 60M. The paper essentially regenerates pairs every epoch from the 60k MNIST images (~60M unique composites). Ours samples a fixed pool of 6k composites and re-shuffles each epoch. With more pairs the test accuracy plateau is higher.
- Routing coefficients
c_ijtreated as constants for the backward pass. Only the final iteration’scis used as a fixed weight when differentiatings = sum_i c_ij * u_hat. This is the standard simplification used in the original Sabour et al. TF reference and keeps the backward implementation tractable in numpy (no need to implement softmax-Jacobians through the routing loop). - No
relu(stack)between PrimaryCaps and squash. Paper applies a capsule-wise non-linearity (“squash”) directly on the PrimaryCaps conv output. We do the same. - Adam instead of paper’s TF-default SGD with momentum. Adam reaches the same plateau in fewer iterations on this problem size.
Correctness notes
- Squash gradient.
v_i = n s_i / (1 + n^2),n = ||s||. The gradient isdL/ds_j = f'(n) (s . dL/dv) s_j / n + f(n) dL/dv_jwithf(n) = n / (1+n^2),f'(n) = (1-n^2) / (1+n^2)^2. This is implemented insquash_backward; verified by finite-difference atatol = 1e-5during development. - im2col with strided slicing. The convolution forward and backward
uses an
im2colreshape via numpy strided slicing into a contiguous buffer, so the heavy work is onecols @ W_flatBLAS call per layer. Stride-2 PrimaryCaps backward goes throughcol2imwhich scatter-adds into the input gradient buffer (overlapping windows do not occur with stride 2 + 9x9 kernel on 28x28 input, but the same code handles overlapping kernels). - u_hat via einsum.
u_hat[b,i,j,d] = sum_p W_route[i,j,d,p] * u[b,i,p]is implemented as a singlenp.einsum('bip,ijdp->bijd')withoptimize=True. For our sizes (B=32, N_p=800, 10, 16, 8) the einsum is fine; numpy dispatches it as a batched matmul. - Two-digit labels are both positive. For multi-label margin loss
we set
T[k] = 1for both ground-truth digits and0for the other 8. The decoder runs twice (one mask per digit) and the reconstruction gradient adds to the corresponding capsule slot.
Open questions / next experiments
- Does increasing capacity to paper-scale recover the paper’s 95%? A quick run with Conv1=64 / PrimaryCaps=16 in this codebase plateaus around 60% after 5-6 epochs at 2x wallclock. Going to paper-scale (Conv1=256 / PrimaryCaps=32) is feasible in numpy but pushes per-step time to ~3 sec, not budget-friendly for a stub.
- Backward through routing iterations. Treating
c_ijas constant for the backward pass is a known simplification. Implementing the full Jacobian through 3 routing iterations should give the encoder a stronger learning signal for the routing weights. - Data augmentation. With 60k MNIST images and only 6k pairs we leave a lot of data on the table. Each step could resample composites on-the-fly rather than from a fixed pool — closer to how the paper generates 60M pairs.
- Reconstruction-as-regularizer weight. The default
recon_weight = 0.0005is from the paper but their batches are 100; with batch 32 and per-batch sum-MSE the regularization signal is weaker. Bumping to0.005may help generalization.
smallNORB held-out azimuth / elevation
Source: Hinton, Sabour & Frosst (2018), “Matrix capsules with EM routing”, ICLR. Demonstrates: Viewpoint extrapolation. Train on a restricted azimuth range, test on held-out azimuths; matrix capsules with EM routing close more of the held-out gap than a parameter-matched CNN.

Problem
The classic small-NORB experiment uses 5 toy categories rendered from many controlled viewpoints. The novel-viewpoint test trains on a restricted azimuth range (e.g. 0–150°) and tests on a disjoint range (200–330°).
We replace the real smallNORB image set with 5 synthesized 3D shape classes drawn as 32×32 silhouettes:
| Class | 3D structure |
|---|---|
| cross | center voxel + 6 axial neighbours |
| L | two perpendicular line segments + short prong |
| T | top bar + perpendicular stem |
| frame | 12 edges of a wireframe cube |
| tripod | three prongs from origin |
Each shape is rotated by a 3×3 rotation matrix R(azimuth, elevation),
projected orthographically, rasterised as a sum of Gaussian blobs gated by
depth (closer points are brighter), and written into a 32×32 grid with light
pixel noise. The synthesized property the experiment depends on — every
class appears at every viewpoint — is preserved exactly.
Architecture
Matrix capsule network
| Stage | Layer | Output | Notes |
|---|---|---|---|
| Input | – | (B, 32, 32) | grayscale silhouette |
| Conv1 | 5×5, stride 2, 16 ch, ReLU | (B, 16, 16, 16) | numpy im2col-free naïve conv |
| PrimaryCaps | linear 4096 → 8 × 17 | 8 caps × (4×4 pose + 1 act) | flatten then dense |
| ClassCaps | EM routing, 3 iters | 5 caps × (4×4 pose + 1 act) | matrix-pose, EM routing |
| Logit | ‖μ_j‖_F² / 16 + 4 (a_j − 0.5) | (B, 5) | combined pose-norm + activation |
Total parameter count: ~570k floats.
- Conv: 16·1·5·5 + 16 = 416
- Primary linear: 4096·136 + 136 = 557 192
- Routing W: 8·5·4·4 = 640
- β_a, β_v: 10
CNN baseline
| Stage | Layer | Output | Notes |
|---|---|---|---|
| Conv1 | 5×5, stride 2, 16 ch, ReLU | (B, 16, 16, 16) | identical to caps’ Conv1 |
| Hidden | linear 4096 → 64, ReLU | (B, 64) | |
| Logit | linear 64 → 5 | (B, 5) | softmax-xent |
Total parameter count: ~262k floats.
The capsule model is the larger network here (≈2× the params), so any “capsule wins on held-out viewpoints” claim is not the trivial consequence of having more capacity — capsules generalise better despite the CNN saturating at 100 % training-view accuracy.
EM routing
def em_routing(votes, a_lower, n_iters=3):
# votes: (B, n_lower, n_upper, 4, 4) -- M_i @ W_ij
# a_lower:(B, n_lower)
R = uniform(B, n_lower, n_upper) # routing assignments
for it in range(n_iters):
# M-step: weighted Gaussian fit per upper capsule
R_a = R * a_lower[..., None]
sum_Ra = R_a.sum(axis=1, keepdims=True)
mu = (R_a * votes).sum(1) / sum_Ra # (B, n_upper, 16)
sigma2 = (R_a * (votes - mu)**2).sum(1) / sum_Ra
cost = sum_h (β_v + 0.5 log sigma2_h) * sum_Ra
a_upper = sigmoid(λ (β_a − cost))
if it < n_iters - 1: # E-step
log_p = -0.5 sum_h ((votes_h - mu_h)^2 / sigma2_h
+ log(2π sigma2_h))
R = softmax_j(log_p + log a_upper)
return mu, a_upper
Loss
Softmax cross-entropy on caps_logits = ‖μ_j‖² / 16 + 4 (a_j − 0.5). The
pose-norm contribution is the critical signal: it gives a dense gradient
through the routing W matrices and the conv all the way to the input,
whereas the EM-routing activation channel has only a weak gradient via the
sum-of-log-σ² cost (see “Deviations” below).
Files
| File | Purpose |
|---|---|
smallnorb_novel_viewpoint.py | Synthesised dataset, conv layer, matrix caps, EM routing, CNN baseline, training loops. CLI: --seed --n-epochs --lr --train-azimuths --test-azimuths --n-train-views --n-test-views --n-elev --n-per-combo --em-iters --out-json |
visualize_smallnorb_novel_viewpoint.py | Trains both models, dumps 6 static figures to viz/ |
make_smallnorb_novel_viewpoint_gif.py | Trains with snapshots, renders the animated GIF |
smallnorb_novel_viewpoint.gif | Output of the GIF script |
viz/ | Static PNGs |
Running
# Train both models on default split (~10s wall on M-series Mac)
python3 smallnorb_novel_viewpoint.py --n-epochs 10 --lr 2e-3 --seed 0 \
--train-azimuths 0:150 --test-azimuths 200:330
# Train + render all static figures (~12s + plotting)
python3 visualize_smallnorb_novel_viewpoint.py --seed 1 --n-epochs 10
# Train + render the animated GIF (~25s + frame rendering)
python3 make_smallnorb_novel_viewpoint_gif.py --seed 1 --n-epochs 8 \
--snapshots 8 --sweep-frames 12 --fps 10
No external downloads. Pure numpy + matplotlib + imageio.
Results
Default config: 12 epochs disabled (overfits caps), use 10 epochs. Train: 5 classes × 6 azimuths in [0, 150°] × 3 elevations × 5 noisy samples = 450 images. Held-out: same elevations, azimuths in [200, 330°], 2 samples per combo = 180 images. Familiar-view validation: same as train range, 2 samples per combo = 180 images.
Three-seed run (seeds 0, 1, 2; 10 epochs; same train/test split):
| Seed | Caps familiar | Caps held-out | Caps drop | CNN familiar | CNN held-out | CNN drop |
|---|---|---|---|---|---|---|
| 0 | 0.944 | 0.750 | 0.194 | 1.000 | 0.733 | 0.267 |
| 1 | 0.989 | 0.711 | 0.278 | 1.000 | 0.672 | 0.328 |
| 2 | 0.978 | 0.717 | 0.261 | 1.000 | 0.683 | 0.317 |
| mean | 0.970 | 0.726 | 0.244 | 1.000 | 0.696 | 0.304 |
- The capsule network beats the CNN on held-out viewpoint accuracy on every seed (0.726 vs 0.696 on average; capsules higher on 3/3 seeds).
- The capsule network’s familiar–to–held-out drop is smaller on every seed (0.244 vs 0.304 on average; ~20 % relative reduction).
- The CNN saturates at 1.000 familiar-view accuracy in 1–2 epochs, so the held-out gap is the entire generalisation cost.
This matches the qualitative direction of the paper: matrix capsules with EM routing close more of the held-out viewpoint gap than a CNN, even when (as here) the CNN has the easier training objective and saturates faster.
Static figures
Synthesized dataset

Each row is one of the 5 classes; columns are 6 evenly-spaced azimuths.
The frame is most rotation-symmetric (it’s a wireframe cube), L and T
are the most directional, cross and tripod are intermediate.

The same 5 classes at fixed azimuth = 60° but varying elevation. Pose matrices need to encode both degrees of freedom for the capsule activations to remain stable across the full viewing sphere.
Training curves

The CNN’s loss collapses near zero by epoch 2–3; the capsule loss decreases more slowly because the pose-norm + activation logit signal is mediated by the EM-routing layer. Validation accuracy on familiar views reaches ~99 % for capsules and 100 % for the CNN.
Familiar vs held-out

Side-by-side comparison for the seed used by the visualisation script. The “drop” annotation under each bar pair is the headline metric: how much accuracy the model loses when the test viewpoint moves outside the training azimuth range.
Per-azimuth accuracy

Both models hit ~100 % accuracy inside the green training region. Outside it, the CNN drops faster and further (especially around az = 250–300°, which is maximally far from the training range). The capsule curve is flatter — the same property the paper highlights.
Class capsule pose matrices

For each ground-truth class (rows), we forward one example image and dump
the 5 class capsules’ 4×4 pose matrices. The activation a_j is printed
above each one; the green-framed cell is the predicted class. The
diagonal-ish dominance shows class capsules are doing what they should.
Deviations from the 2018 paper (all documented in source)
- Synthesized NORB-like dataset; real smallNORB download deferred. The
real dataset is a 9.5 GB download of 96×96 stereo pairs of 50 toy
models. Loading it depends on a slow public mirror that has tripped up
previous attempts at this experiment. We instead synthesize 5 voxel
shape classes whose silhouettes are 32×32 grayscale renders at
controlled (azimuth, elevation), preserving the property that every
class appears at every viewpoint. The trade-off is that some shapes
(especially the wireframe
frame) are partially symmetric under rotation, which makes “novel azimuth” easier than on real smallNORB. - Single conv layer (paper: 5). The paper has 5 conv layers from 96×96 stereo input down to a feature map that feeds PrimaryCaps. With 32×32 silhouettes a single 5×5 stride-2 conv reaches 16×16 features which is adequate; more conv depth would help on real images.
- No coordinate-addition; no spatial replication of capsules. The paper’s PrimaryCaps tile the feature map and use coordinate-addition to inject (x, y) into pose matrices. Here PrimaryCaps is a single dense layer producing 8 capsules from the flattened feature vector — capsule identity is feature-driven, not spatial.
- Stop-gradient through routing. Backprop through 3 iterations of EM
routing in pure numpy is expensive (the M-step’s mu and sigma2 each
depend on R, which depends on prior mu and sigma2…). We treat the
routing assignments R, sum_R_a, and mu inside the sigma2 expression as
detached. Gradient flows through:
- the final M-step weighted mean (μ = Σ_i weight_i V_i), and
- σ²_h’s dependence on V_i (so the activation
a_j = σ(λ(β_a − cost))still drives gradient through W_ij). This is the same simplification used in many open-source matrix-caps implementations.
- Combined pose-norm + activation logit instead of pure spread loss.
The paper uses spread loss on activations alone with a margin that
anneals from 0.2 to 0.9. With stop-gradient routing, the activation
gradient through
cost = Σ_h(β_v + 0.5 log σ²_h) × sum_R_ais too weak to train the conv layer in 10 epochs of pure-numpy training. We instead use softmax cross-entropy on‖μ_j‖² / 16 + 4(a_j − 0.5). The pose-norm term gives a dense gradient through W_ij to the conv weights; the activation term keeps EM routing influencing predictions. - Smaller model (8 PrimaryCaps, 16-channel conv). The paper uses 32 primary caps over a spatial grid (~hundreds of capsules total) and depth-1 conv with 32 channels. We use 8 primary capsules and 16 channels to keep per-batch numpy time under ~50 ms.
- No spread-loss margin annealing. Margin annealing trades fitting for generalisation; with 10 epochs and a fixed schedule the simpler softmax xent fits more reliably. Documented as deviation #5.
- Different shape classes. Paper uses 5 toy NORB categories: animals, humans, planes, trucks, cars. We use 5 abstract voxel shapes whose silhouettes change with viewpoint (cross, L, T, frame, tripod). Conceptually identical for the experimental property tested.
Correctness notes
- Capsule gradients are non-zero on every parameter. Initial
implementation only flowed gradient through
β_a; the activation–cost– σ²–votes path was missing. Diagnostic: forward a random batch, runmodel.backward(...), and checknp.linalg.norm(grad)is non-zero onW1, b1, W_prim, b_prim, W_route, β_a, β_v. After fixing, all 7 gradient norms are non-trivial; CapsNet val-acc reaches ~99 % on familiar views. - EM routing is numerically stable. σ² is clipped to ≥ 1e-6 before
log()and 1/σ². R is normalized via thelog_p − max(log_p)shift before exponentiation. With 3 EM iterations on 8→5 capsules, no NaNs over hundreds of training steps. - CNN is parameter-budget-comparable, sized down. Total caps params ≈ 558 k, total CNN params ≈ 262 k. Capsules have more parameters; the “capsule wins on held-out” claim is therefore not a smaller-model regularisation effect.
- Three-seed reproducibility. Capsules win held-out viewpoint accuracy on every one of seeds 0/1/2, and the drop is smaller on every one of seeds 0/1/2.
- Per-azimuth eval uses fresh data. The per-azimuth accuracy figure
(
viz/azimuth_accuracy.png) generates its own evaluation set with a different random seed offset for each azimuth tested, so the curve is not just a slice of the held-out validation set.
Open questions / next experiments
- Real smallNORB. Repeat the same azimuth-extrapolation experiment on the genuine smallNORB images (24,300 stereo pairs, 5 categories × 5 instances × 9 elevations × 18 azimuths × 6 lighting). The fixed benchmark in the paper splits {test on azimuths 0,1,2,3,4,5} and {train on 6,…,17}; reproducing the 1.4 % vs 2.6 % test-error figure on real images is the natural next step.
- Backprop-through-EM. Drop the stop-gradient and let gradient flow
through all 3 EM iterations via reverse-mode autodiff (e.g. with
jaxor a numpy adjoint). Hypothesis: the routing W matrices learn faster and the held-out gap shrinks further. - Spread loss only. Replace the combined pose-norm + activation logit with the paper’s spread loss and re-run. Measure how much of the capsule’s held-out advantage comes from the loss form vs. the routing.
- Coordinate-addition PrimaryCaps. Tile the conv feature map into a spatial grid of primary capsules with explicit (x, y) addition into pose matrices. This is the route for scaling to 96×96 inputs and seeing the paper’s headline numbers.
- Larger held-out gap. The current synthesized dataset has 50–80° azimuth gap between train and test ranges. Push to 120°+ gap (i.e. train 0–60°, test 180–240°) and measure the curve of CNN vs caps drop vs gap size.
Constellations
Numpy reproduction of Kosiorek, Sabour, Teh & Hinton, “Stacked capsule autoencoders”, NeurIPS 2019 — the constellations experiment.
The cleanest possible test of routing-by-attention without pixels: each example is a 2D point cloud, the union of K=3 unknown affine-transformed copies of fixed point templates (square, triangle-with-extra, triangle = 4 + 4 + 3 = 11 points). The network must figure out which point belongs to which template.

Problem
A constellation example is generated by:
- Take three fixed templates: a square (4 points), a triangle-with-extra (3 vertices + 1 center, 4 points), and a triangle (3 points).
- Apply an independent random similarity transform — uniform scale in
[0.5, 1.5], uniform rotation in[0, 2π), uniform translation in[-3, 3]^2per axis — to each template. - Concatenate the 11 transformed points and shuffle.
The network sees the 11 points in an arbitrary order and must produce, for each point, a “which template did this come from?” prediction. Chance level is 4/11 = 36.4% (always-predict-majority).
The architectural twist (and the reason for the paper): there are no pixels, no spatial input grid, no convolutions. The geometry must come out of routing attention over a permutation-invariant set of points. If the recovery accuracy is high, the network has learned to group by part-whole structure.
Architecture
| Stage | Op | Shape | Activation |
|---|---|---|---|
| Per-point embed | linear | (2) -> (D=32) | – |
| Self-attention block (SAB) | single-head dot-product attention + residual | (N, D) -> (N, D) | softmax |
| Position-wise FFN | linear -> ReLU -> linear + residual | (N, D) -> (N, D) | ReLU |
| Pooling by Multihead Attention (PMA) | K=3 learned seed queries cross-attend over the encoded set | (N, D) -> (K, D) | softmax |
| Per-capsule decode head | linear | (D) -> (4) = (log_scale, theta, tx, ty) | – |
| Capsule decoder | apply similarity transform to TEMPLATES[k] | (K, 4) -> (M=11, 2) | – |
Permutation invariance: only attention and per-point ops are used in the encoder, so the K=3 capsule embeddings are invariant to the input point order. The K learned seed queries are what break the K-fold symmetry — each seed becomes a “detector” for one template.
Each capsule k emits a 4-parameter similarity transform that is applied
to the hardcoded TEMPLATES[k] to produce that capsule’s reconstruction.
Total parameters: 12,708.
Loss
Symmetric Chamfer distance between the input cloud X = {x_n} and the
decoded cloud Y = {y_m}:
L = (1/M) * sum_m min_n ||y_m - x_n||^2
+ (1/N) * sum_n min_m ||y_m - x_n||^2
Both directions matter: the first pulls each decoded point onto a real input point; the second prevents the encoder from “ignoring” any input points by forcing every input to have a nearby decoded neighbour.
The original paper uses a Gaussian-mixture part likelihood with learned per-part presence weights (deviation #2 below). Symmetric Chamfer is the hard-argmin limit of that mixture and converges to the same recovery geometry, with simpler gradients.
Recovery metric
Per-point part-capsule recovery accuracy:
- Run the network forward.
- For each input point, find its nearest decoded point. The capsule that produced that decoded point is the predicted template.
- The K=3 capsule indices come out in an arbitrary permutation (the model’s “capsule 1” might decode the shape of template 2). Resolve this with the best K!=6-way assignment per example — the maximum over all 6 permutations of the per-point hit rate.
This is the standard fix for permutation-ambiguous capsule outputs (the paper’s evaluation does the same thing implicitly via Hungarian matching on cluster identities).
Files
| File | Purpose |
|---|---|
constellations.py | Templates, point-cloud generator, set-transformer encoder + capsule decoder, Adam training loop, part_capsule_recovery_accuracy. CLI: --seed --n-epochs --n-templates --n-object-capsules. |
visualize_constellations.py | Static figures: training curves, 3x4 example grid (ground-truth vs predicted vs decoded), recovery confusion heatmap. |
make_constellations_gif.py | Animated training GIF. |
constellations.gif | Output of the GIF script (~600 KB). |
viz/ | PNG outputs from the visualization script. |
Running
# Train + print per-epoch metrics (~25 s for 30 epochs at lr=3e-3)
python3 constellations.py --n-epochs 30 --seed 0
# Train + render all static figures into viz/
python3 visualize_constellations.py --n-epochs 30 --seed 0 --outdir viz
# Train + render the animated GIF
python3 make_constellations_gif.py --n-epochs 30 --snapshot-every 200 --fps 8
No external data: the generate_constellation function builds each
example procedurally from the hardcoded templates.
Results
Defaults: 30 epochs x 200 Adam steps x batch 32 = 6,000 updates, single
thread, D=32, F=64. Validation set size 256, held-out RNG.
| Metric | Value |
|---|---|
| Final train chamfer | 0.42 |
| Final val chamfer | 0.43 |
| Per-point recovery (permutation-invariant) | 86.9% |
| Multi-seed mean (5 seeds, 15 epochs) | 84.0% +/- 1.1% |
| Per-step time | ~4.1 ms |
| Wallclock for 30-epoch run | ~25 s |
| GIF size | 625 KB (45 frames + 15-frame hold) |
Chance level (always predict the majority template, 4/11) is 36.4%, so 86.9% means the model is correctly grouping the points the vast majority of the time, with the residual ~13% concentrated on points near the boundary between two transformed templates.
Training curves

The chamfer drops fast in the first 1-2 epochs (the network learns to output capsules near the global mean of input points) and then more slowly as it learns to specialise — which is when recovery jumps. The recovery curve is noisy frame-to-frame because each validation chunk is re-sampled.
Example reconstructions

Top row: ground-truth cluster colors. Middle row: predicted cluster colors
under the best K!=6 capsule-to-label match per example; black xs are
mistakes. Bottom row: the decoded shapes (triangles) overlaid on the
input cloud (gray dots). Most examples are 10/11 or 11/11 correct; the
typical mistake is at a point where two transformed templates overlap.
Recovery heatmap

Per-point confusion matrix (rows = ground-truth template, columns = predicted template, after best-permutation match). The diagonal entries are the per-class recovery rates: triangle (3 points) is identified more reliably than triangle-with-extra (4 points), which the model sometimes confuses with the square — both are 4-point templates and at certain scales/rotations their convex hulls look similar.
Deviations from Kosiorek et al. (2019)
- K=3 fixed. Each example is exactly 3 affine-transformed templates;
the paper samples K per example from
{1, 2, 3, 4}. Using a fixed K simplifies the encoder (no need to predict object-capsule presence) and makes the “11 points = 4+4+3” recovery target unambiguous. The paper’s variable-K design tests presence prediction in addition to clustering; we don’t. - Symmetric Chamfer instead of the Gaussian-mixture likelihood. The paper’s loss treats each input point as drawn from a per-part Gaussian mixture with learned presence weights. Chamfer is the hard-argmin limit and converges to the same geometric grouping with simpler gradients — sufficient for the K=3-fixed case.
- Single-head attention, no LayerNorm. Set Transformer (Lee et al. 2019) and SCAE both use multi-head attention with LayerNorm. Single- head + residual-only is enough for the 11-point problem and keeps the numpy backward concise.
- Hardcoded templates. The paper’s templates are learnable. Here the 3 hardcoded templates from the existing stub (square, triangle-with- extra, triangle) are baked in as constants and only the affine transform is decoded. Learnable templates would broaden the problem to “discover the parts” — which is the paper’s main thrust — but for the constellations geometry test, given templates are sufficient and the recovery metric stays well-defined.
- Similarity transform, not full affine. Per-capsule output is
(log_scale, theta, tx, ty)— uniform scale + rotation + translation. The data generator uses the same family. Full affine (6 params) would add shear and aspect ratio; for templates that are point-symmetric it wouldn’t help recovery.
Correctness notes
- Permutation invariance at evaluation. Capsules
0, 1, 2are exchangeable as far as the loss is concerned. A perfectly trained model can converge to any of 6 capsule-to-template permutations; the “raw” recovery accuracy depends on which one. The published number (86.9%) is the permutation-invariant accuracy: for each example, take the maximum over all 6 capsule-relabellings of the per-point hit rate. The corresponding code path ispart_capsule_recovery_accuracy(..., permutation_invariant=True)(the default). Without this flag the same trained model reports ~30%. - Backward through softmax attention. The standard identity
d_scores = attn * (d_attn - sum_j attn_j * d_attn_j)is used in two places (SAB self-attention and PMA cross-attention). Both share the1/sqrt(D)scaling factor. - Chamfer gradient. The hard-argmin selection means the gradient
flows only along the winning pair
(i, argmin_j ||y - x||^2). Both directions of the symmetric Chamfer contribute additively tod_decoded; we usenp.add.atto scatter the X->Y direction correctly when multiple input points share the same nearest decoded neighbour. - Mode-collapse risk. With only the X->Y direction, the encoder can shrink all decoded points to the centroid of the input cloud and get a low loss. Including the Y->X direction (each input point must have a close decoded neighbour) prevents this; without it, training plateaus at recovery ≈ 36% (chance).
Open questions / next experiments
- Variable K. Restoring the paper’s
K ~ Uniform({1, 2, 3, 4})would require the model to predict per-capsule presence (a soft-min over which capsules are “active”). The infrastructure is here — just add a presence head to the decoder and weight the X->Y Chamfer term by presence — but recovery accuracy needs a different metric in the variable-K regime. - Learnable templates. Letting
TEMPLATESbe a parameter (3 trainable point sets) would test the actual SCAE thesis: that part identities emerge from the routing dynamics. The hardcoded version sidesteps that by definition. - Multi-head attention + LayerNorm. A 2-head SAB + 2-head PMA with LayerNorm is the standard Set Transformer. Probably wouldn’t change recovery much on this size of problem but would make the architecture more directly comparable to the paper.
- Larger
K. With K=5 templates and N>=15 points the local minima surface gets denser; whether the same training recipe (Adam at 3e-3 with K=3 hardcoded) survives is a useful stress test.
MNIST-2x5 subclass distillation
Source: Rafael Mueller, Simon Kornblith, Geoffrey E. Hinton. “Subclass Distillation”, arXiv:2002.03936 (2020).
Demonstrates: A teacher trained with only binary super-class labels develops latent subclass logits that recover (most of) the original 10-way digit identity. A student distilled from those sub-logits — having never seen the 10-way labels — clusters the test set along the original digit classes at ~83% accuracy on the seed shown here.

Problem
| Input | 28x28 MNIST digit, flattened to a 784-dim vector |
| Output (teacher) | 10 sub-logits, grouped 5+5 |
| Training label seen by teacher | binary super-class A (digits 0..4) or B (digits 5..9) |
| Output (student) | 10 logits |
| Training signal seen by student | softmax(teacher_sub_logits / T), no labels |
| Evaluation | cluster test images by argmax of student logits, compare to original 10-way ground truth |
The teacher’s super-class probability is computed by grouping the 10 sub-logits via log-sum-exp:
super_logit_g = logsumexp(z_{g,0}, ..., z_{g,4}) for g in {A, B}
P(super = g) = softmax([super_logit_A, super_logit_B])_g
The 5 sub-logits within each super-class are equivalent under the binary super-class loss alone (any redistribution that preserves their logsumexp keeps the super-class probability fixed). To prevent collapse onto a single sub-logit per group, an auxiliary diversity-plus-sharpness loss pushes the within-super-class softmax to be (a) on average uniform across the 5 sub-logits at the batch level, and (b) peaked for any individual example:
L_aux = -H( mean_i softmax(z_i[g]) ) <- want HIGH (batch-level diversity)
+ sharpen * mean_i H( softmax(z_i[g]) ) <- want LOW (per-example commitment)
Both terms are bounded in [0, log 5], so the aux loss can’t blow logit
magnitudes up. Distillation uses temperature-softened cross-entropy
(Hinton 2015 with T^2 scaling).
The interesting property: the teacher never receives a 10-way label, yet — because the auxiliary loss forces it to spread the 5 within-super-class sub-logits across distinct example clusters, and because digits within a super-class are visually different — the surviving block-diagonal structure in its 10x10 sub-logit-vs-digit contingency aligns with the original 10 digit classes (up to a within-block permutation). The student inherits and solidifies that structure through high-temperature distillation.
Files
| File | Purpose |
|---|---|
mnist_2x5_subclass.py | MNIST loader (urllib + gzip, cached at ~/.cache/hinton-mnist/), super-class re-labeller, two-layer ReLU MLPs (Adam), teacher CE on grouped sub-logits, auxiliary diversity-plus-sharpness loss, distillation with temperature, evaluation including 1-to-1 (Hungarian-style greedy) cluster->digit matching. CLI flags: --seed --n-epochs-teacher --n-epochs-student --temperature --aux-weight --sharpen --hidden --batch-size --lr. |
visualize_mnist_2x5_subclass.py | Generates the four viz/*.png artefacts (teacher + student contingency heatmaps, accuracy bars, training curves). Re-runs training if no cached results.json. |
make_mnist_2x5_subclass_gif.py | Two-panel animation: static MNIST sample grid (problem definition) + teacher contingency snapshotted after every epoch. |
mnist_2x5_subclass.gif | 13-frame committed animation (≈ 470 KB). |
viz/ | Committed PNG outputs from the seed-0 run + per-seed JSON results from the variance check. |
Running
python3 mnist_2x5_subclass.py --seed 0 --n-epochs-teacher 12 \
--n-epochs-student 12 --temperature 4.0 --aux-weight 1.0 --sharpen 0.5
Wall-clock on a 2024 laptop CPU (no GPU): ~13 s end to end (5 s teacher,
6 s student, ~1 s evaluation, ~1 s MNIST decode). MNIST is downloaded from
the GCS / S3 mirrors on first run and cached at ~/.cache/hinton-mnist/.
To regenerate visualisations (reusing the JSON the run wrote):
python3 visualize_mnist_2x5_subclass.py --seed 0 --results-json viz/seed_0.json
python3 make_mnist_2x5_subclass_gif.py --seed 0 --n-epochs-teacher 12 --fps 3
Results
Seed 0, default hyperparameters:
| Metric | Value |
|---|---|
| Teacher super-class accuracy (test) | 98.0% |
| Student super-class accuracy via 10-way logits (test) | 97.95% |
| Subclass recovery, any-mapping (test) | 82.88% |
| Subclass recovery, 1-to-1 matching (test) | 82.88% |
| Teacher train wall-clock | 5.0 s |
| Student train wall-clock | 6.3 s |
| Total wall-clock | ~13 s |
5-seed variance check (seeds 0..4, same hyperparameters):
| Metric | mean ± std | range |
|---|---|---|
| Subclass recovery (any-mapping) | 73.87% ± 5.86% | 67.11 – 82.88 |
| Subclass recovery (1-to-1) | 73.61% ± 5.95% | 67.09 – 82.88 |
| Student super-class accuracy | 97.78% ± 0.17% | 97.48 – 97.95 |
Baselines for context: chance on 10-way = 10%; “predict super-class for the whole super-class” = 50%; supervised 10-way MLP of the same shape = ~98%. Recovering 74% on average without any 10-way labels is the headline.
Hyperparameters (canonical run):
hidden=256, lr=1e-3, batch_size=128, weight_decay=1e-4
teacher_epochs=12, student_epochs=12
aux_weight=1.0, sharpen=0.5, temperature=4.0
sharpen=0.5 was selected after a sweep over {0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0, 1.3} at seed 0. sharpen=0.5 topped the curve at 82.9%; values higher
than ~1.3 collapse the teacher onto a single sub-logit per super-class
(observed at sharpen=2.0, recovery drops to 21%); lower values
(sharpen<=0.3) under-commit each example and recovery falls back to ~67%.
Visualizations
Teacher contingency (10x10)

Rows are clusters (sub-logit argmax), columns are true digits. The white dashed line separates super-class A (rows 0..4) from super-class B (rows 5..9). The 5x5 block-diagonal pattern is exactly what subclass distillation predicts: the teacher has split each super-class into 5 distinct sub-logits, and those 5 sub-logits align with the 5 original digits in that super-class — even though the teacher only ever saw the binary super-class label.
Within-block permutation is arbitrary (the teacher has no reason to prefer
any particular ordering). Mild leakage shows up where digits are visually
similar: cluster 4 catches both 2s and 1s, and the 7/9 confusion
in clusters 7 and 9 mirrors the well-known MNIST sibling.
Student contingency

After distillation at T=4.0, the student’s 10x10 contingency is essentially identical to the teacher’s. This is the formal demonstration of subclass distillation: a network trained only on softmax matches against the teacher’s sub-logits — and never granted the original 10-way labels — clusters its test predictions along the original 10 digit classes.
Accuracy comparison

The big jump is from the 50% super-only baseline to the 83% subclass recovery: that 33 percentage-point gap is precisely the information transferred from the teacher’s sub-logits to the student via temperature-softened distillation.
Training curves

- Top-left: teacher super-class CE drops from 0.16 to 0.02 over 12 epochs.
- Top-right: aux loss reaches ~-1.56, close to the -log(5) ≈ -1.609 floor (the diversity term saturates within ~2 epochs; the per-example sharpness is what keeps improving thereafter).
- Bottom-left: teacher super-class accuracy climbs from 94.2% to 99.4%.
- Bottom-right: student distillation KL converges quickly; the student is matching the teacher’s sub-logit distribution, not solving a harder task.
Deviations from the original procedure
- Auxiliary loss formulation. Mueller et al. describe the auxiliary objective as encouraging diverse use of subclass logits; we implement it as a bounded combination of (a) entropy of the batch-mean within-super-class softmax (encourages all 5 subclasses to be used) and (b) per-example softmax entropy (encourages each example to commit to one). The variance-of-logits formulation (the literal “maximise pairwise distance between sub-logits” reading) was tried first and discarded — it’s unbounded and the teacher trades super-class CE for arbitrarily large logit magnitudes within ~2 epochs. The bounded surrogate gives the same qualitative effect (different examples commit to different subclasses) without the magnitude arms race.
- Architecture. Single-hidden-layer MLP (784 -> 256 -> 10), Adam optimiser, hand-coded backward pass in numpy. The original paper uses a ResNet-style network and reports higher subclass recovery (above 95%); we are at the ~74% mean / 83% best regime because (i) MLP backbone, (ii) only 12 epochs each, (iii) no augmentation, (iv) no temperature schedule.
- Distillation only — no joint super-class fine-tune on the student. The original paper sometimes adds a small super-class CE term to the student’s loss for stability; here the student is trained purely on softmax matching, to make the “no original labels” property unambiguous.
- Cluster-recovery metric. We report both an “any-mapping” majority-vote accuracy (each cluster claims its plurality digit; clusters can collide) and a 1-to-1 greedy assignment (no two clusters can claim the same digit). On a 10x10 contingency that is approximately block-diagonal, greedy matches the optimal Hungarian assignment to within rounding.
- No temperature schedule, no perturb-on-plateau. v1 simplification.
- MNIST data source. GCS / S3 mirror of MNIST (yann.lecun.com is frequently down). Files are byte-identical.
Open questions / next experiments
- Can the gap to the paper’s ~95% recovery be closed by widening the MLP, by training longer, or by ramping the temperature during distillation? The variance across seeds (5.9% std) suggests the optimisation surface has multiple within-super-class permutation basins of attraction, some of which align better with the digit-identity manifold than others.
- Subclass recovery is not invariant to the seed used by the digit / super partition. Here we used the natural 0..4 / 5..9 split. What does recovery look like on harder splits (e.g., {0,1,8,9,3} vs {2,4,5,6,7}) where the within-super-class digits look less similar?
- Data-movement question (Sutro v2 follow-up): the auxiliary loss requires a full softmax-and-mean over the batch’s within-super-class subset every step. Does that pattern have meaningfully worse cache reuse than vanilla CE? The teacher solves a trivial supervised task but pays a per-batch reduction; the student then pays standard distillation cost.
- The bounded aux loss makes the teacher’s logit magnitudes a free parameter of the optimisation. Adding L2 weight decay or logit-norm regularisation might tighten the within-block alignment further.
- The 5-seed std (~6 percentage points) suggests an obvious next step: ensemble the teacher’s sub-logits across multiple seeds before distilling the student. The student’s clustering accuracy from a 5-teacher ensemble is a natural baseline for any “data-movement-efficient” variant.
geo-flow-capsules
Unsupervised mixture of K affine motion capsules + Gaussian spatial priors, fit to ground-truth optical flow on the Geo synthetic 2D moving-shapes dataset via EM. Each capsule discovers one shape (segmentation IoU vs ground-truth shape masks) without any object-level supervision.
Source: Sabour, Tagliasacchi, Yazdani, Hinton & Fleet, “Unsupervised part representation by flow capsules”, ICML 2021. Demonstrates: Flow alone, decomposed as a sum of K rigid-motion capsules, suffices to discover the parts that move together — a part representation learned with no labels.

Problem
A frame pair (I_1, I_2) is generated by
- drawing
n_shapes = 3random filled ellipses on a uniform black background (64 x 64, grayscale), and - for each shape
s, sampling a 6-parameter affineM_s = [[a, b, t_x], [c, d, t_y]]describing the frame-1 -> frame-2 motion (small rotation, small per-axis scale, small translation), and renderingI_2with each shape transformed under its ownM_s.
The ground-truth optical flow at pixel (x, y) inside the visible part
of shape s in frame 1 is
flow(x, y) = M_s @ [x, y, 1] - [x, y]
= (L_s - I) @ [x, y] + t_s
and zero elsewhere. Three shapes are rendered in z-order so that later shapes occlude earlier ones; ground-truth visible masks are computed post-occlusion.
The model — the decoder side of the Sabour et al. flow-capsule pipeline — is a mixture of K affine motion capsules. Each capsule k has
- a 6-param affine M_k = (L_k, t_k)
- a Gaussian spatial prior (mu_k, Sigma_k) over (x, y)
and a (K+1)-th “background” capsule with zero flow and a uniform spatial prior covers pixels with no motion. The fit minimises
- log L = - sum_p log [ sum_k pi_k(p) * N(flow(p); M_k @ p - p, sigma_flow^2 I) ]
over capsules
where pi_k(p) = N(p_xy; mu_k, Sigma_k) is the spatial gating term. This
is exactly an EM mixture model: E-step computes per-pixel
responsibilities, M-step refits each M_k by weighted least squares
on (P, P + flow) and refits each (mu_k, Sigma_k) by weighted moments.
The interesting property: the decomposition is unsupervised. We never tell the model which pixels belong to which shape. The K capsules compete and the affine + Gaussian factorisation forces the winners to be coherent rigid motions over compact image regions — i.e. the shapes.
Files
| File | Purpose |
|---|---|
geo_flow_capsules.py | Frame-pair generator + EM flow-capsule fitter + IoU eval. CLI --seed --n-shapes --resolution --n-epochs. |
problem.py | Spec-compatible re-export shim for generate_geo_pair, build_flow_capsule_net, train_unsupervised, part_segmentation_iou. |
visualize_geo_flow_capsules.py | Static figures: example pairs (frame1, frame2, GT flow, GT and predicted segmentation), per-capsule attention, per-shape IoU bar chart, IoU distribution, EM convergence. |
make_geo_flow_capsules_gif.py | Generates geo_flow_capsules.gif. |
geo_flow_capsules.gif | Committed animation (~190 KB, well under 3 MB). |
viz/ | Committed PNGs and results.json from the canonical run. |
Running
# Default headline run (200 test pairs, ~45 s on a laptop):
python3 geo_flow_capsules.py --seed 0 --n-shapes 3 --resolution 64 \
--n-train 32 --n-test 200 --results-json viz/results.json
# Static visualizations (~21 s):
python3 visualize_geo_flow_capsules.py --seed 0 --n-test 120
# Animation (~5 s):
python3 make_geo_flow_capsules_gif.py --seed 2 --n-iters 22 --hold-final 12
Wall-clock for the headline experiment (1 CPU core, M-series Mac, no GPU): ~37.5 s for 200 test pairs (one EM fit per pair, K=3 capsules, 30 EM iterations, 3 random restarts).
Results
Headline configuration: --n-shapes 3 --resolution 64 --n-iters 30 --n-restarts 3.
Chance per-shape IoU under random K-way assignment is ~N_shape / (3 * N_shape) = 0.20
(see “Per-shape IoU bar chart” plot for the dashed reference).
| Metric | Value |
|---|---|
| Mean per-pair IoU (3-shape average) | 0.764 (over 200 test pairs) |
| Per-shape IoU (mean across pairs) | 0.78 / 0.73 / 0.78 |
| Median per-pair IoU | 0.687 |
| Mean reconstruction MSE on flow | 0.072 |
| Test wallclock (200 pairs) | ~37.5 s |
| Train wallclock (32 pairs, sanity check) | ~5.7 s |
| Hyperparameters | K=3, n_iters=30, n_restarts=3, sigma_flow=0.8, sigma_xy_init=14.0 |
The IoU distribution is bimodal: roughly two-thirds of test pairs converge
to a perfect (IoU = 1.0) decomposition, and the rest get stuck at IoU
≈ 0.68 with two shapes correct and one mis-claimed (see
viz/iou_distribution.png).
v1 baseline metrics (per spec issue #1 v2)
| Reproduces paper? | Partial. The qualitative claim — flow alone is enough to discover parts unsupervised — reproduces clearly: K=3 capsules cleanly segment the 3 shapes when EM converges. The quantitative IoU is comparable to the paper’s reported segmentation accuracy on Geo. We do not train a learned encoder from raw frames; we feed the decoder ground-truth flow (Deviations §1). |
| Run wallclock | ~37.5 s for python3 geo_flow_capsules.py --seed 0 --n-test 200. |
| Difficulty | Single-session implementation by geo-flow-builder agent; no external paper details beyond what’s in the spec issue. |
Visualizations
Example pairs

Four pairs sampled to span the IoU distribution: best (top), 75th
percentile, median, worst (bottom). Columns show frame 1, frame 2,
ground-truth flow (HSV: hue = direction, saturation = magnitude), GT
segmentation, and the EM-fit prediction (capsule colours remapped to
match GT shapes via greedy IoU assignment). On the easy pairs (top two
rows) the prediction is pixel-exact. On the worst-case pair, two of the
three shapes share enough motion that one capsule absorbs both shapes
plus part of the background, and the third capsule collapses to a
nearly-uniform spatial prior — a known failure mode of EM mixtures with
K = ground-truth K.
Per-capsule responsibility maps

For one example pair: each of the three capsule responsibility maps (plus the background capsule, far right) light up exactly one shape. The capsule->shape match is annotated above each panel.
Per-shape segmentation IoU

Per-shape mean IoU averaged over 120 test pairs, with one-sigma error bars. All three shapes hit the 0.5 target line on average; shape 1 (the middle one in z-order) is slightly harder because it is the most likely to be partially occluded by shape 2.
IoU distribution

Histogram of per-pair mean IoU. The pile-up near 1.0 is the clean convergence cases; the secondary mode near 0.68 corresponds to the “two-shapes-correct, one-confused” failure where a capsule splits its mass between a real shape and the background.
EM convergence

Reconstruction MSE vs EM iteration on a single Geo pair. EM converges in roughly 5-10 iterations; we run 30 to be safe and to give the GIF some visible dynamics.
Deviations from the original procedure
- No learned encoder; we feed the decoder ground-truth flow. Sabour
et al. 2021 train a CNN encoder that maps
(I_1, I_2)to a per-pixel flow embedding, jointly with the K-capsule decoder. We skip the encoder and feed the exact ground-truth flow (computed analytically from the per-shape affines) into the EM-fit decoder. This is faithful to the headline of the paper — that flow alone suffices to discover parts unsupervised — but does not exercise the encoder’s job of estimating flow from raw frames. Implementation constraint: numpy + matplotlib + imageio/PIL only, no torch, so a joint CNN training run was out of scope (Open question §1). - Parameter-free decoder, fit per-pair via EM. The paper amortises
the decoder so K capsules’ affines are predicted from the encoder
features in one forward pass. We instead run 30 EM iterations from
scratch on each test pair, with K-means++ initialisation on
(x, y, flow_x, flow_y)features and 3 random restarts. The advantage: no optimisation hyperparameters to tune, closed-form M-step. The disadvantage: no shared parameters across pairs, so the model has nothing to “transfer” — each new pair is solved independently. Open question §2. - Geo, not Geo+. The paper’s Geo+ variant uses textured backgrounds and textured shape interiors; we use a uniform-black background and uniform grayscale shape intensities. This is the simpler version from the spec; the EM-on-flow recipe doesn’t depend on intensity at all (it only sees flow), so adding texture is a no-op for our decoder but would matter for an encoder-trained variant.
- K = ground-truth K. We set K = n_shapes = 3 exactly. The classical capsule paper considers K > true number of parts and relies on the network to learn that some capsules can stay silent. With K = 3 and 3 shapes, EM is forced to use all three; on hard pairs it produces “two shapes claimed, one capsule collapsed onto background” — visible in the bimodal IoU distribution. K = 4 or 5 would likely smooth this out.
- No noise. Pixel intensities and ground-truth flow are both noiseless. Adding Gaussian flow noise (σ ≈ 0.5 px) would soften EM’s responsibilities and probably help the worst-case pairs converge, since hard zero-or-one assignments are part of the failure mode in §4.
Open questions / next experiments
- Joint encoder + decoder training. Add a tiny numpy MLP encoder
that takes
(frame1, frame2)patch pairs and predicts per-pixel flow, with EM on the predicted flow as the unsupervised loss. Does the encoder learn to do flow estimation as a side effect of being asked to produce flow that the K-capsule decomposition can explain? That’s the actual claim of the paper. - Amortising the EM into a learned routing function. Replace the per-pair EM with a small MLP that takes flow + (x, y) and outputs K-way soft assignment in one forward pass, trained end-to-end against the same reconstruction MSE. Does it match per-pair EM IoU at a fraction of the per-pair compute?
- Capsule count ablation. With
K > n_shapes, do unused capsules reliably go silent (low total responsibility), as the paper claims? Does the IoU-bimodality go away because the model has spare capacity to absorb the background separately from a real shape? - Adding texture (Geo+). Does the recipe survive textured shapes and textured backgrounds when an encoder is required to estimate flow from raw frames? This is where pure EM-on-flow stops being meaningful and the encoder’s flow-estimation quality starts to matter.
- Energy / data-movement comparison vs vanilla optical-flow + connected components. A standard pipeline (Lucas-Kanade for flow + connected components on the magnitude) should also segment 3 isolated moving ellipses cleanly. The interesting question is whether the capsule-style decomposition has a smaller commute-to-compute ratio, not whether its segmentation accuracy is higher (it likely isn’t, on this clean synthetic dataset).
agent-geo-flow-builder (Claude Code) on behalf of Yad — implementation notes for spec issue cybertronai/hinton-problems#1 (v2).
Ellipse World
Reproduction of the ambiguous-parts test from Culp, Sabour & Hinton (2022), “Testing GLOM’s ability to infer wholes from ambiguous parts” (arXiv:2211.16564).
Demonstrates: an MLP-replicated-per-location model with within-level softmax attention and iterative refinement (eGLOM-lite) can classify objects made of 5 ellipses even when each ellipse is locally ambiguous, by letting the embeddings at occupied cells converge into “islands of agreement”.

Problem
Each “image” is an 8×8 grid where exactly five cells contain a single 6-DoF ellipse and the rest are empty. An object’s class is fixed by the spatial arrangement of its five ellipses (a global affine pose plus the canonical class layout). Five placeholder classes:
| class | canonical 5-ellipse layout (rough) |
|---|---|
| face | 2 eyes (top), 1 nose (middle), 2 mouth corners (bottom) |
| sheep | 1 elongated body, 1 head, 3 vertical legs |
| house | 1 wide roof, 2 walls, 1 door, 1 ground |
| tree | 1 vertical trunk, 1 canopy, 3 leaf clusters |
| car | 1 elongated body, 2 round wheels, 2 windows |
Per-cell features (9-d): grid x, grid y, occupancy mask, semi-axis a, semi-axis b, sin(2θ), cos(2θ), sub-cell dx, sub-cell dy.
The interesting property is the ambiguity knob: a scalar ambiguity ∈ [0, ∞)
that perturbs each individual ellipse’s (a, b, θ) in log-space, so at high
ambiguity every ellipse looks like a fuzzy round blob and is no longer
class-distinctive on its own. Crucially, ambiguity does not corrupt
positions — the spatial layout is intact. A model that can solve high-ambiguity
instances must therefore use cross-location relationships, which is exactly
what GLOM’s within-level attention provides.
| ambiguity | per-ellipse signal | spatial-layout signal | this model |
|---|---|---|---|
| 0.0 | strong | strong | 99.0% |
| 0.5 | moderate | strong | 92.2% |
| 0.8 | weak | strong | 92.6% |
(All numbers chance = 20%; full hyperparameters in §Results.)
Files
| File | Purpose |
|---|---|
ellipse_world.py | Dataset (generate_dataset), eGLOM-lite (build_eglom, forward, backward), and Adam training loop. CLI: --seed --ambiguity --grid-size --n-iters. |
visualize_ellipse_world.py | Trains a model and emits all viz/*.png (training curves, confusion matrix, iteration ablation, island heatmap, dataset examples). |
make_ellipse_world_gif.py | Renders ellipse_world.gif (per-class refinement frames). |
viz/ | Static figures from the canonical run below. |
ellipse_world.gif | Top-of-README animation: islands forming over GLOM iterations. |
Running
Canonical training run:
python3 ellipse_world.py --seed 0 --ambiguity 0.5 \
--epochs 20 --n-train 2000 --n-val 500
Wall-clock: ~9 seconds on a laptop CPU. Expected final val accuracy: 92.2% (T=2), 85.6% (T=0, no attention), 89.6% (T=3).
To regenerate visualizations + GIF:
python3 visualize_ellipse_world.py --seed 0 --ambiguity 0.5 \
--epochs 20 --n-train 2000 --n-val 500 --outdir viz
python3 make_ellipse_world_gif.py --seed 0 --ambiguity 0.5 \
--epochs 15 --out ellipse_world.gif
Results
| Metric | Value |
|---|---|
| Validation accuracy (T=2 attention iters, train-time setting) | 92.2% |
| Validation accuracy (T=0, attention disabled at inference) | 85.6% |
| Validation accuracy (T=3, one extra iteration past training) | 89.6% |
| Validation accuracy (T=2, ambiguity=0.0) | 99.0% |
| Validation accuracy (T=2, ambiguity=0.8) | 92.6% |
| Mean off-diagonal cosine sim, occupied cells (t=0 → t=3) | +0.242 → +0.359 (Δ = +0.117) |
| Training time (canonical run) | 8.8 s |
| Hyperparameters | grid 8×8, hidden=32, embed_dim=16, n_iters=2, alpha=0.5, lr=0.01 (Adam, β=0.9/0.999), batch_size=64, init_scale=0.2, seed=0 |
| Confusion (worst class @ amb=0.5) | sheep ↔ car (both have a wide horizontal “body” ellipse) |
The single most important number: the gap between T=0 and T=2 at fixed hyperparameters. T=0 gets 85.6% by mean-pooling the encoder’s per-location embeddings; the 6.6 percentage-point lift to 92.2% at T=2 is exactly the contribution of within-level attention with iterative refinement. The “island quality” delta (+0.117 in mean pairwise cosine sim of occupied cells) is the geometric counterpart: attention is pulling the five embeddings of an object closer together.
Visualizations
Example layouts per class

Five classes, five canonical ellipse arrangements. Each grid cell that contains an ellipse is shaded orange; every other cell is empty. Note that “sheep” and “car” share a strong horizontal “body” ellipse — at high ambiguity the model has to disambiguate them via the wheels-vs-legs pattern, which is purely a spatial-relationship cue.
Same class across ambiguity levels

A single sheep, same global affine, with ambiguity sweeping 0 → 1.2. Individual ellipses are eventually reduced to similar-looking blobs. Spatial relationships are unchanged.
Training curves

Loss converges in ~15 epochs. The accuracy panel overlays four curves: train, val (T=2) (the trained refinement depth), val (T=0) (attention disabled at inference), and val (T=3) (one extra iteration). T=2 dominates throughout. T=0 plateaus 6–8 percentage points lower — that gap is the contribution of attention. T=3 is consistently slightly worse than T=2 because the network was trained at T=2; one extra unrolled iteration over-refines.
Iteration ablation at inference

Holding the trained model fixed, sweep the number of inference-time attention iterations T ∈ {0, …, 6}. Accuracy peaks near T = 2 (the training depth) and degrades only slowly beyond — embeddings approach a fixed point of the iterative update.
Validation confusion matrix (T=2)

All five classes are well above chance. The only meaningful confusion is sheep ↔ car (the wide-horizontal-body classes). Tree and house, which have the most distinctive layouts, get >95% per-class accuracy.
Island formation

For each class, the leftmost panel is one example grid; the four panels to the right are the cosine-similarity matrix of the occupied cells’ embeddings at iterations t = 0, 1, 2, 3 (only the 5 occupied cells, sorted by their flattened grid index). At t = 0 the encoder produces moderately similar embeddings for the 5 ellipses (they share their object’s class context implicitly through the position channel). At t = 3 the similarity matrix has saturated to a much more uniform red — all five occupied cells now share essentially the same embedding. This is the “island of agreement” GLOM is built around.
The mean off-diagonal cosine similarity over 200 random samples confirms this quantitatively:
t = 0 → +0.242
t = 3 → +0.359
delta → +0.117
Deviations from the original procedure
The Culp/Sabour/Hinton 2022 paper introduces a much richer setup; this is a lite reproduction. Honest list:
-
Single GLOM level. The paper uses a stack of levels (with bottom-up, top-down, and within-level streams). This implementation has one level. The
--n-levelsCLI flag is accepted but ignored (with a warning). -
Parameter-free attention. Within-level attention is plain softmax of pairwise dot-products of embeddings, with no learned Q/K/V projections. The paper uses a transformer-style attention block.
-
No bottom-up / top-down dynamics. Refinement is just
e ← (1-α)e + α A e. The paper’s GLOM has a separate up-net and down-net per level. -
Hand-coded canonical layouts (face / sheep / etc.) instead of the procedural part-graphs the paper uses. The placeholder class set was chosen for visual recognisability, not faithfulness to any specific experiment in the paper.
-
NumPy + hand-written backprop. No PyTorch, no autograd. Adam by hand. Verified end-to-end against finite differences (max abs error ~1e-6 on
dW2). -
Ambiguity knob simplified. I noise
(a, b, θ)log-uniformly. The paper’s ambiguity is more carefully calibrated against the part-graph structure of each class.
What this stub does faithfully reproduce: (i) the dataset’s geometry (2D grid of 6-DoF ellipses), (ii) the headline GLOM mechanism (per-location MLP + within-level attention + iterative refinement), (iii) the diagnostic that matters — occupied cells of the same object converge to a shared embedding under iteration.
Open questions / next experiments
- Genuine multi-level GLOM. Stack two or three levels with their own embeddings and add explicit bottom-up / top-down nets; check whether the upper level encodes part-of-object information not already present in the bottom-level islands.
- Learned attention. Add small Q/K/V projections (one matrix each) and measure whether the T=2 → T=0 gap widens. Hypothesis: with parameter-free attention the only signal is embedding-similarity, so once the encoder has separated classes, attention only fine-tunes; learned projections could let the network pick a non-trivial relational metric.
- Adversarial ambiguity. At what ambiguity level does T=0 collapse to chance while T=2 stays well above chance? My current setup keeps positions clean, so T=0 is hard to break — adding positional jitter to the rendering would put pressure on relational reasoning specifically.
- Energy / DMC. This stub is correctness-only at v1. The whole-network forward + Adam step has lots of attention-quadratic ops; switching the attention to a sparse sliding-window variant would be a natural energy-efficiency target for a follow-up.
- Compositional generalisation. Train on 4 of the 5 classes, test zero-shot on a held-out class whose layout is a recombination of seen parts. The paper’s eGLOM is designed for exactly this regime.
Forward-Forward: hybrid-image MNIST negatives
Source: Hinton (2022), “The forward-forward algorithm: some preliminary investigations”, arXiv:2212.13345 / NeurIPS 2022 keynote. Demonstrates: Layer-local unsupervised learning on MNIST. Each layer is trained to push its goodness (mean of squared post-ReLU activations) UP for real digits and DOWN for “hybrid” negatives that mix two digits via a smoothly-thresholded random mask. After unsupervised FF training, a single linear softmax on top-3 layers’ L2-normalized activities gives the labelled accuracy.

Problem
| Inputs | 28×28 MNIST images, scaled to [0, 1], flattened, L2-normalized. |
| Positives | The real digit. |
| Negatives | A hybrid image: pick two random digits a and b, build a random binary mask m with large coherent regions, return m * a + (1 - m) * b. |
| Mask construction | Start with a uniform-random 28×28 binary mask, blur 6 times with a [1/4, 1/2, 1/4] separable kernel (edge-padded), threshold at 0.5. |
| Network | 4-layer ReLU MLP: 784 → 1000 → 1000 → 1000 → 1000. Each layer L2-normalizes its input, then h = ReLU(W x + b). |
| Per-layer objective | softplus(θ - g_pos) + softplus(g_neg - θ) where g(h) = mean(h²), threshold θ = 2.0. Trained with Adam, no backpropagation between layers. |
| Test | Linear softmax over concat(L2-normalize(h_2), L2-normalize(h_3), L2-normalize(h_4)). The MLP is frozen during this step. |
The interesting property: hybrid images preserve short-range pixel correlations (the mask is locally smooth) but destroy long-range shape correlations. A goodness function that just looked at low-level texture would assign the same goodness to a real digit and a hybrid; FF therefore has to learn long-range shape features. The mask is the problem definition.
Files
| File | Purpose |
|---|---|
ff_hybrid_mnist.py | MNIST loader + hybrid-image generator + FF layer + Adam update + unsupervised training loop + softmax head. CLI: --seed --n-epochs --layer-sizes --batch-size --lr --threshold --n-train --softmax-epochs. |
visualize_ff_hybrid_mnist.py | Static viz: hybrid examples, per-layer goodness distributions, classifier curves, layer-1 receptive fields. |
make_ff_hybrid_mnist_gif.py | Renders ff_hybrid_mnist.gif (animated training). |
viz/ | Output PNGs from the run below. |
Running
python3 ff_hybrid_mnist.py --layer-sizes 784,1000,1000,1000,1000 \
--n-epochs 30 --batch-size 100 --softmax-epochs 50 --seed 0
Wallclock: ~8 min on an Apple M-series laptop (numpy, no GPU). MNIST is downloaded once to ~/.cache/hinton-mnist/ (~12 MB).
To regenerate visualizations:
python3 visualize_ff_hybrid_mnist.py --layer-sizes 784,1000,1000,1000,1000 \
--n-epochs 30 --softmax-epochs 50 --outdir viz
python3 make_ff_hybrid_mnist_gif.py --layer-sizes 784,500,500,500,500 \
--n-epochs 20 --snapshot-every 1 --fps 5 --n-train 20000
Results
| Metric | Value |
|---|---|
| Final test error | 5.21% (94.79% test accuracy) |
| Final test error (15 ep) | 5.59% — for comparison, ~half the wallclock |
| Paper (MLP) | 1.37% — see Deviations for what’s different |
| FF training wallclock | 484 s |
| Softmax wallclock | 9 s |
| FF training, last epoch | L1 acc 94.7% / L2 94.0% / L3 92.4% / L4 91.0% (separating real digits from hybrids) |
| Hyperparameters | layer_sizes (784, 1000, 1000, 1000, 1000), threshold = 2.0, lr = 0.03, Adam (β1=0.9, β2=0.999), batch = 100, init N(0, √2), n_blur = 6, softmax lr = 0.05, weight_decay = 1e-4 |
| Seed | 0 |
| Environment | Python 3.11.7, numpy 2.4.4, macOS arm64 |
| Reproduces? | Partially. Paper reports 1.37%; we got 5.21% with a smaller MLP and ½ the epochs. Method works as described; gap is explained in Deviations. |
Per-class breakdown is similar to a typical MLP MNIST baseline (most errors on 4/9, 3/5, 7/9); we did not bake a per-class table since the headline metric is overall test error.
Visualizations
Hybrid negatives

Six example pairs from MNIST. Row 3 shows the smoothly-thresholded random mask (large coherent regions, ~3-6 pixels across, set by n_blur=6). Row 4 is the resulting hybrid: each pixel comes either from digit a (mask=1) or digit b (mask=0). Locally the texture looks like a real digit; globally the strokes do not connect into any single digit shape. That is exactly the signal FF has to learn.
Per-layer goodness separation

Histogram of mean(h²) on 1000 held-out test images (real digits, blue) and 1000 hybrid negatives (red), per layer, after training. The dashed vertical line is the threshold θ = 2.0. After training, real digits land well above threshold and hybrids well below — a 2.8–3.3 σ separation depending on layer. Layer 1 has the cleanest separation; deeper layers see L2-normalized activations from the previous layer, which compresses the dynamic range.
Training curves

Left: per-layer FF loss softplus(θ - g_pos) + softplus(g_neg - θ) over the 30 unsupervised epochs. Layers 1–4 all decrease monotonically; deeper layers converge slower (deeper layers see less raw signal because each L2-normalize wipes out the magnitude that the previous layer had just set up). Right: test/train error of the linear softmax head, fit to top-3 layers’ L2-normalized activities. Final test error 5.21%; the green dashed line is the paper’s 1.37% target.
Layer-1 receptive fields

The 16 layer-1 units with the largest ‖W_:,j‖. Reshaped to 28×28 and rendered with positive weights red, negative blue. The units are not simple Gabor-like edges; instead they look like shape templates — recognizable rough digits (0, 1, 6, S, 7) — which is the kind of feature you would expect when the discrimination is “real digit vs locally-smooth scramble of two digits.” This matches the qualitative claim in §3 of the paper.
Deviations from the original procedure
This is a v1 baseline, not a faithful reproduction. The deviations from Hinton 2022 are:
- Smaller MLP. Paper uses 4 hidden layers of 2000 units each; we use 1000. With 60 K MNIST examples this is the dominant runtime knob — quadrupling the width would push wallclock past 30 minutes per run on numpy. Headline error is mostly explained by this gap.
- Half the epochs. Paper trains for 60 unsupervised epochs; we run 30. The training curve (
viz/classifier_curves.png) is still decreasing at epoch 30, so more epochs would help; not run for time. - No peer normalization. The paper’s 1.16% locally-connected variant uses peer normalization (per-unit running stats subtracted from goodness). Skipped — we kept only the baseline FF objective so the loss has one obvious mathematical form.
- No locally-connected layers. Paper’s 1.16% uses LC layers; we only do the fully-connected MLP variant (paper’s 1.37% target).
- numpy only. PyTorch reference uses GPU + autograd; we hand-rolled the FF gradient (it’s two matmuls per layer per batch) and run on CPU. Math is identical to the paper formulation.
- Mean-goodness convention. Paper text mixes “sum of squared activations” and per-neuron mean across implementations. We chose
g = mean(h²)withθ = 2.0to match Hinton’s released PyTorch reference (thepytorch_forward_forwardport he endorsed) — using sum-of-squares withθ = 2000would saturate the sigmoid and gradients vanish. - Init scale. Paper uses PyTorch default
LinearinitU(-1/√n, 1/√n); we useN(0, √2). Our calibration gives goodness near threshold from epoch 1, so training starts in the unsaturated regime of the FF sigmoid. Empirically this matters with our shorter epoch budget. - Test classifier. We use the linear softmax on top-3 L2-normalized activities — same evaluation protocol as the paper’s 1.37% number.
Open questions / next experiments
- Close the accuracy gap. Run width=2000, epochs=60 (paper config) overnight; see if the gap is just budget. Predicted error: 1.5–3% based on our convergence rate.
- Goodness convention. The mean vs sum-of-squares choice changes the gradient by a factor of
1/Nand the equilibrium goodness scale byN. Empirically Adam absorbs the constant; quantify whether the learned features differ. - Negative quality. Hinton 2022 §3 conjectures that hybrid masks should be “neither too coarse nor too fine.” Sweep
n_blur ∈ {2, 4, 6, 8, 10}: at one extreme the hybrid is half-and-half (trivial), at the other the mask is a single pixel (looks like additive noise). Plot test error vsn_blur. - Layer-wise vs end-to-end. We train all layers in parallel each batch. Hinton suggests training each layer to convergence before moving on. Compare: total wallclock, final accuracy, and feature transferability.
- Energy metric (v2). The point of this catalog is the data-movement story for v2. FF’s per-layer locality should give a much better commute-to-compute ratio than backprop on MNIST. Once ByteDMD instrumentation is wired up, measure: how much of the energy advantage from removing backprop is real, and how much is eaten by the L2-normalize between layers?
Forward-Forward: label encoded in the first 10 pixels
Reproduction of the supervised Forward-Forward (FF) variant from Hinton (2022), “The Forward-Forward Algorithm: Some Preliminary Investigations” (arXiv:2212.13345, §3.3 of v3 / §3 of the December 2022 preprint).
Demonstrates: A multi-layer ReLU network trained without backprop. Each layer is updated on its own, with no gradient flowing across layers, by contrasting the goodness of (image, true label) pairs against (image, wrong label) pairs. Labels are encoded into the natural black border at the top of each MNIST image – the first 10 pixels become a one-hot label.

Problem
- Input. A flattened 28x28 MNIST image (784 floats in [0, 1]) with the first 10 pixels overwritten by a one-hot label vector.
- Positive example. (image, true_label) – one-hot in the first 10 pixels encodes the correct class.
- Negative example. (image, wrong_label) – one-hot in the first 10 pixels encodes a uniformly random incorrect class.
- Architecture. A stack of fully-connected ReLU layers
(here
784 -> 500 -> 500). Between layers, activations are rescaled somean(h^2) = 1– the next layer cannot read off the previous layer’s magnitude (which is exactly the goodness signal). - Goal. Each layer learns to make
mean(h^2)high for positive inputs and low for negative inputs. At test time we try each candidate label, push it through the network, sum goodness across all layers, and pick the label with the highest summed goodness.
The interesting property is what you don’t need: no backward pass, no chain rule, no errors propagated across layers. Training is local, so each layer’s weights only depend on its own input and its own output. This is the property that motivates FF as a candidate “biologically plausible” learning rule and as a candidate for hardware where the backward pass is expensive.
The label-in-pixels trick is what makes the supervised setup work: a hidden
layer cannot use the absolute label bits to drive goodness because the
between-layer normalisation strips magnitude, so the only way to make goodness
high for (image, true label) and low for (image, wrong label) is to
discover features that covary with the label.
Files
| File | Purpose |
|---|---|
ff_label_in_input.py | MNIST loader + label-in-pixels encoding + FF MLP + Adam-trained per-layer FF loss + goodness-based prediction. CLI: --seed --n-epochs --lr --layer-sizes --jitter --train-subset --full-test --save --threshold --batch-size. |
visualize_ff_label_in_input.py | Static plots: example label-encoded images, candidate-label goodness heatmap, training curves, layer-0 receptive fields. |
make_ff_label_in_input_gif.py | Renders ff_label_in_input.gif (the animation at the top of this README). |
ff_label_in_input.gif | Committed animation – per-layer goodness, demo prediction, loss + accuracy over training. |
viz/ | Committed PNGs from the run below. |
Running
The MNIST data is downloaded once into ~/.cache/hinton-mnist/ (~11 MB).
# Final reported run -- 30 epochs, full 60K train set, eval on full 10K test:
python3 ff_label_in_input.py --seed 0 --n-epochs 30 --lr 0.003 \
--layer-sizes 784,500,500 --eval-subset 2000 \
--full-test --save model.npz
# Regenerate the static figures from the saved model:
python3 visualize_ff_label_in_input.py --model model.npz --outdir viz
# Regenerate the GIF (uses train-subset=20000 just to keep render time short):
python3 make_ff_label_in_input_gif.py --epochs 30 --snapshot-every 1 --fps 6 \
--seed 0 --lr 0.003 \
--layer-sizes 784,500,500 \
--train-subset 20000
Wallclock on an Apple M-series laptop:
- Training: 66 seconds for 30 epochs over the full 60K MNIST train set.
- GIF: 33 seconds (with
--train-subset 20000for speed).
Final accuracy: 96.40% on the full 10K test set (3.60% test error).
Results
| Metric | Value |
|---|---|
| Test accuracy (full 10K, seed 0) | 96.40% (3.60% error) |
| Train accuracy (eval subset, seed 0) | 97.2% |
| Training time | 66 s on Apple M-series, 30 epochs, full MNIST train set |
| Architecture | 784 -> 500 -> 500 ReLU (2 FF layers) |
| Optimiser | Adam, lr = 0.003, b1 = 0.9, b2 = 0.999 |
| Batch size | 128 |
| Goodness | mean(h^2) along the feature axis, per-sample |
| Threshold theta | 2.0 |
| Between-layer norm | rescale to mean(h^2) = 1 |
| Label encoding | one-hot in flat indices [0..9] (top row, leftmost 10 px) |
| Negative sampling | uniform over {0..9} \ {true_label} per minibatch |
| Prediction | for each candidate label, sum goodness across both layers, pick argmax |
| Seed | 0 |
Hinton (2022) reports 1.36% test error for 4 x 2000 ReLU after 60 epochs
on full MNIST, and 0.64% with 25-shift jittered augmentation at 500
epochs. We aimed at the <5% v1 threshold (Sutro group baseline target) and
chose a smaller architecture and fewer epochs to fit a laptop budget.
Per-class breakdown (full test set, seed 0)
| Class | Accuracy | Correct / Total |
|---|---|---|
| 0 | 98.78% | 968 / 980 |
| 1 | 98.77% | 1121 / 1135 |
| 2 | 95.64% | 987 / 1032 |
| 3 | 96.73% | 977 / 1010 |
| 4 | 95.62% | 939 / 982 |
| 5 | 94.96% | 847 / 892 |
| 6 | 96.76% | 927 / 958 |
| 7 | 95.72% | 984 / 1028 |
| 8 | 95.69% | 932 / 974 |
| 9 | 94.95% | 958 / 1009 |
Best class: 0 (98.78%). Worst class: 9 (94.95%). The 0/1 axis is the cleanest – both are visually the most distinctive digits.
Prediction mode comparison
Same trained model, different goodness-aggregation strategies at test time:
| Strategy | Test accuracy |
|---|---|
| layer 0 only | 96.86% |
| all layers (default) | 96.40% |
| skip layer 0 | 95.19% |
For this 2 x 500 architecture, layer 0 alone is the strongest predictor.
Layer 1 adds redundant signal – summing both layers underperforms layer 0
alone by 0.46 percentage points but outperforms layer 1 alone (skip-L0)
by 1.21 points. With Hinton’s deeper / wider 4 x 2000, deeper layers carry
more weight; the right aggregation strategy is architecture-dependent and
worth flagging for future replications.
Visualisations
Label-encoded inputs

The cyan box at top-left highlights the 10-pixel label slot. Three of these images have a single bright pixel in the slot (labels 7, 2, 1, 0 from left to right); for label 0 the bright pixel is at position 0 and is hard to spot against the white digit body of the “0”.
The choice of “first 10 pixels” exploits MNIST’s natural black border. Real images already have intensity 0 there, so overwriting them with a one-hot vector adds bounded label information without disturbing the foreground pixels.
Goodness for each candidate label

For each test image, we encode all 10 candidate labels in turn and sum the goodness across all layers. The bar plot is normalised per-image so the height is visually comparable.
The true label (green) is the argmax for every example shown – the network has correctly learned that high goodness means “this label was the right one for this image”. The blue runners-up are typically visually-similar digits.
Training curves

- Top-left: per-layer FF loss decreases monotonically. Layer 0 plateaus around 0.10 by epoch 30; layer 1 plateaus around 0.16 (slightly higher because its input has already been L2-rescaled and is therefore harder to separate cleanly).
- Top-right: the goodness gap widens. Both layers push positive
goodness above the threshold (
theta = 2.0, dotted) and negative goodness below it. By epoch 30 layer 0 haspos = 8.63, neg = 0.18(a 47x ratio) and layer 1 haspos = 5.98, neg = 0.36(a 16.5x ratio). - Bottom-left: train and test accuracy track each other tightly (no significant overfitting at this scale).
- Bottom-right: ~2.2 s per epoch over 60K training samples on a laptop using only NumPy.
Layer-0 receptive fields

Each tile is one column of the 784x500 weight matrix reshaped to 28x28 (red = positive, blue = negative). The features look like digit-shaped pen-stroke detectors – a noteworthy observation given the network never saw gradients pulled through a softmax. Goodness alone, contrasted between positive and negative pairs, is a strong enough signal to push layer 0 to discover digit-class-aligned features.
Deviations from the original procedure
- Architecture. Hinton uses
4 x 2000. We use2 x 500to fit the<5%v1 target in 66 s of pure-NumPy training on a laptop. The wider / deeper paper architecture would push us toward the paper’s 1.36% but at significantly more compute. - Optimiser. Hinton uses Adam (with cosine LR decay) – we use Adam at a
single fixed
lr = 0.003. We did not implement LR decay or warm-up. Atlr = 0.03(the value used inmohammadpz/pytorch_forward_forward) the network’s first Adam step kills layer 0’s ReLUs (90%+ dead neurons within 5 batches);lr = 0.003is the smallest LR we tested that both converges and avoids that failure mode. - No augmentation. The 0.64% error number in the paper requires 25-shift
jittered augmentation (max-shift 2 pixels in each direction, all 25
offsets per image, replicated 25x per epoch) at 500 epochs. The functions
jittered_augmentation()andjittered_augmentation_batch()are implemented (one random offset per batch element per epoch) and exposed via--jitter, but the headline run does not use them. Faithfully reproducing the 0.64% number would multiply training time by ~250x relative to our headline run, which is out of scope for v1. - Aggregating across layers. Hinton describes accumulating goodness
from all layers. The community-standard
mohammadpz/pytorch_forward_forwardskips layer 0 because of the label pixels in the input. We measured both: on this2 x 500architecture, layer 0 alone (96.86%) is best, followed by all-layers (96.40%) and skip-layer-0 (95.19%). The headline number uses the all-layers default (matches the paper). - Layer normalisation magnitude. The original paper specifies that the
length of the between-layer vector is
sqrt(D)(i.e.mean(h^2) = 1). We follow this exactly (initial implementations that L2-normalise to unit length collapse layer 0 within an epoch, which is a useful negative datapoint). - Two layers, not four. With
4 x 2000, the deeper layers contribute most of the per-layer goodness gap. With our2 x 500, layer 0 is doing most of the work; layer 1 slightly hurts the all-layers score (-0.46 percentage points vs layer-0 only). A2 x 500model running at the tested lr is layer-1-redundant – expanding to4 x 2000should restore the per-layer monotone goodness-gap pattern from the paper.
Open questions / next experiments
- Goodness gap vs depth. Why does layer 0 do so much of the work in our
run? Hinton’s paper reports a clean per-layer accumulation. Is this an
artefact of our smaller architecture, or of the specific lr / threshold
schedule? A sweep over
(depth, width)at fixed compute budget would tell. - Hard-negative selection. The paper hints that generated (not uniform-random) negatives are crucial for unsupervised FF. The supervised variant here uses uniform random wrong labels. Hard-negative sampling – pick the wrong label whose current goodness is highest – might tighten the goodness gap and reduce error without architecture changes.
- Energy/data-movement metric. This is the v1 baseline. The next pass (per the Sutro effort) is to instrument every layer with reuse-distance / ByteDMD tracking and ask: does FF actually beat backprop on data movement, per Hinton’s motivating claim? Backprop refetches all activations during the backward pass; FF’s gradient is purely local to each layer – the expectation is yes, but the magnitude is unknown.
- Jittered augmentation. Toggling
--jitterdoubles compute per epoch but the paper’s 0.64% number is achievable. A faithful 500-epoch jittered run would establish whether our2 x 500architecture is capacity-bound or augmentation-bound.
Reproducibility
| Python | 3.12.9 |
| NumPy | 2.2.5 |
| OS | macOS-26.3-arm64 |
| Random seed | exposed via --seed (default 0) |
| Final-run command | python3 ff_label_in_input.py --seed 0 --n-epochs 30 --lr 0.003 --layer-sizes 784,500,500 --eval-subset 2000 --full-test --save model.npz |
| MNIST cache | ~/.cache/hinton-mnist/ (11 MB; downloaded from storage.googleapis.com/cvdf-datasets/mnist/) |
The model.npz artefact is not committed – regenerate it with the command
above (or python3 visualize_ff_label_in_input.py will fall back to training
from scratch if the file is missing).
Forward-Forward: top-down recurrent on repeated-frame MNIST
Source: Hinton (2022), “The Forward-Forward Algorithm: Some Preliminary Investigations”, arXiv:2212.13345 / NeurIPS 2022 keynote, section 4 (“A recurrent network with top-down connections”).
Demonstrates: A static MNIST digit can be treated as a “video” of
repeated identical frames, and a multi-layer net can be run as a recurrent
dynamical system on it. Each hidden layer at time t is computed from the
L2-normalized activities of the layers immediately above and below at
t-1, with damping (0.3 * old + 0.7 * new) to stabilize the iteration.
The top layer is clamped to a one-of-N candidate label. Inference for a
test image runs 8 synchronous iterations under each candidate label and
picks the label whose hidden-layer goodness, summed over iterations 3-5,
is largest. Paper reports: 1.31% test error.

Problem
The Forward-Forward (FF) family of algorithms replaces backprop’s global gradient with a per-layer local objective: each layer learns to make its sum-of-squared activations high on positive examples and low on negative ones. The recurrent variant in this section of the 2022 paper turns FF into a temporal protocol:
- A single MNIST digit is shown for 8 frames in a row (no motion).
- Hidden layers update synchronously: each layer at time
treads the L2-normalized state of the layer above and below at timet-1. - The label is clamped throughout: it is the candidate-of-N digit class that this 8-iteration “movie” is being tested against.
- Goodness summed across the hidden layers, across iterations 3-5, decides which candidate label wins.
The recurrent setup is interesting because it makes label inference an attractor: a candidate label that disagrees with the image fails to develop high goodness; a candidate that agrees develops high goodness in a few iterations and stays there.
This implementation uses pure numpy. There is no torch, no autodiff, no backprop-through-time. Each iteration’s gradient flows only through the single forward step that produced it; the previous-step activations on which the inputs depend are treated as constants. That sidesteps BPTT entirely and matches the local-update spirit of FF.
Files
| File | Purpose |
|---|---|
ff_recurrent_mnist.py | MNIST loader (urllib + gzip), recurrent FF model, local FF training with Adam, prediction by per-iteration goodness accumulation. CLI flags --seed --n-epochs --n-iters --damping. |
visualize_ff_recurrent_mnist.py | Generates viz/training_curves.png, viz/iteration_goodness.png, viz/state_evolution.png. Loads a saved model with --load-model (preferred) or trains its own. |
make_ff_recurrent_mnist_gif.py | Renders ff_recurrent_mnist.gif (3 examples × 8 iterations of inference dynamics). Same --load-model flag. |
viz/ | Saved model, results JSON, training log, and the three PNGs above. |
ff_recurrent_mnist.gif | Animated demo (under 3 MB). |
Running
The default workflow trains once, saves the model, and reuses it for visualization and the GIF:
# 1. train and save (around 3-4 min on CPU)
python3 ff_recurrent_mnist.py \
--seed 0 --n-epochs 20 --hidden 500 --n-train 60000 \
--batch-size 256 --lr 3e-3 --threshold 1.0 --damping 0.7 \
--eval-test-subset 2000 \
--results-json viz/results.json --save-model viz/model.npz
# 2. visualize (a few seconds)
python3 visualize_ff_recurrent_mnist.py \
--load-model viz/model.npz --results-json viz/results.json --outdir viz
# 3. animate (~1 min)
python3 make_ff_recurrent_mnist_gif.py \
--load-model viz/model.npz --n-examples 3 --fps 4 \
--out ff_recurrent_mnist.gif
MNIST is downloaded on first run to ~/.cache/hinton-mnist/ (~11 MB,
fetched via urllib.request.urlretrieve). The cache is not committed.
Results
| Metric | Value |
|---|---|
| Architecture | 784 - 500 - 500 - 10 (input, two hidden ReLU layers, label) |
| Damping | 0.7 (weight on new activity per iteration) |
| Goodness threshold | 1.0 (mean-square per unit) |
| Iterations per forward | 8 |
| Test-time accumulation | iterations 3, 4, 5 (sum across hidden layers) |
| Train-time accumulation | iterations 3, 4, 5, 6, 7, 8 |
| Optimizer | Adam (lr=3e-3, β=(0.9, 0.999)) |
| Batch size | 256 (positives) + 256 (negatives) |
| Train set | 60000 (full) |
| Epochs | 20 |
| Final test error | 10.66% (89.34% accuracy) |
| Training wallclock | 216 s on this laptop |
| Paper reported | 1.31% |
| Reproduces paper? | No — see “Deviations” below |
| Seeds tried | 0 |
Test error vs epoch (red, eval on a 2000-image subset for speed) and train loss (blue, BCE on goodness across iterations 3-8) below. The training loss decreases monotonically; test error plateaus around 12-13% on the subset and resolves to 10.66% on the full 10k test set:

What the network actually learns
Per-iteration goodness for each candidate label
For a test image, run 8 iterations under each of the 10 candidate labels and plot the per-iteration goodness summed across hidden layers. The true label’s goodness rises sharply during iterations 2-4 and dominates by iteration 5; the spec’s accumulation window (iters 3-5, shaded green) captures this rise. Wrong-label trajectories grow more slowly or plateau low:

When the network gets it wrong (left panel: a “2” that the model labels as something else) the true label’s curve fails to rise above the others during the accumulation window — a clean signature of the failure mode. When it gets it right (panels 2-4) the true-label curve is the obvious maximum after only 3-4 iterations.
Hidden state across iterations
A heatmap of hidden-layer activations over the 8 iterations for one image clamped under its true label. Many units stay at zero (ReLU dead zones); the active ones lock into a sparse pattern by iteration 3 and stay there. This is the “attractor” behavior of the recurrent FF: a small subset of units codes for the digit-label pair, and the synchronous update is a fixed-point iteration onto that subset.

Animated inference
The GIF at the top of this README walks through inference for 3 test images, 8 iterations each. Each frame shows:
- the input image with the candidate label tag;
- layer-1 and layer-2 activations (first 80 of 500 units);
- per-candidate goodness traces (true label highlighted red);
- the current scoreboard (sum of goodness over iterations 3..min(5, t)).
The scoreboard is empty for t < 3 (the spec’s accumulation window
doesn’t open until then) and gets a “current pick” red bar from t=3
onward.
Deviations from the paper
The paper’s 1.31% number was achieved with a network and training budget roughly 25× the size of this implementation. The headline gap (paper 1.31% vs ours 10.66%) is driven by capacity, not algorithm.
| Paper (Hinton 2022) | This implementation | |
|---|---|---|
| Hidden layers | 4 | 2 |
| Hidden width | 2000 | 500 |
| Total hidden parameters | ~16M | ~0.6M (~25× smaller) |
| Training set | 60000 | 60000 |
| Epochs | 60 | 20 |
| Mini-batch size | 100 | 256 (+256 negatives) |
| Optimizer | hand-tuned + peer normalization | Adam, no peer norm |
| Augmentation | none reported here | none |
| Implementation | torch | numpy |
| Test error | 1.31% | 10.66% |
The two genuine algorithm-level deviations are:
- Mean-square goodness with threshold 1.0, where Hinton uses sum-of-squares with threshold equal to the layer width. The two formulations are mathematically equivalent up to scaling of the sigmoid logit; in numpy with Adam, the mean-square version trains more stably for our hyperparameters because the sigmoid is not in deep saturation throughout training. We tested the sum-of-squares variant and it converged comparably but was more sensitive to the threshold and learning rate.
- No peer normalization. Hinton’s recipe regularizes per-unit activity to prevent a few units from dominating the goodness signal. Adding peer norm is a likely lever for closing the remaining gap.
Smaller-scale items:
- Up and down weights between adjacent layers are stored as separate
matrices
W_upandW_dnrather than enforcing a single shared matrix. Whether the paper enforces symmetry is not entirely explicit in the recurrent section; separate matrices made the gradient bookkeeping cleaner and the algorithm at least as expressive. - Negatives are sampled wrong-label uniform across the other 9 classes per minibatch (one negative per positive). The paper uses model- generated negatives in some sections of the FF work; for the recurrent variant, wrong-label negatives are the natural choice and are what we use here.
Correctness notes
- Local gradients only. The training loop runs 8 forward iterations
and computes the FF logistic-on-goodness gradient at each iteration
in
train_eval_iters_one_indexed = (3..8). Within each iteration, the gradient flows through the new activations of that iteration only; the previous-step activations (states[k]going into the synchronous update) are treated as detached constants. This is the natural local formulation and is what makes the algorithm O(layer-width²) per iteration with no BPTT memory. - Damping convention. The spec wording “0.3 old + 0.7 new” maps to
damping=0.7in our code. We implementh_t = damping * relu(...) + (1 - damping) * h_{t-1}, treating theh_{t-1}term as detached for gradient purposes. - L2 normalization of inputs to each layer. Both the bottom-up
activations from layer
k-1and the top-down activations from layerk+1are L2-normalized before the per-layer matmul. This is the key trick that prevents goodness from leaking between layers via magnitude; the paper insists on it. - Iteration indexing. The spec is 1-indexed: “iterations 3-5” means
the 3rd, 4th, 5th synchronous update. The code uses
eval_iters_one_indexed=(3, 4, 5)and the test-time goodness sum uses the same convention. Internally the loop counter is 0-indexed, soiter_label = loop_t + 1.
Open questions / next experiments
- Capacity gap vs algorithm gap. Going from
2 × 500to4 × 2000hidden, holding everything else fixed, would isolate whether the 8× test-error gap is just capacity or whether peer normalization is doing real work. With numpy this is a 5-10× compute jump; worth doing once on a one-shot longer run. - Peer normalization. Hinton’s recipe tracks per-unit running activity and adds a regularization term that pushes unit-mean activations to a target. Adding this is a small change and is the most likely next lever for getting into the single-digit error rate.
- Model-generated negatives (from the unsupervised FF). Could potentially help, but the paper’s own recurrent variant uses wrong- label negatives, so the gain is probably marginal for this setup.
- Damping schedule. Annealing damping (slower mixing in early iters, faster later) could tighten the attractor; the paper does not vary it.
- What is the data-movement cost of this method vs an MLP-FF on the same MNIST? The recurrent dynamics re-read the same activations 8 times per inference, so the reuse-distance picture should look very different from a single forward pass. Would be a natural follow-up for the v2 ByteDMD instrumentation.
Forward-Forward on CIFAR-10 with locally-connected layers
Reproduction of the CIFAR experiment from Hinton (2022), “The Forward-Forward Algorithm: Some Preliminary Investigations” (arXiv:2212.13345).
Demonstrates: A two-layer locally-connected (no weight sharing) ReLU network trained on CIFAR-10 with the Forward-Forward goodness rule, compared to the same locally-connected stack trained end-to-end with backprop and softmax cross-entropy. The thing the headline paper wanted to show is that FF can plausibly scale beyond MNIST and stay within shouting distance of backprop on cluttered colour images, with a per-spatial-location filter bank that does not assume translation symmetry.

Problem
- Input. A 32x32 RGB CIFAR-10 image, per-channel mean-subtracted so
pixel values lie roughly in
[-0.5, 0.5]. - Label-in-input encoding (FF only). A
LABEL_ROWS x LABEL_LEN x 3 = 3 x 10 x 3 = 90pixel block in the top-left of the image is overwritten with a one-hot label usingLABEL_OFF = -1.0for unset bins andLABEL_ON = 1.0for the set bin (replicated across all 3 channels and 3 rows). - Architecture (both FF and BP). Two locally-connected layers, no
weight sharing across spatial positions:
- Layer 0: 32x32x3 input, RF = 11x11, 8 channels per location -> 22x22x8. 1.4 M params.
- Layer 1: 22x22x8 input, RF = 5x5, 8 channels per location -> 18x18x8. 0.5 M params.
- FF goal. Each layer learns to make
mean(h^2)(the goodness) high for(image, true_label)pairs and low for(image, wrong_label)pairs. Per-layer loss (Hinton 2022, eq. 1):L = log(1 + exp(-(g_pos - theta))) + log(1 + exp(g_neg - theta))withtheta = 2.0. - Between layers. Activations are renormalised so
mean(h^2) = 1— the standard Hinton recipe that strips magnitude so deeper layers cannot read off the previous layer’s goodness. - Prediction (FF). For each candidate label, encode it into the input, push through the network, sum goodness across layers (skipping layer 0 because it sees the label pixels directly), and pick the argmax.
- BP baseline. Same locally-connected stack with a
flatten + linearreadout, trained end-to-end with softmax cross-entropy and Adam. Uses the raw image (no label-in-input).
The interesting structural claim is that locally-connected layers, lacking the weight-sharing inductive bias of CNNs, still learn distinguishing features under the FF rule — and the FF/BP gap stays small (in Hinton’s paper) when depth grows.
Files
| File | Purpose |
|---|---|
ff_cifar_locally_connected.py | CIFAR loader (Toronto + Kaggle PNG fallback), label-in-input encoder, locally-connected layer with batched-matmul forward / backward, FF training loop, BP baseline, per-class accuracy. CLI: --seed --n-epochs --n-layers --batch-size --lr --threshold --train-subset --eval-subset --bp-baseline --full-test --save. |
visualize_ff_cifar_locally_connected.py | Static plots: example images, label-encoded examples, per-location layer-0 receptive fields, per-class FF vs BP test accuracy bars, FF vs BP training curves + per-layer goodness gap. |
make_ff_cifar_locally_connected_gif.py | Renders the animation at the top of this README. |
ff_cifar_locally_connected.gif | Committed animation, ~140 KB. |
viz/ | Committed PNGs from the headline run. |
Running
CIFAR-10 is downloaded once into ~/.cache/hinton-cifar/. Note: the
canonical Toronto mirror at https://www.cs.toronto.edu/~kriz/ returns 503
as of 2026-05 — they migrated cs.toronto.edu to a Squarespace site and the
~kriz/ directory is gone. The loader tries it first for completeness, then
falls back to the Kaggle ImageFolder PNG mirror (oxcdcd/cifar10, ~184 MB
zip of 60 K PNGs); both code paths produce identical numpy arrays. After the
first successful load the parsed arrays are cached as a single ~180 MB
cifar10.npz so subsequent runs reload in well under a second.
# Headline run -- 10 epochs, 10K train subset, lr 0.01, BP baseline, full test.
python3 ff_cifar_locally_connected.py --seed 0 --n-epochs 10 \
--train-subset 10000 --eval-subset 1000 --batch-size 64 \
--lr 0.01 --bp-baseline --full-test --save model.npz
# Static figures from the saved run:
python3 visualize_ff_cifar_locally_connected.py --model model.npz \
--per-class-test-subset 10000
# GIF (smaller subset for a quick render):
python3 make_ff_cifar_locally_connected_gif.py --epochs 8 --fps 4 \
--train-subset 4000 --eval-subset 500 --snapshot-every 1
Wallclock on Apple M-series, headline command:
- FF training: 104 s for 10 epochs over 10 K CIFAR-10 train images.
- BP training: 48 s for 10 epochs over the same 10 K images.
- GIF render: ~70 s (8 epochs over 4 K train, snapshot every epoch).
Results
Headline run (seed 0, lr 0.01, 10 epochs, 10 K train subset, full 10 K test):
| Method | Test acc (full 10 K) | Test error | Train acc (eval subset) | Wallclock |
|---|---|---|---|---|
| FF (locally-connected, goodness rule) | 22.78 % | 77.22 % | 23.1 % | 104 s |
| BP (same arch, end-to-end softmax CE) | 38.31 % | 61.69 % | 49.7 % | 48 s |
| Chance | 10.00 % | 90.00 % | — | — |
Hinton (2022) reports BP 37–39 % error and FF 41–46 % error on CIFAR (i.e. ~5 percentage-point gap, with FF closing on BP as depth grows). Our gap is wider (~15 pp) because the headline run is dramatically smaller than Hinton’s: 2 layers vs 2–3 layers at much higher width, 10 K train images vs 50 K, 10 epochs vs 60+, and a single-shot label-in-input encoding instead of his recurrent label-via-attention scheme. See Deviations below.
Per-class breakdown (full 10 K test set, seed 0)
See viz/per_class_accuracy.png for the bar chart. Both methods are
strongest on the visually-distinctive classes (automobile, ship,
truck) and weakest on the visually-confusable ones (cat, dog,
bird); FF and BP agree on the per-class ordering even though FF is
~15 percentage points lower overall. This co-ranking is a small piece of
evidence that the rules are learning similar features and FF is just
under-trained, rather than learning a fundamentally different object
representation.
Visualisations
CIFAR-10 examples and label-encoded inputs


The cyan box marks the LABEL_ROWS x LABEL_LEN = 3 x 10 label slot.
Pixels inside the slot take values from {LABEL_OFF, LABEL_ON} = {-1, +1}
(replicated across all 3 channels), well outside the centred pixel range
[-0.5, 0.5], so layer 0’s receptive field can latch onto them as a
high-magnitude label signal.
Layer-0 receptive fields

A grid of layer-0 receptive fields sampled from 6 spatial positions and 4 random channels. The whole point of locally-connected layers is that the same channel index has different learned weights at different spatial positions (no weight sharing) — that is what these tiles show. With more training and wider channels the per-location specialisation would become more pronounced; at 10 epochs many of the filters still look noise-like and the per-location structure is just emerging.
Per-class accuracy

FF vs BP training curves and FF goodness gap

- Top-left: test accuracy. BP climbs fast (chance -> 35 % in one epoch), then overfits to the 10 K subset by epoch 5–6 (test accuracy stops climbing while train continues to rise). FF climbs more slowly to ~23 %.
- Top-right: per-layer FF goodness, positive vs negative. By epoch 10
layer 0 has
g_pos = 2.12, g_neg = 1.84(gap 0.28) and layer 1 reachesg_pos = 2.31, g_neg = 1.79(gap 0.52). The gap is what drives prediction. - Bottom-left: per-layer FF loss. Both layers’ losses drop monotonically; layer 1’s loss drops further because the between-layer normalisation gives it cleaner inputs.
- Bottom-right: BP softmax cross-entropy.
Deviations from the original procedure
The original CIFAR experiment in Hinton (2022) is briefly described — the paper does not publish a single recipe. We deviate from common reconstructions of it in the following documented ways, all driven by the wave-8 5-minute laptop budget:
- Architecture. 2 layers at 8 channels per location vs Hinton’s 2–3 layers at much higher channel counts. Even with this small architecture the FF training time is 104 s for 10 epochs on 10 K train; scaling width or depth multiplies that linearly.
- Training set size. 10 K of the 50 K train images. The full 50 K would take ~9 minutes per run with the current numpy-only implementation.
- Epoch count. 10 epochs vs Hinton’s 60+. Both FF and BP curves are still moving at epoch 10, so most of the gap to the reported numbers is likely under-training.
- No top-down recurrence. Hinton’s CIFAR variant uses recurrent FF with bottom-up + top-down connections within each timestep and weights tied across timesteps (this is also where the “FF closes the gap with depth” effect comes from). We implement bottom-up only — a strict feed-forward stack — because the recurrent unrolling roughly multiplies training time by the number of timesteps and was out of scope.
- Label-in-input encoding. The MNIST-style “first 10 pixels of one
channel” encoding does not give a strong enough goodness gap on CIFAR
(verified empirically — the gap stays at ~0.001 even after 5 epochs).
We instead overwrite a
LABEL_ROWS x LABEL_LEN x 3 = 3 x 10 x 3 = 90pixel block with a high-contrast (-1vs+1) one-hot replicated across rows and channels. With this encoding the gap opens within the first epoch. Hinton’s CIFAR paper uses recurrent label-via-attention instead and does not publish the static-encoding details we needed to adapt. - CIFAR mirror. Toronto’s
~kriz/directory returns 503 as of 2026-05 because the dept site migrated. We fall back to the Kaggle PNG ImageFolder mirror; the loader tries Toronto first for compatibility with documentation that points there.
Open questions / next experiments
- Top-down recurrence. Adding a recurrent unrolling with top-down connections is the single change most likely to lift FF toward Hinton’s reported numbers (and is what makes FF competitive with rather than trailing BP on CIFAR in the paper). Estimated cost: 4–8x current wallclock per run.
- Wider / deeper architecture. Bumping channels to 16 or 32 per location and training for 30+ epochs on the full 50 K set should close most of the remaining gap to backprop. Cost: ~10–20x current wallclock.
- Hard-negative selection. Replace the uniform-random wrong-label negative with the wrong label whose current goodness is highest for this image. Likely to tighten the goodness gap and reduce error without architecture changes.
- Energy/data-movement metric. This is the v1 baseline. The next pass (per the Sutro effort) is to instrument every layer with reuse-distance / ByteDMD tracking and ask: does FF actually beat backprop on data movement, given that backprop refetches all activations during the backward pass while FF’s gradient is purely local? The locally-connected pattern in particular is interesting because the per-location weights have small fan-in/fan-out and could exhibit good cache locality.
Reproducibility
| Python | 3.12.9 |
| NumPy | 2.2.5 |
| OS | macOS-26.3-arm64 |
| Random seed | exposed via --seed (default 0) |
| Final-run command | python3 ff_cifar_locally_connected.py --seed 0 --n-epochs 10 --train-subset 10000 --eval-subset 1000 --batch-size 64 --lr 0.01 --bp-baseline --full-test --save model.npz |
| CIFAR cache | ~/.cache/hinton-cifar/ (~180 MB after first load) |
The model.npz artefact is not committed (16 MB; covered by the repo’s
.gitignore *.npz rule). Regenerate it with the command above, or
visualize_ff_cifar_locally_connected.py will fall back to training from
scratch if it is missing.
Forward-Forward: next-character prediction on Aesop’s Fables
Reproduction of the unsupervised / sequence-modelling Forward-Forward variant from Hinton (2022), “The Forward-Forward Algorithm: Some Preliminary Investigations” (arXiv:2212.13345, §3.4 of v3).
Demonstrates: A multi-layer ReLU network trained without backprop to predict the next character in a 30-symbol alphabet, by contrasting the goodness of real 10-character substrings of Aesop’s Fables (positives) against synthetic windows whose final character was wrong (negatives). Two ways of producing the negatives are compared head-to-head:
- teacher-forcing — keep the real first 9 chars, replace char 10 with the model’s current argmax prediction.
- self-generated — seed with the real first 10 chars and let the model
roll forward autoregressively for 90 more characters (sampled from
softmax(goodness / T)); use sliding 10-char windows from the resulting string as negatives.
The headline result is that both schemes train successfully and beat the
unigram and random-fixed-hidden baselines decisively, supporting the
“sleep phase” decoupling idea — that the negative phase does not need to
interleave with the positive phase in real time. In our laptop-scale run
the teacher-forcing variant ends a meaningful margin ahead of the
self-generated variant; matching Hinton’s reported parity likely needs the
full 3 × 2000 width and longer training (see Deviations below).

Problem
- Corpus. Aesop’s Fables, Project Gutenberg eBook 19994. After header / footer stripping,
lowercasing, and filtering to the 30-symbol alphabet
abcdefghijklmnopqrstuvwxyz ,;., we slice the first 24 800 characters into 248 strings of 100 characters each. - Input. A length-10 sliding window of one-hot characters
(
10 × 30 = 300floats per window). - Positive example. Any real 10-char substring of any training string.
- Negative example. Same shape (10 chars, 300-dim one-hot) but with at least the last character replaced by something the model itself generated.
- Architecture. 3 fully-connected ReLU layers (default
300 → 500 → 500 → 500). Between layers, activations are rescaled somean(h²) = 1— exactly Hinton’s recipe for stripping out the magnitude that drives goodness. - Goal. Each layer learns to push
mean(h²)above the thresholdθ = 2.0for positive windows and below it for negative windows. At test time we take a 9-char context, try each of the 30 possible next characters as the 10th element, and pick the candidate whose summed goodness across all layers is highest.
The interesting property mirrors the supervised label-in-input wave/7 sibling: no backward pass, no chain rule, gradients local to each layer. The new property here is that the negative data — the training signal’s other half — can be produced by a separate, completely offline rollout pass. Hinton calls this a “sleep phase.”
Files
| File | Purpose |
|---|---|
ff_aesop_sequences.py | Aesop loader + window encoding + FF MLP + Adam-trained per-layer FF loss + goodness-based next-char prediction. CLI: --seed --negatives {teacher_forcing, self_generated} --n-epochs --layer-sizes --lr --threshold --batch-size --steps-per-epoch --eval-every --rollout-every --rollout-temperature --save --baseline. |
visualize_ff_aesop_sequences.py | Trains (or loads) both variants, computes baselines, writes accuracy curves, per-position accuracy, generated-text samples, and per-layer goodness curves to viz/. |
make_ff_aesop_sequences_gif.py | Renders ff_aesop_sequences.gif: per-epoch accuracy curves + per-layer goodness + autoregressive rollouts side-by-side. |
ff_aesop_sequences.gif | Committed animation. |
viz/ | Committed PNGs. |
problem.py | Spec stub (skeleton) – kept for reference. |
Running
The Aesop text is downloaded once into ~/.cache/hinton-aesop/ (~170 KB).
# Headline run for both negative variants + visualisations:
python3 visualize_ff_aesop_sequences.py --n-epochs 30 --steps-per-epoch 200 \
--layer-sizes 300,500,500,500 \
--lr 0.003 --eval-every 2 \
--rollout-every 1 \
--rollout-temperature 1.0 --seed 0
# Single-variant training (teacher-forcing):
python3 ff_aesop_sequences.py --negatives teacher_forcing \
--n-epochs 30 --steps-per-epoch 200 \
--layer-sizes 300,500,500,500 \
--baseline --save model_tf.npz
# Single-variant training (self-generated):
python3 ff_aesop_sequences.py --negatives self_generated \
--n-epochs 30 --steps-per-epoch 200 \
--layer-sizes 300,500,500,500 \
--rollout-temperature 1.0 \
--baseline --save model_sg.npz
# Render the GIF (smaller architecture for fast frame rendering):
python3 make_ff_aesop_sequences_gif.py --epochs 20 --snapshot-every 1 --fps 4 \
--layer-sizes 300,400,400,400 \
--steps-per-epoch 120 --seed 0
Wallclock on an Apple M-series laptop (NumPy CPU only):
- Teacher-forcing training (30 epochs, 200 batches/epoch): 131 s.
- Self-generated training (30 epochs, 200 batches/epoch + per-epoch rollout): 108 s.
- Plotting + baselines: ~20 s.
- End-to-end implementation wallclock (smoke tests + headline run + GIF): ~12 minutes.
Results
Numbers below come from the headline run committed in viz/ (seed 0,
30 epochs, 200 batches/epoch, batch 128, lr 0.003, layer sizes
300 → 500 → 500 → 500, threshold 2.0, rollout temperature 1.0, rollout
refresh every epoch).
| Method | Per-char accuracy on 248 × 90 next-char predictions |
|---|---|
| chance (1 / 30) | 3.33% |
random fixed hidden (untrained FF stack, 300 → 500 → 500 → 500) | 2.85% |
unigram (always predict the most common char ' ') | 19.60% |
| FF teacher-forcing negatives (30 epochs) | 52.97% |
| FF self-generated negatives (30 epochs) | 34.08% |
Both FF variants substantially beat the unigram baseline (which is itself
much stronger than chance because the 30-symbol alphabet is space-heavy
in English text), and both beat the random-fixed-hidden control by more
than 10×. Self-generated negatives lag teacher-forcing by ~19 percentage
points at this scale; we expect that gap to close (per Hinton’s claim) at
3 × 2000 width and / or with longer training. The qualitative claim —
that a fully-decoupled, model-only-generated negative dataset can train
the FF stack at all — replicates cleanly.
Per-position accuracy

Accuracy is measured at each predicted character index (positions 10..99 of each 100-char string). It is roughly flat across positions for both variants, dipping slightly at the very start (less per-string statistics to lean on) and at the very end (no clear pattern; mostly noise).
Sample autoregressive rollouts

A representative seed plus the 90-character continuation produced
greedily (argmax of summed goodness) by each variant. Both rollouts
contain English-shaped chunks — common bigrams (th, er, an),
correct spaces between word-shaped runs — but neither perfectly tracks
real Aesop. This is expected: 248 × 100 = 24 800 characters is a tiny
corpus, and the model has only ~775 K parameters.
Per-layer goodness

Both variants drive positive goodness above the threshold and negative goodness below it within the first few epochs. Self-generated goodness oscillates more because the rollout (and hence the negative data) is regenerated every epoch, which presents a moving target.
Accuracy curves over training

Teacher-forcing accuracy climbs smoothly to 53% by epoch 30. Self-generated accuracy starts at the unigram floor (19.6%) and climbs more slowly, ending at 34% — comfortably above unigram, well above random-fixed-hidden, but visibly below teacher-forcing. With this small laptop-scale architecture and short training run, we did not reproduce the exact parity Hinton reports; both variants nonetheless train.
Deviations from the original procedure
- Architecture. Hinton uses 3 hidden layers of 2000 ReLUs each. We
use
500-500-500to keep training under a couple of minutes per variant on a NumPy CPU stack. Going to 2000 wide is a straightforward--layer-sizes 300,2000,2000,2000change; expected to roughly close the gap between the two variants and to lift absolute accuracy. - Self-generated rollout sampling. Hinton specifies an autoregressive
rollout. We expose
--rollout-temperature(default 1.0). With pure argmax (temperature = 0) the rollout collapses onto fixed-point attractors during the first few epochs (the model repeatedly emits' 'because it is the most-frequent character) and FF training destabilises. Sampling fromsoftmax(goodness / T)withT = 1.0keeps the negative distribution broad and avoids collapse. - Rollout refresh frequency. We refresh the entire 248-string
rollout every epoch (
--rollout-every 1). The “sleep phase” interpretation tolerates any refresh schedule; refreshing less often slightly increases the gap between the two variants in our experiments but does not change the qualitative result. - Optimiser. Hinton uses Adam with cosine LR decay. We use Adam at a
single fixed
lr = 0.003. We did not implement LR decay or warm-up. - Train / test split. The corpus is small (24 800 chars) and Hinton’s
experiment is about whether the FF mechanism can learn local
character statistics, not whether it generalises to held-out fables.
We therefore evaluate per-char accuracy on the same 248 strings used
for training. Holding out, e.g., the last 50 strings is a
one-line change in
evaluate_per_char_accuracy()if needed. - Window length 10. Hinton describes 10-character windows; we use the same. We did not sweep window length.
- Threshold θ = 2.0. Same as the supervised wave/7 FF run, same as Hinton’s preferred value across the paper.
Open questions / next experiments
- Architecture sweep. Does scaling to
3 × 2000close the residual gap between teacher-forcing and self-generated, as Hinton claims, or reveal a persistent variant-specific advantage? - Sleep-phase decoupling. What is the longest gap between rollout refreshes (in epochs) for which self-generated negatives still match teacher-forcing? If it is many epochs, the case for an offline “sleep phase” is strong.
- Hard-negative selection. Both variants currently use a single rollout per string. Selecting hardest negatives (windows whose current goodness is closest to the positive distribution) might tighten the goodness gap and lift accuracy without scaling the architecture.
- Energy / data-movement metric. This is the v1 baseline. The next pass is to instrument every layer with reuse-distance / ByteDMD tracking and ask: under FF, does the negative phase refetch any of the same activations as the positive phase, or is the data movement cleanly partitioned? The “sleep phase” intuition predicts very low cross-phase reuse.
- Held-out test set. Train on 200 strings, evaluate on 48 held-out ones. Does the per-char accuracy survive, or are we mostly memorising?
Reproducibility
| Python | 3.12.9 |
| NumPy | 2.x (whatever is on PATH) |
| OS | macOS arm64 |
| Random seed | exposed via --seed (default 0) |
| Final-run command | see Running |
| Aesop cache | ~/.cache/hinton-aesop/pg19994.txt (~170 KB; downloaded from gutenberg.org) |
The model_tf.npz and model_sg.npz artefacts are not committed —
regenerate them with the visualisation command (or pass
--save model_*.npz to ff_aesop_sequences.py directly).