Cautious Optimizers: Improving Training with One Line of Code
AdamW has been the default optimizer for transformer pretraining. For many years, our community searched for faster and more stable optimizers with only constrained positive outcomes. In this work, we propose a \textbf{one-line modification in Pytorch} to any momentum-based optimizer, which we rename cautious optimizer, e.g. C-AdamW and C-Lion. Our theoretical result shows that this modification preserves Adam’s Hamiltonian function and it does not break the convergence guarantee under the Lyapunov analysis. In addition, a whole new family of optimizers is revealed by our theoretical insight. Among them, we pick the simplest one for empirical experiments, showing not only consistent speed-up on LLM pretraining, but also image classification, with minimum extra tuning on hyperparameters. Code is available at https://github.com/kyleliang919/C-Optim.
💡 Research Summary
The paper introduces “Cautious Optimizers,” a lightweight modification applicable to any momentum‑based optimizer such as AdamW, Lion, or Polyak momentum. The authors observe that in standard momentum methods the update direction uₜ does not always align with the current gradient gₜ. When the signs disagree, the update can increase the loss temporarily and cause oscillations, slowing convergence. To address this, they propose a simple element‑wise mask:
ϕ(uₜ ∘ gₜ) = α(uₜ ∘ gₜ) · I(uₜ ∘ gₜ > 0)
where I is the indicator function and α is a positive scaling factor (default α = dim/nnz + 1). The mask zeroes out any coordinate whose update direction and gradient have opposite signs, while scaling the remaining coordinates to compensate for the reduced magnitude. In PyTorch this can be written in a single line (Algorithm 1), making it trivial to add to existing codebases.
Theoretical contributions
The authors place most momentum‑based optimizers within a continuous‑time Hamiltonian‑descent framework: a damped Hamiltonian system with a Lyapunov function H(w,s)=L(w)+K(s). While H is guaranteed to be non‑increasing, the actual loss L(w) may rise temporarily because kinetic energy can be traded for potential energy. By modifying the dynamics to multiply the kinetic term ∇K(sₜ) by the mask ϕ, they obtain new dynamics (Equation 5). They prove (Theorem 2.1) that if the mask satisfies element‑wise constraints (ϕᵢ ≥ 1 when xᵢ > 0 and ϕᵢ ≤ 0 when xᵢ < 0), then both H and L decrease faster than in the original system. The chosen mask ϕ(vₓ)=α(vₓ)·I(vₓ ≥ 0) with α ≥ 1 meets these conditions. Consequently, the modified optimizer retains the convergence guarantees of the base method while ensuring monotonic loss reduction.
In the discrete‑time setting, the paper shows (Theorem 2.3) that for μ‑smooth losses, a step size εₜ small enough relative to the mask‑induced reduction guarantees that each cautious update yields a loss no larger than the corresponding standard update. A stronger result (Theorem 2.4) provides a specific inner‑product‑based mask that guarantees loss decrease for any step size satisfying a simple inequality. These results formalize the intuition that masking out misaligned directions cannot hurt and often helps.
Empirical evaluation
The authors conduct three sets of experiments:
-
2‑D quadratic toy problem (L(w)=κ w₁² + w₂², κ=4). They compare Gradient Descent (GD), Gradient Descent with Momentum (GDM), and Cautious GDM (C‑GDM). Visualizations show that C‑GDM eliminates overshooting and oscillations, and both the Hamiltonian and the loss decrease more rapidly.
-
Large‑scale language model pre‑training. Using a 7‑billion‑parameter transformer, they train with AdamW, Lion, and their cautious counterparts (C‑AdamW, C‑Lion) on a 300‑billion‑token corpus. Hyperparameters (learning rate, betas, weight decay) are kept identical to the baseline. C‑AdamW and C‑Lion achieve 2–3 % better token‑level loss reduction, translating into faster convergence (fewer tokens to reach a given perplexity) without additional tuning.
-
Vision benchmarks. They fine‑tune ResNet‑50 on CIFAR‑100 and ViT‑B/16 on ImageNet‑1K using AdamW and C‑AdamW. C‑AdamW improves top‑1 accuracy by ~0.4 % on CIFAR‑100 and ~0.6 % on ImageNet, while reducing total training time by roughly 5 % due to faster convergence.
A notable practical point is that the method requires virtually no extra hyperparameter search; the same learning‑rate schedule and momentum coefficients used for the base optimizer work out‑of‑the‑box. The only new hyperparameter is the scaling factor α, which the authors set to a simple heuristic based on the proportion of aligned coordinates.
Critical assessment
The approach is elegant in its simplicity and the theoretical analysis is thorough, covering both continuous‑time Hamiltonian dynamics and discrete‑time step‑wise guarantees. However, the method’s effectiveness hinges on the behavior of α. In very high‑dimensional models the fraction of aligned coordinates can vary dramatically during training, potentially leading to overly aggressive scaling or under‑scaling. The paper provides a default heuristic but does not explore adaptive schemes or sensitivity analyses. Moreover, when uₜ and gₜ are almost perfectly anti‑aligned, the mask can zero out large portions of the update, temporarily stalling progress. The authors argue that momentum accumulation will eventually realign the directions, but empirical evidence for this recovery in highly non‑convex, rapidly changing loss landscapes is limited.
Finally, while the “one‑line” claim is technically correct for PyTorch, practical deployment still requires careful handling of the scaling factor and possibly adjusting the effective learning rate (the authors multiply the learning rate by the ℓ₀‑norm of the mask in Algorithm 2). This adds a small amount of bookkeeping that may be overlooked in production pipelines.
Conclusion and outlook
Cautious Optimizers provide a minimally invasive, theoretically sound way to improve the stability and speed of any momentum‑based optimizer. By masking out updates that would move opposite to the current gradient, they guarantee monotonic loss decrease while preserving the original optimizer’s convergence properties. Empirical results on both toy problems and large‑scale language and vision tasks demonstrate consistent gains with virtually no extra hyperparameter tuning. Future work could explore adaptive scaling of α, extensions to other frameworks (e.g., JAX, TensorFlow), and a deeper investigation of the method’s behavior in highly non‑stationary or adversarial training regimes.
Comments & Academic Discussion
Loading comments...
Leave a Comment