Putting a Face to Forgetting: Continual Learning meets Mechanistic Interpretability
Catastrophic forgetting in continual learning is often measured at the performance or last-layer representation level, overlooking the underlying mechanisms. We introduce a mechanistic framework that offers a geometric interpretation of catastrophic forgetting as the result of transformations to the encoding of individual features. These transformations can lead to forgetting by reducing the allocated capacity of features (worse representation) and disrupting their readout by downstream computations. Analysis of a tractable model formalizes this view, allowing us to identify best- and worst-case scenarios. Through experiments on this model, we empirically test our formal analysis and highlight the detrimental effect of depth. Finally, we demonstrate how our framework can be used in the analysis of practical models through the use of Crosscoders. We present a case study of a Vision Transformer trained on sequential CIFAR-10. Our work provides a new, feature-centric vocabulary for continual learning.
💡 Research Summary
The paper tackles catastrophic forgetting in continual learning from a mechanistic, feature‑level perspective rather than the usual performance‑or final‑layer similarity metrics. It builds on the linear representation hypothesis, which assumes that each learned concept (a “feature”) is encoded as a linear direction (a feature vector) in the activation space of a layer. A layer’s activation for an input is a linear combination of these feature vectors weighted by scalar feature activations. Because the representation space is finite, multiple features compete for capacity; the amount of capacity allocated to a feature is quantified by a metric C_i that depends on the norm of its vector and its cosine similarity with other vectors. Overlap reduces C_i, making the feature harder to read out cleanly.
Forgetting is modeled as two primitive geometric transformations of feature vectors: rotations (changing direction) and scaling (changing magnitude). Rotations increase overlap with other features, while scaling can cause fading (norm → 0) or strengthening (norm ↑). Both transformations can lower allocated capacity and, independently, can misalign the downstream readout vector r_i if it is not updated, leading to readout misalignment. Thus forgetting arises from (1) capacity degradation and (2) readout misalignment.
To formalize these ideas, the authors introduce a tractable “feature‑reader” model: a fixed probe w reads a set of feature vectors Φ through a linear readout ŷ = wᵀ Φ f(x), where f(x) are non‑negative feature activations. Two tasks A and B are learned sequentially. Using gradient descent on task B with mean‑squared error loss, they derive the expected update for each feature vector:
Δϕ_i = –η ( Σ_j γ_j Σ_{ij} – β_i w ),
where Σ_{ij}=E
Comments & Academic Discussion
Loading comments...
Leave a Comment