The Effect of Mini-Batch Noise on the Implicit Bias of Adam
With limited high-quality data and growing compute, multi-epoch training is gaining back its importance across sub-areas of deep learning. Adam(W), versions of which are go-to optimizers for many tasks such as next token prediction, has two momentum hyperparameters $(β_1, β_2)$ controlling memory and one very important hyperparameter, batch size, controlling (in particular) the amount mini-batch noise. We introduce a theoretical framework to understand how mini-batch noise influences the implicit bias of memory in Adam (depending on $β_1$, $β_2$) towards sharper or flatter regions of the loss landscape, which is commonly observed to correlate with the generalization gap in multi-epoch training. We find that in the case of large batch sizes, higher $β_2$ increases the magnitude of anti-regularization by memory (hurting generalization), but as the batch size becomes smaller, the dependence of (anti-)regulariation on $β_2$ is reversed. A similar monotonicity shift (in the opposite direction) happens in $β_1$. In particular, the commonly “default” pair $(β_1, β_2) = (0.9, 0.999)$ is a good choice if batches are small; for larger batches, in many settings moving $β_1$ closer to $β_2$ is much better in terms of validation accuracy in multi-epoch training. Moreover, our theoretical derivations connect the scale of the batch size at which the shift happens to the scale of the critical batch size. We illustrate this effect in experiments with small-scale data in the about-to-overfit regime.
💡 Research Summary
The paper investigates how mini‑batch noise interacts with the momentum hyper‑parameters (β₁, β₂) of Adam (and its variant AdamW) to produce an implicit bias toward sharper or flatter regions of the loss landscape, which in turn affects generalization in multi‑epoch training. The authors first formalize “memory algorithms” – optimizers whose updates depend on the entire history of past gradients – and invoke a recent theorem (Cattaneo & Shigida, 2022) that allows any such algorithm to be approximated by a memory‑less iteration consisting of a main term and a correction term, with an error of order O(η²) for step size η. Applying this framework to SGD with momentum reveals two correction components: one proportional to the squared gradient norm (‖∇L‖²), which corresponds to an ℓ₂‑sharpness penalty, and another proportional to the trace of the gradient‑noise covariance matrix (tr Σ), which captures the effect of stochasticity and is also predictive of flatness.
The authors then extend the analysis to Adam. By rewriting Adam’s update in terms of weighted sums μₜ,ₖ and νₜ,ₖ of past gradients and squared gradients, and by expanding the mini‑batch loss deviation dₖ = Lₖ – L together with its derivatives, they derive an explicit expression for the expected parameter change after one iteration:
E
Comments & Academic Discussion
Loading comments...
Leave a Comment