SketchOGD: Memory-Efficient Continual Learning
When machine learning models are trained continually on a sequence of tasks, they are often liable to forget what they learned on previous tasks–a phenomenon known as catastrophic forgetting. Proposed solutions to catastrophic forgetting tend to involve storing information about past tasks, meaning that memory usage is a chief consideration in determining their practicality. This paper develops a memory-efficient solution to catastrophic forgetting using the idea of matrix sketching, in the context of a simple continual learning algorithm known as orthogonal gradient descent (OGD). OGD finds weight updates that aim to preserve performance on prior datapoints, using gradients of the model on those datapoints. However, since the memory cost of storing prior model gradients grows with the runtime of the algorithm, OGD is ill-suited to continual learning over long time horizons. To address this problem, we propose SketchOGD. SketchOGD employs an online sketching algorithm to compress model gradients as they are encountered into a matrix of a fixed, user-determined size. In contrast to existing memory-efficient variants of OGD, SketchOGD runs online without the need for advance knowledge of the total number of tasks, is simple to implement, and is more amenable to analysis. We provide theoretical guarantees on the approximation error of the relevant sketches under a novel metric suited to the downstream task of OGD. Experimentally, we find that SketchOGD tends to outperform current state-of-the-art variants of OGD given a fixed memory budget.
💡 Research Summary
Continual learning aims to enable neural networks to acquire new tasks without erasing knowledge of previously learned ones, a challenge commonly referred to as catastrophic forgetting. Orthogonal Gradient Descent (OGD) is a simple regularization‑based continual‑learning method that stores the gradients of past examples in a matrix G and projects every new weight update onto the subspace orthogonal to range(G). While OGD is attractive because it does not require storing raw data and admits clean theoretical analysis, its memory footprint grows linearly with the number of training steps (O(p T), where p is the number of parameters and T the total number of updates). For modern over‑parameterized networks, this scaling quickly becomes prohibitive.
The paper introduces SketchOGD, a memory‑efficient variant of OGD that leverages matrix sketching to compress the gradient history into a fixed‑size summary. Sketching approximates a large matrix A by multiplying it with two random Gaussian matrices Ω and Ψ, producing low‑dimensional sketches Y = AΩ and W = ΨA. These sketches can be updated online as new gradients arrive, require only O(p k) memory (k ≪ p), and retain linearity properties that make the final sketch equivalent to sketching the full matrix after the fact.
Three concrete SketchOGD algorithms are proposed:
-
SketchOGD‑1 directly sketches the gradient matrix G. It maintains Y = GΩ and updates it with each new gradient g as Y ← Y + g ωᵀ (where ω is a freshly sampled Gaussian vector). The orthonormal basis for projection is orth(Y). Memory cost is p k, the smallest among the variants.
-
SketchOGD‑2 sketches the symmetric product GGᵀ. Using the same Ω, it updates Y ← Y + g (gᵀΩ). The basis is again orth(Y). This variant stores both Y and Ω, costing 2 p k memory, but benefits from the richer information contained in the Gram matrix.
-
SketchOGD‑3 employs a fully symmetric sketch, drawing both Ω and Ψ. It maintains two sketches Y and W, updates them with g, and extracts a basis from the concatenation
Comments & Academic Discussion
Loading comments...
Leave a Comment