Deep Learning · Optimization Theory · Training Dynamics
beyond gradient descent
Deep Learning Optimization
Loss landscapes and their geometry, saddle points, sharp versus flat minima,
the role of SGD noise, implicit bias, generalization theory, batch size
scaling, warmup schedules, neural scaling laws, and grokking.
▲ Extends Gradient Descent — that article established first-order optimization mechanics. This one asks a deeper question: why does gradient descent in deep networks find solutions that generalize, and what do the loss landscape's geometry and the optimizer's stochasticity have to do with it?
Loss LandscapesSaddle PointsSharp vs Flat MinimaSGD NoiseImplicit BiasGeneralizationBatch Size TheoryLR WarmupScaling LawsGrokking
The loss landscape is the function \(\mathcal{L}: \mathbb{R}^P \to \mathbb{R}\) that maps every possible parameter vector to its training loss. For a model with \(P\) parameters, this is a hypersurface in \(\mathbb{R}^{P+1}\). Understanding its geometry — not just at the starting point, but globally — is the prerequisite for understanding everything else in this article.
The Hessian H is the second-order local map of the loss landscape. Its eigenvalues are the principal curvatures: large eigenvalues = steep walls (sharp directions), near-zero eigenvalues = flat valleys. For a model with P=10^10 parameters, H has 10^20 entries — completely intractable to store or compute, yet its spectral structure fundamentally determines every aspect of training dynamics.
Why High-Dimensional Geometry Is Counter-Intuitive
Our 3D intuition
In 3D, "stuck" means surrounded by walls — you need to climb up in every direction to escape. A local minimum in 3D has positive curvature in every direction. Saddle points require finding the one downhill direction.
Reality in P dimensions
In \(P\sim10^{10}\) dimensions, a point with even one negative Hessian eigenvalue is NOT a local minimum — gradient descent (or noise) will escape it. True local minima require all \(P\) eigenvalues to be positive. The probability of this at a random critical point is astronomically low — almost no critical points in deep networks are true local minima.
Visualizing Loss Surfaces
Li et al. (2018) introduced filter normalization visualization: project the landscape onto a 2D plane defined by two random Gaussian directions \(\boldsymbol{\delta}_1, \boldsymbol{\delta}_2\) from a trained solution \(\boldsymbol{\theta}^*\), normalized to match filter-wise norms of the parameter matrices (removing the spurious curvature caused by scale differences between layers):
Filter-Normalized Loss VisualizationLi et al. 2018
This revealed a stunning visual result: deep networks with skip connections (ResNets) have dramatically smoother, more convex-looking loss landscapes than networks without skip connections — the skip connections literally smooth out the terrain, which is one reason they train more reliably. This is a visual corroboration that architectural choices change landscape geometry, not just parameter count.
▲
The Key Question of This Article
The loss landscape has many minima. Gradient descent finds one of them. Classical machine learning theory asks "did you find the global minimum?" — and the answer for deep networks is almost certainly "no." But the right question turns out to be: which minimum did you find, and does it generalize? The rest of this article is about the geometry of that question.
02
Escaping
Saddle Points
// critical points · index of a saddle · why SGD escapes · the strict saddle property
A saddle point is a critical point (\(\nabla\mathcal{L}=\mathbf{0}\)) where the Hessian has both positive and negative eigenvalues — it looks like a local minimum from some directions and a local maximum from others. In high dimensions, they are vastly more common than local minima.
Saddle Point — Formal DefinitionCharacterization
\[\nabla\mathcal{L}(\boldsymbol{\theta}^*)=\mathbf{0}, \quad \mathbf{H}(\boldsymbol{\theta}^*) \text{ has at least one negative eigenvalue } \lambda_k < 0\]
\[\text{Index of the saddle} = \text{number of negative eigenvalues of }\mathbf{H}\]
In P dimensions, if the loss has a random critical point structure, the probability that ALL P eigenvalues are positive (true local min) is exponentially small in P. By the Wigner semicircle law for random matrices, roughly half the eigenvalues will be negative for a random critical point at the same loss value as the global minimum — so almost all critical points with low loss are high-index saddles, not true local minima.
How Gradient Descent Escapes Saddles
// Why exact gradient descent can get stuck, but SGD typically doesn't
1
Near a saddle point at \(\boldsymbol{\theta}^*\), the gradient is small. Along the negative-eigenvalue eigenvector \(\mathbf{v}_k\) (i.e., \(\mathbf{H}\mathbf{v}_k=\lambda_k\mathbf{v}_k\), \(\lambda_k<0\)):
\[\mathcal{L}(\boldsymbol{\theta}^*+t\mathbf{v}_k)\approx\mathcal{L}(\boldsymbol{\theta}^*)+\frac{t^2}{2}\lambda_k < \mathcal{L}(\boldsymbol{\theta}^*)\]
2
Exact gradient descent: if initialized exactly at the saddle (\(\boldsymbol{\theta}_0=\boldsymbol{\theta}^*\)), gradient is zero → no update → stuck forever. In practice, any tiny perturbation causes exponential divergence from the saddle along \(\mathbf{v}_k\) (with rate \(|\lambda_k|\)), but convergence can be extremely slow near a saddle where the negative eigenvalue is small.
3
SGD: gradient noise (from mini-batch sampling) has a component along \(\mathbf{v}_k\) with probability 1 (random noise projects onto every direction). The saddle is unstable under perturbation — SGD escapes saddles efficiently in practice.
Ge et al. (2015) proved that stochastic gradient descent with added noise converges to a second-order stationary point (no negative Hessian eigenvalues) efficiently in polynomial time, under the "strict saddle property" — all saddles have at least one eigenvalue strictly below some negative threshold. Empirically, deep networks satisfy this property.
Strict Saddle Property
Strict Saddle Property (Informal)
A function satisfies the strict saddle property if every critical point is either (a) an approximate local minimum (all Hessian eigenvalues \(\geq -\epsilon\)), or (b) a "strict saddle" with at least one eigenvalue \(\leq -\epsilon\) for some fixed \(\epsilon>0\). If a function is strict-saddle, then gradient descent with noise converges to a local minimum in polynomial time — saddle points are never stable fixed points, only transient obstacles. Empirically, deep networks with smooth activations appear to satisfy this, which is the leading theoretical explanation for why SGD reliably converges to low-loss solutions despite the non-convexity.
Not all minima are equal. Two parameters sets can achieve identical training loss yet radically different generalization. The key differentiator is the sharpness of the minimum — how steeply the loss rises when you perturb the parameters.
Sharpness — Maximum Perturbation SensitivityKeskar et al. 2017
This measures how much the loss can increase within an ε-ball around θ — the maximum "sensitivity to perturbation." A flat minimum has small φ_ε; a sharp minimum has large φ_ε. At a flat minimum, many nearby parameter settings have approximately the same (low) training loss — at a sharp minimum, even tiny perturbations spike the loss dramatically.
Why Flat Minima Generalize Better — Intuitive Argument
Consider training vs test loss. At any finite sample size \(n\), the training loss and test loss differ — the gap is approximately bounded by complexity measures. Now consider two minima with equal training loss:
A flat minimum: even if the true test loss is shifted slightly from the training loss (due to distribution shift or finite-sample effects), the minimum is still approximately at a flat, low-loss region of the test loss — generalization is robust.
A sharp minimum: a small shift between training and test loss distributions can move the sharp minimum's basin significantly — the test loss at the same parameter setting may be much higher than the training loss, even when the two distributions are nearly identical.
In the PAC-Bayes framework, Q is a distribution over parameters (a "fuzzy" version of our deterministic θ), P is a prior. If we set Q = N(θ, σ²I) (Gaussian around our found minimum with standard deviation σ) and P to be a fixed prior, the KL term measures how much "information" is in θ relative to the prior — and a FLAT minimum is one where σ can be chosen LARGE (Q still covers a region of low loss) → small KL → tighter generalization bound. Flatness and compressibility are the same property, seen from two perspectives.
SAM — Sharpness-Aware Minimization
SAM Objective — Foret et al. 2021Optimization Algorithm
SAM explicitly minimizes the WORST-CASE loss in a ρ-neighborhood — directly minimizing sharpness rather than just training loss. The update is a two-step process: (1) perturb θ to the worst-nearby point (maximization step, approximated by one gradient ascent step normalized to ρ), (2) gradient step at the perturbed point (minimization on the perturbed landscape). SAM consistently improves generalization across vision models, often matching larger models' accuracy with significantly better training efficiency.
04
Stochasticity
SGD Noise and Its Role
// gradient noise covariance · effective temperature · annealing · noise as implicit regularization
The gradient descent article established the update rule. Here we examine the noise structure of stochastic gradient descent in depth — because it turns out the noise is not a nuisance to be minimized, but an active ingredient that shapes which minima are found.
The per-sample gradient covariance C(θ) encodes the geometry of the data's loss structure — its eigenvalues in the directions of large gradient variance are the directions where loss changes most unpredictably across samples. Importantly, C(θ)/B DECREASES with larger batch size B — larger batches use a MORE accurate gradient estimate, leaving LESS gradient noise. We'll explore the consequences in §07.
Effective Temperature Analogy
The dynamics of SGD in the vicinity of a minimum can be approximated by a Langevin equation — a stochastic differential equation studied in statistical physics for Brownian motion:
SGD as Stochastic Differential EquationLangevin View
The effective temperature T_eff controls the amplitude of parameter fluctuations. In statistical physics, temperature determines which minima are accessible — at high T, the system explores broadly; at low T, it settles into narrow low-energy minima. Here: high η/B (small batch or large step) → high effective temperature → SGD explores broadly and avoids sharp minima (narrow "energy basins"). Low η/B → low effective temperature → SGD freezes near the first acceptable minimum it finds, which may be sharp.
Noise as Implicit Regularization
T
The Subtle Power of Gradient Noise
The effective temperature argument makes a testable prediction: everything else equal, high \(\eta/B\) should lead to flatter minima and better generalization — not because the optimizer has any explicit knowledge of flatness, but because high noise makes sharp, narrow minima unstable (the parameter oscillations exceed the basin width and kick the iterate out) while flat, wide minima are thermodynamically stable (the entire broad basin is accessible). Smith & Le (2018) confirmed this empirically: models trained with small batch + large LR (high T) generalize significantly better than those trained with large batch + small LR (low T), even when total gradient steps and final training loss are matched. The noise is doing the work.
05
Theory
Implicit Bias of Gradient Descent
// the implicit regularizer · linear networks → minimum norm · deep networks → max margin
Gradient descent does not just find a solution that fits the training data — it finds a specific solution determined by the optimizer's geometry. This "implicit bias" is a regularization effect that operates entirely without an explicit regularization term, and it is one of the deepest reasons deep networks generalize despite being highly overparameterized.
Linear Models: Gradient Descent → Minimum Norm
Theorem: Implicit Bias for Linear Regression
For a linear model \(\hat{\mathbf{y}}=\mathbf{X}\boldsymbol{\theta}\) initialized at \(\boldsymbol{\theta}_0=\mathbf{0}\) and trained with gradient descent on the MSE loss until interpolation (\(\mathbf{X}\boldsymbol{\theta}^*=\mathbf{y}\)), gradient descent converges to the minimum-norm interpolating solution \(\boldsymbol{\theta}^*=\mathbf{X}^+\mathbf{y}\) (the pseudoinverse solution), regardless of the learning rate and step count, even in the overparameterized setting where \(P \gg n\).
// Proof sketch — gradient descent stays in the row space of X
1
The gradient is \(\nabla_{\boldsymbol{\theta}}\mathcal{L} = 2\mathbf{X}^T(\mathbf{X}\boldsymbol{\theta}-\mathbf{y}) \in \text{row-space}(\mathbf{X})\) — always a linear combination of rows of \(\mathbf{X}\).
2
Starting from \(\boldsymbol{\theta}_0=\mathbf{0}\), every gradient descent update adds something in the row-space of \(\mathbf{X}\). By induction, \(\boldsymbol{\theta}_t\in\text{row-space}(\mathbf{X})\) for all \(t\).
3
The unique element of the row-space that satisfies \(\mathbf{X}\boldsymbol{\theta}=\mathbf{y}\) (the interpolation constraint) is precisely \(\boldsymbol{\theta}^*=\mathbf{X}^+\mathbf{y}\), which has minimum \(\ell_2\) norm among all interpolating solutions. ∎
No explicit L2 regularization was used anywhere. The minimum-norm solution emerged purely from gradient descent's initialization and geometry. This is "free" regularization.
Deep Networks: Max Margin
Implicit Bias — Linear Networks on ClassificationSoudry et al. 2018
\[\text{For separable data, gradient descent on logistic loss converges to the }\textbf{maximum L2-margin}\text{ classifier:}\]
\[\boldsymbol{\theta}_t/\|\boldsymbol{\theta}_t\| \to \underset{\boldsymbol{\theta}:\|\boldsymbol{\theta}\|=1}{\arg\max}\;\min_i y_i\mathbf{x}_i^T\boldsymbol{\theta} \quad\text{as } t\to\infty\]
Gradient descent on logistic/cross-entropy loss (the standard deep learning objective) implicitly converges to the maximum-margin classifier — the same solution found by SVMs with explicit margin maximization! The optimizer finds this solution for FREE, with no explicit max-margin objective in the code. This result extends (with more complex implicit biases) to deep linear networks and, empirically, to deep nonlinear networks trained on real data.
Why Implicit Bias Is Crucial for Overparameterization
∞
Infinitely Many Perfect Training Solutions, One Is Chosen
A network with more parameters than data points can fit the training data perfectly with infinitely many different parameter settings. Classical theory provides no guidance about which one gradient descent finds — they could all have wildly different test performance. The implicit bias resolves this: gradient descent is not choosing randomly from this infinite set. It is choosing the structurally simplest fitting solution in a sense determined by the optimizer's geometry (minimum norm, maximum margin, or more complex inductive biases for deep networks). This "structured interpolation" is one of the leading explanations for why massive overparameterized networks generalize at all.
06
Foundation
Generalization — Why Overparameterized Networks Work
// classical vs modern theory · double descent · benign overfitting · the interpolation threshold
Classical machine learning theory (VC dimension, Rademacher complexity, bias-variance tradeoff) predicts that interpolating the training data perfectly (zero training loss) leads to catastrophic generalization failure. Deep networks do this routinely — and generalize superbly. Something fundamental in the theory must be missing.
The Classical Bias-Variance Tradeoff
Bias-Variance DecompositionClassical Theory
\[\mathbb{E}\left[(\hat{f}(\mathbf{x})-y)^2\right] = \underbrace{\text{Bias}^2}_{\text{underfitting}} + \underbrace{\text{Variance}}_{\text{overfitting}} + \sigma^2\]
\[\text{Classical prediction: variance} \uparrow\text{ monotonically with model complexity}\]
Classical theory says: after crossing the interpolation threshold (fitting training data perfectly), variance explodes → generalization collapses. The optimal model is at the bias-variance "sweet spot" just before interpolation. Neural networks empirically violate this prediction at scale.
Double Descent
Double Descent PhenomenonBelkin et al. 2019
\[\text{Risk}(P) = \begin{cases}\text{classical U-shape} & P < n \text{ (under-parameterized)} \\ \text{peaks} & P \approx n \text{ (interpolation threshold)} \\ \text{decreases again!} & P \gg n \text{ (over-parameterized)}\end{cases}\]
Belkin et al. demonstrated that as model size P passes through n (the number of training samples), test error first follows the classical U-shape, then PEAKS at the interpolation threshold (the model is just barely able to fit the training data, in a "rigid," high-variance way), then DECREASES again as P grows further — because overparameterized models have the freedom to find the MINIMUM NORM interpolating solution (§05), which turns out to have low test error. Double descent was observed in random feature models, kernel methods, and neural networks.
Benign Overfitting
Theorem: Benign Overfitting (Bartlett et al. 2020)
For linear models in high-dimensional settings where the data covariance has rapidly decaying spectrum, the minimum-norm interpolating solution can have zero training error AND asymptotically optimal test error simultaneously — "benign" because the overfitting causes no generalization harm. The precise condition: the "effective dimensionality" of the covariance's tail (dimensions beyond the top-\(n\)) must be large enough to absorb the noise from memorizing training labels without affecting predictions on unseen test points (which lie primarily in the top-n signal directions).
This result provides a formal version of the double-descent intuition: when the model has far more degrees of freedom than training points, it can simultaneously memorize training data (using the many tail dimensions) AND generalize (because the signal-relevant directions are learned correctly and minimally corrupted by the memorization).
07
Scaling
Batch Size Theory
// the linear scaling rule · critical batch size · gradient noise scale · Goyal et al. 2017
Batch size is one of the most consequential hyperparameters in deep learning training — it controls the trade-off between throughput (samples per second), wall-clock time, and generalization. The theory of how to scale batch size optimally is surprisingly precise.
The Linear Scaling Rule
Linear Scaling Rule (Goyal et al. 2017)Facebook AI Research
\[\text{If batch size multiplied by } k: \quad \eta \leftarrow k\cdot\eta \quad\text{(multiply learning rate by same } k\text{)}\]
Intuition via effective temperature: recall from §04 that T_eff ∝ η/B. To keep T_eff constant when B → kB (training dynamics unchanged), we need η → kη. This works in practice up to some critical batch size B_crit beyond which the linear rule breaks down and additional techniques (warmup, §08) are needed.
Critical Batch Size — The Efficiency Frontier
Critical Batch Size (McCandlish et al. 2018)OpenAI Scaling
\[B_{\text{crit}} \triangleq \frac{\text{tr}(\mathbf{C}(\boldsymbol{\theta}))}{\|\nabla\mathcal{L}(\boldsymbol{\theta})\|^2} = \frac{\text{gradient noise scale}}{\text{gradient signal scale}}\]
\[\text{For } B \ll B_{\text{crit}}: \text{ parallelism is efficient — each sample adds new information}\]
\[\text{For } B \gg B_{\text{crit}}: \text{ redundant work — extra samples add diminishing gradient accuracy}\]
B_crit is the batch size at which gradient noise and gradient signal are equal in magnitude. Below B_crit, increasing batch size improves gradient accuracy proportionally — linear speedup in training efficiency. Above B_crit, increasing batch size gives diminishing returns in gradient accuracy — you're just averaging away already-negligible noise. This defines the "efficiency frontier" of batch size scaling.
Practical Consequences
Regime
Batch Size
Steps to Convergence
Wall-Clock Time
Generalization
Small batch SGD
\(B \ll B_{\text{crit}}\)
Many steps
Long (serial)
Best (high T_eff, flat minima)
Critical batch
\(B \approx B_{\text{crit}}\)
Optimal efficiency
Moderate
Good
Large batch SGD
\(B \gg B_{\text{crit}}\)
Fewer steps, but redundant
Shorter (parallel)
Worse (low T_eff, sharp minima)
McCandlish et al. showed that \(B_{\text{crit}}\) grows as training progresses — early in training, the gradient is large and noisy (small \(B_{\text{crit}}\)); late in training, the gradient is small and smooth (large \(B_{\text{crit}}\)), allowing larger batches without efficiency loss. This motivates increasing batch size during training as an alternative to learning rate decay — both achieve similar effects via the T_eff ratio.
08
Scheduling
Learning Rate Warmup
// why large LR crashes early training · gradient variance theory · cosine schedule · warmup variants
Modern large-scale training invariably starts with a learning rate warmup — gradually increasing the learning rate from near-zero to its target value over the first few hundred to a few thousand steps. This seemingly small detail is critical for training stability and has a precise mathematical justification.
Why Large LR Crashes at the Start
// Mathematical analysis via gradient variance at initialization
1
At initialization, weights are random (e.g., Kaiming uniform/normal). Layer outputs have HIGH variance — small input perturbations can produce wildly different activations. The per-sample gradient variance \(\text{Var}(\nabla\ell_i)\) is correspondingly large: gradients fluctuate dramatically from sample to sample.
2
The effective learning rate for stable GD is bounded by the descent lemma (Convex Optimization masterclass): \(\eta < 2/L\) where \(L = \lambda_{\max}(\mathbf{H})\). At initialization, \(\lambda_{\max}\) is large (large gradient variance → large curvature in many directions) → the stable learning rate is small.
3
Starting with a large learning rate violates the descent-lemma bound → loss spikes. The network never recovers: the catastrophic update at step 1 scrambles all weights, and subsequent steps start from a worse-than-initial position.
Adam exacerbates this: its second moment estimates start at 0, so the effective step size in early steps is proportional to |gradient|/sqrt(second moment) → potentially very large for well-matched but small early-step second moments. This is why Adam specifically needs warmup — to let the second moment estimates stabilize.
Common Warmup + Decay Schedules
Warmup + Cosine Decay (Standard Modern Schedule)Recipe
\[\eta(t) = \begin{cases}\eta_{\max}\cdot\frac{t}{T_{\text{warm}}} & t \leq T_{\text{warm}} \quad\text{(linear warmup)} \\ \eta_{\max}\cdot\frac{1}{2}\!\left(1+\cos\!\left(\pi\,\frac{t-T_{\text{warm}}}{T-T_{\text{warm}}}\right)\right) & t > T_{\text{warm}} \quad\text{(cosine decay)}\end{cases}\]
T_warm is typically 500–2,000 steps; T is the total training steps. The cosine schedule smoothly decays from η_max to near-zero, with a smooth derivative (no discontinuities) that avoids sudden "learning rate drop" artifacts. The combination (warmup → cosine) is the default schedule across most modern large model training runs (GPT, BERT, LLaMA, ViT, etc.).
Warmup Duration Scaling
↑
Warmup Scales With Model Size and Batch Size
The required warmup duration scales roughly proportionally with the target learning rate and inversely with the initial gradient-step stability. Empirically: larger models (higher \(\lambda_{\max}\) at init due to more layers and parameters) need longer warmup. Larger batch sizes (where the effective step size per epoch is larger) need longer warmup. The Transformer training recipe (Vaswani et al. 2017) used \(T_{\text{warm}} = 4000\) steps for a medium-sized model; modern LLM training (Chinchilla-scale) typically uses 1-5% of total training steps as warmup. A simple rule: warmup should be long enough for the second-moment estimates of Adam to stabilize, which requires \(\approx 1/(1-\beta_2)\) steps (\(\approx 1000\) steps for \(\beta_2=0.999\)).
09
Empirical
Neural Scaling Laws
// Kaplan et al. 2020 · Chinchilla · compute-optimal training · power laws in loss and compute
Scaling laws are perhaps the most striking empirical regularity in modern deep learning: across many orders of magnitude in model size, dataset size, and compute budget, loss follows precise power law relationships that allow quantitative predictions about training outcomes before any training is done.
N = number of non-embedding parameters, D = dataset size in tokens, C = total FLOPs. These power laws hold across 7+ orders of magnitude in each quantity for language models, with the same exponent remarkably consistent across architectures, tokenizations, and data sources. N_c, D_c, C_c are constants determined by fitting empirical data. The consistent power-law behavior (rather than, say, exponential or linear) is itself a profound and not-fully-explained property of deep learning.
Chinchilla — Compute-Optimal Scaling
// Hoffmann et al. 2022 — finding the optimal N,D for a given compute budget C
1
For a fixed compute budget \(C = 6ND\) FLOPs (approximately, for a transformer with \(N\) parameters trained on \(D\) tokens, where 6 accounts for forward + backward passes):
\[\mathcal{L}(N,D) \approx E + \frac{A}{N^{\alpha}} + \frac{B}{D^{\beta}}\]
2
Minimize \(\mathcal{L}\) subject to \(C = 6ND\) (budget constraint). Using Lagrange multipliers (from the Lagrange Multipliers masterclass):
\[N_{\text{opt}} \propto C^{0.49}, \quad D_{\text{opt}} \propto C^{0.51}\]
3
Chinchilla's practical rule: for every doubling of model size \(N\), roughly double the dataset size \(D\) as well. Or: train approximately 20 tokens per parameter for compute-optimal performance.
Kaplan et al. underestimated the data scaling exponent — they recommended far too few training tokens relative to model parameters. Chinchilla corrected this: Gopher (280B params, 300B tokens) was undertrained; Chinchilla (70B params, 1.4T tokens) achieved better performance on the same compute budget. The field recalibrated around the ~20 tokens/parameter rule, and subsequent models (LLaMA, Mistral) were trained even more data-efficiently. ∎
Why Power Laws and Not Something Else?
α
A Partial Explanation
Bahri et al. (2021) and others have argued that power laws arise from a "random matrix theory + statistical mechanics" picture: the data distribution can be decomposed into independent "tasks" of varying difficulty, and a model with \(N\) parameters solves tasks in order of decreasing difficulty (easiest first). If task difficulties follow a power-law distribution (itself plausible from Zipf-like statistics of natural language), loss as a function of capacity follows a power law. But this remains partially speculative — the precise mathematical mechanism behind empirical scaling laws is an active area of theoretical research.
10
Phenomenon
Grokking
// Power et al. 2022 · sudden generalization long after memorization · phase transition · algorithmic representations
Grokking (Power et al. 2022) is one of the most surprising optimization discoveries of the 2020s: networks can achieve perfect training accuracy on a task, then show near-zero validation accuracy for tens of thousands of additional training steps — before suddenly, discontinuously, generalizing perfectly. The memorized solution "transmutes" into a generalizing one long after the loss has converged.
The Canonical Experiment
Grokking Setup — Modular ArithmeticPower et al. 2022
\[\text{Task: } (a + b) \bmod p \text{ for } a,b\in\{0,\ldots,p-1\}, \quad p = 97\]
\[\text{Observations: } \text{train on } 30\%-70\% \text{ of all } p^2 \text{ pairs; evaluate on the rest}\]
\[\text{Training loss: } \to 0 \text{ in } \sim 100 \text{ steps}\quad\text{Validation accuracy: } \to 100\% \text{ in } \sim 10{,}000\text{-}100{,}000 \text{ steps}\]
The network memorizes the training set immediately. But memorization and generalization use fundamentally different internal circuits. After memorization, the network continues to be trained with L2 weight decay (which slowly drives weights toward zero). The network gradually replaces its memorizing solution with a generalizing one — but this process can take 100-1000× more steps than the initial memorization.
Why Grokking Happens — The Representation Hypothesis
// Nanda et al. 2023 mechanistic analysis of grokking in mod-p addition
1
Memorizing solution: lookup-table behavior — each (a,b) pair is individually memorized by specific network components. High weight norm (resists L2 penalty), no generalizable structure.
2
Generalizing solution (revealed by interpretability): the network discovers that modular addition has a Fourier structure — specifically, the operation \((a+b)\bmod p\) can be computed via a specific periodic representation using frequencies \(\omega_k = 2\pi k/p\):
\[\cos(\omega_k(a+b)) = \cos(\omega_k a)\cos(\omega_k b) - \sin(\omega_k a)\sin(\omega_k b)\]
The network's activations, when analyzed via Fourier transform, show that the generalizing solution implements EXACTLY this identity — it represents a and b as sets of (cos,sin) pairs at specific frequencies, multiplies them, and reads off the result via another linear layer. This is a genuine "discovery" of mathematical structure by gradient descent.
3
The phase transition: L2 regularization continuously penalizes the high-norm memorizing solution and slowly builds up the lower-norm generalizing solution in parallel. When the generalizing solution becomes more efficient (lower total loss after regularization), the network undergoes a sudden phase transition — rapidly discarding memorization in favor of the algorithmic generalizing representation. ∎
Grokking's Broader Significance
Memorization is not the end state: even after achieving 100% training accuracy, gradient descent continues making meaningful updates to the network's internal representations — motivated by regularization rather than loss reduction.
Generalization can be sudden: the standard view that generalization improves gradually alongside training loss is wrong for tasks with clean algorithmic structure. A phase transition separates memorizing and generalizing phases.
Regularization is essential for grokking: without weight decay, grokking does not occur — the memorizing solution is a stable fixed point. Regularization destabilizes it in favor of the more parameter-efficient generalizing solution.
Training duration as a hyperparameter: standard early stopping (stop when validation loss stops improving) would terminate training during the memorization phase, long before the generalizing phase begins. Grokking suggests that patience in training — and careful regularization — can unlock qualitatively better solutions.
11
Synthesis
The Complete Mental Model
// everything unified · one diagram · the complete optimization story
Fig 1. Deep Learning Optimization complete map — from loss landscape geometry through saddle points, flatness, SGD noise, implicit bias, generalization, batch/schedule engineering, scaling laws, and grokking.
The complete story, unified:
The loss landscape is a high-dimensional hypersurface whose local geometry is entirely captured by the Hessian. In \(P\sim10^{10}\) dimensions, our 3D intuition fails — true local minima are extraordinarily rare; almost all critical points are saddle points.
Saddle points have at least one negative Hessian eigenvalue. They are unstable under the gradient noise of SGD — unlike exact gradient descent, which can get stuck. Networks satisfying the strict saddle property allow SGD to efficiently escape every saddle in polynomial time.
Not all minima are equal. Flat minima (small Hessian eigenvalues, broad loss basins) generalize better than sharp minima. The PAC-Bayes framework makes this precise: flatness equals compressibility, and compressibility bounds generalization error.
SGD noise is not a problem to minimize — it's a tool. The effective temperature \(T_{\text{eff}} \propto \eta/B\) determines whether SGD explores broadly (finding flat minima) or settles early (finding sharp ones). High \(\eta/B\) is its own form of implicit regularization.
Implicit bias means gradient descent — even without any explicit regularization term — finds structurally simple solutions: minimum-norm for regression, maximum-margin for classification. This is why overparameterized networks generalize at all: they are not choosing randomly from the infinite set of interpolating solutions.
Classical generalization theory is wrong for modern networks. The double descent phenomenon and benign overfitting show that interpolating the training data perfectly can actually be the best choice when sufficiently overparameterized — provided the implicit bias selects the right interpolant.
Batch size and learning rate are coupled through \(T_{\text{eff}}\). The linear scaling rule maintains \(T_{\text{eff}}\) constant when scaling batch size. The critical batch size \(B_{\text{crit}}\) is the efficiency frontier beyond which larger batches provide diminishing gradient-accuracy returns.
Learning rate warmup is necessary because the Hessian's largest eigenvalue is large at initialization, making large steps unstable. Warmup gradually increases the step size as the landscape becomes more stable and the gradient variance decreases.
Neural scaling laws reveal that loss follows power laws in model size, data size, and compute — allowing quantitative prediction of training outcomes. Chinchilla showed that both model and data must be scaled together for compute-optimal training.
Grokking reveals that training beyond apparent convergence (100% training accuracy) can trigger a phase transition from memorization to genuine generalization — with L2 regularization as the thermodynamic force driving this transition. It shows that training duration is an underappreciated hyperparameter.
The Unifying Answer
Why does gradient descent in deep networks find solutions that generalize, despite the profound non-convexity and astronomical overparameterization? The answer is a confluence of all ten phenomena in this article: the high-dimensional landscape's saddle-dominated structure funnels SGD toward minima rather than saddles; SGD's gradient noise provides an effective temperature that biases the search toward flat, generalizing minima over sharp, memorizing ones; gradient descent's implicit bias selects the structurally simplest interpolating solution within whatever basin it reaches; and regularization — explicit weight decay or implicit via finite training budget and SGD noise — continuously sculpts the solution toward one that trades training-set fit for representational compactness, which is exactly the property that the PAC-Bayes theorem identifies as the source of generalization. Generalization in deep learning is not a coincidence or a mystery — it is the accumulated statistical physics of these optimization dynamics, all working in the same direction.