Memory Caching: RNNs with Growing Memory
Transformers have been established as the de-facto backbones for most recent advances in sequence modeling, mainly due to their growing memory capacity that scales with the context length. While plausible for retrieval tasks, it causes quadratic complexity and so has motivated recent studies to explore viable subquadratic recurrent alternatives. Despite showing promising preliminary results in diverse domains, such recurrent architectures underperform Transformers in recall-intensive tasks, often attributed to their fixed-size memory. In this paper, we introduce Memory Caching (MC), a simple yet effective technique that enhances recurrent models by caching checkpoints of their memory states (a.k.a. hidden states). Memory Caching allows the effective memory capacity of RNNs to grow with sequence length, offering a flexible trade-off that interpolates between the fixed memory (i.e., $O(L)$ complexity) of RNNs and the growing memory (i.e., $O(L^2)$ complexity) of Transformers. We propose four variants of MC, including gated aggregation and sparse selective mechanisms, and discuss their implications on both linear and deep memory modules. Our experimental results on language modeling, and long-context understanding tasks show that MC enhances the performance of recurrent models, supporting its effectiveness. The results of in-context recall tasks indicate that while Transformers achieve the best accuracy, our MC variants show competitive performance, close the gap with Transformers, and performs better than state-of-the-art recurrent models.
💡 Research Summary
The paper addresses the fundamental limitation of recurrent neural networks (RNNs) – their fixed‑size hidden state – which hampers performance on tasks that require long‑range recall. While Transformers provide a growing memory that scales with sequence length, they incur quadratic time and memory costs (O(L²)). To bridge this gap, the authors propose Memory Caching (MC), a simple yet versatile framework that augments recurrent models with cached checkpoints of their hidden states.
Core idea:
A sequence of length L is split into N segments S¹,…,Sᴺ of possibly varying lengths. After processing each segment, the final hidden state M⁽ⁱ⁾ₗ(ⁱ) is stored as a checkpoint. When a new token arrives, its query qₜ is evaluated not only against the current (online) memory but also against all stored checkpoints via an aggregation function Agg. This yields a computational complexity of O(N L), which can be tuned from O(L) (standard RNN, N = 1) up to O(L²) (Transformer‑like, N = L).
Four MC variants:
-
Residual Memory – the simplest form, summing all cached memories with the online memory (a residual connection). Even though linear memories can be pre‑summed, the residual still improves recall in practice.
-
Gated Residual Memory (GRM) – introduces token‑dependent gates γ⁽ⁱ⁾ₜ ∈
Comments & Academic Discussion
Loading comments...
Leave a Comment