Multiway Multislice PHATE: Visualizing Hidden Dynamics of RNNs through Training
Recurrent neural networks (RNNs) are a widely used tool for sequential data analysis; however, they are still often seen as black boxes. Visualizing the internal dynamics of RNNs is a critical step toward understanding their functional principles and developing better architectures and optimization strategies. Prior studies typically emphasize network representations only after training, overlooking how those representations evolve during learning. Here, we present Multiway Multislice PHATE (MM-PHATE), a graph-based embedding method for visualizing the evolution of RNN hidden states across the multiple dimensions spanned by RNNs: time, training epoch, and units. Across controlled synthetic benchmarks and real RNN applications, MM-PHATE preserves hidden-representation community structure among units and reveals training-phase changes in representation geometry. In controlled synthetic systems spanning multiple bifurcation families and smooth state-space warps, MM-PHATE recovers qualitative dynamical progression while distinguishing family-level differences. In task-trained RNNs, the embedding identifies information-processing and compression-related phases during training, and time-resolved geometric and entropy-based summaries align with linear probes, time-step ablations, and label–state mutual information. These results show that MM-PHATE provides an intuitive and comprehensive way to inspect RNN hidden dynamics across training and to better understand how model architecture and learning dynamics relate to performance.
💡 Research Summary
This paper introduces Multiway Multislice PHATE (MM‑PHATE), a graph‑based dimensionality‑reduction technique designed to visualize the evolution of recurrent neural network (RNN) hidden states across three axes: time‑steps within a sequence, training epochs, and hidden units. Existing visualization methods either focus on post‑training representations or are limited to feed‑forward networks, failing to capture the intertwined temporal and learning dynamics intrinsic to RNNs. MM‑PHATE addresses this gap by extending the Multislice PHATE (M‑PHATE) framework—originally built for feed‑forward networks—to handle the additional temporal dimension inherent in recurrent architectures.
The method begins by constructing a four‑dimensional tensor T(τ, ω, i, k) that stores the activation of hidden unit i at time‑step ω during training epoch τ for input sample k. Each activation is z‑scored across samples to remove bias and variance effects. From this tensor two similarity kernels are derived:
-
Intra‑step kernel – measures similarity between different hidden units i and j at the same epoch τ and time‑step ω. It uses an adaptive Gaussian kernel with bandwidth σ(τ, ω, i) set to the distance to the k‑th nearest neighbor (k=5 in all experiments), allowing the kernel to adapt to local density variations.
-
Inter‑step kernel – measures similarity of the same hidden unit i across different epochs and time‑steps (τ, ω) ↔ (η, ν). This kernel employs a fixed bandwidth ε equal to the average nearest‑neighbor distance across the entire dataset, thereby linking the temporal and training dimensions.
These kernels are assembled into a large block‑structured matrix K of size (n × s × m) × (n × s × m), where n is the number of epochs, s the number of time‑steps, and m the number of hidden units. The block diagonal contains intra‑step affinities, while off‑diagonal blocks contain inter‑step affinities for the same unit; all other entries are zero. After symmetrization and row‑normalization, K serves as the diffusion operator for PHATE.
PHATE then computes multi‑step diffusion probabilities, transforms them into an information‑theoretic distance, and embeds the points into 2‑D or 3‑D space using classical multidimensional scaling. The resulting plot simultaneously displays the trajectory of each hidden unit across training and time, preserving community structure (clusters of units with similar functional roles) while revealing global geometric changes.
The authors evaluate MM‑PHATE on two fronts:
-
Synthetic benchmarks – They generate dynamical systems based on Hopf and Pitchfork bifurcations, optionally applying smooth warps to the state space. MM‑PHATE successfully recovers the qualitative progression of the underlying dynamics, distinguishes different bifurcation families, and remains robust to warping, demonstrating that the method captures true dynamical relationships rather than artefacts of the embedding.
-
Real RNN tasks – Experiments on sequence prediction and language modeling tasks show a characteristic two‑phase pattern during training. Early epochs exhibit an “expansion” phase where hidden units spread out in the embedding, reflecting high variability and information capture. Later epochs show a “compression” phase where units cluster, indicating dimensionality reduction and task‑specific specialization. These phases align with three independent diagnostics:
- Linear probes – Time‑resolved linear regression from hidden states to targets shows a sharp increase in predictability during expansion and a plateau during compression.
- Time‑step ablations – Masking specific time‑steps leads to performance drops precisely at the epochs identified as critical by the embedding.
- Label–state mutual information – Mutual information between hidden activations and task labels peaks during the transition from expansion to compression, matching the geometric changes observed.
Comparisons with PCA, t‑SNE, UMAP, Isomap, LLE, and the original M‑PHATE reveal that only MM‑PHATE simultaneously preserves temporal continuity, maintains inter‑unit community structure, and provides meaningful entropy‑based summaries. Parameter sensitivity analyses (varying k and ε) show that the visualizations are stable across a reasonable range.
In summary, MM‑PHATE offers a unified, interpretable view of RNN hidden‑state dynamics across both the temporal axis of sequences and the training axis of epochs. By coupling diffusion‑based geometry with a multi‑dimensional graph that respects both intra‑step and inter‑step similarities, the method uncovers information‑theoretic phases (expansion/compression) and links them to functional performance metrics. The authors suggest future extensions to non‑recurrent architectures such as Transformers, real‑time monitoring for early detection of over‑fitting, and integration with model‑based control of training curricula.
Comments & Academic Discussion
Loading comments...
Leave a Comment