Transformers learn factored representations
Transformers pretrained via next token prediction learn to factor their world into parts, representing these factors in orthogonal subspaces of the residual stream. We formalize two representational hypotheses: (1) a representation in the product space of all factors, whose dimension grows exponentially with the number of parts, or (2) a factored representation in orthogonal subspaces, whose dimension grows linearly. The factored representation is lossless when factors are conditionally independent, but sacrifices predictive fidelity otherwise, creating a tradeoff between dimensional efficiency and accuracy. We derive precise predictions about the geometric structure of activations for each, including the number of subspaces, their dimensionality, and the arrangement of context embeddings within them. We test between these hypotheses on transformers trained on synthetic processes with known latent structure. Models learn factored representations when factors are conditionally independent, and continue to favor them early in training even when noise or hidden dependencies undermine conditional independence, reflecting an inductive bias toward factoring at the cost of fidelity. This provides a principled explanation for why transformers decompose the world into parts, and suggests that interpretable low dimensional structure may persist even in models trained on complex data.
💡 Research Summary
This paper investigates whether transformer language models, when pretrained on next‑token prediction, internally factorize the world into independent components and represent each component in orthogonal subspaces of the residual stream. The authors develop a rigorous theoretical framework based on Generalized Hidden Markov Models (GHMMs) to formalize two competing representational hypotheses.
The first hypothesis, the “joint representation,” stores the predictive state in the full tensor‑product space of all latent factors. For N factors with dimensions d₁,…,d_N, this requires (∏ₙ dₙ − 1) dimensions, guaranteeing lossless prediction but incurring exponential growth in dimensionality. The second hypothesis, the “factored representation,” stores each factor’s predictive vector in its own low‑dimensional subspace, yielding a direct‑sum geometry that needs only Σₙ (dₙ − 1) dimensions. This representation is linearly efficient but becomes lossy when the latent factors are not conditionally independent.
Conditional independence is defined as the ability to write each token‑conditioned transition operator T(x) as a tensor product of per‑factor operators: T(x)=⊗ₙ Tⁿ(x). Under this condition, the predictive vector after any context remains a product state η(x₁:ℓ)=⊗ₙ ηₙ(x₁:ℓ). The authors prove (Proposition 2.2 and Theorem 2.3) that product states form a low‑dimensional manifold that can be mapped losslessly to a set of orthogonal subspaces, establishing that a factored representation is sufficient and exact when the data‑generating process is conditionally independent.
From these theoretical results the authors derive three concrete, testable predictions (the “Factored World Hypothesis”, FWH): (1) the dimensionality of transformer activations should match Σₙ (dₙ − 1) rather than the exponential joint dimension; (2) these dimensions should organize into N mutually orthogonal subspaces, one per latent factor; (3) the model should exhibit this structure even when it has enough capacity to represent the full joint space, and even when the data only approximately satisfies conditional independence.
To evaluate these predictions, the authors construct synthetic sequence datasets generated by GHMMs with five latent factors: three “Mess3” 3‑state HMMs and two “Bloch Walk” 3‑dimensional GHMMs. Sub‑tokens from each factor are combined via a Cartesian product into a single observed integer token, so the model never sees the factor structure explicitly. Three experimental regimes are explored: (a) fully independent factors (pure tensor‑product dynamics), (b) conditionally independent factors with a dependency chain (each factor’s transition may depend on earlier factors but the overall operator still factorizes), and (c) non‑conditionally independent factors where cross‑factor correlations persist after conditioning on the observed token.
The authors probe trained transformers (various sizes, up to 12‑layer, 768‑dim) by performing PCA on the residual stream, measuring explained variance, and testing linear orthogonality of the top components. They also fit linear maps from activations to each factor’s predictive distribution to verify that each factor is encoded in a distinct subspace.
Results show that in regime (a) the model’s activations collapse onto exactly 10 dimensions (Σₙ (dₙ − 1) for the five factors) and these dimensions are mutually orthogonal, confirming the factored geometry. In regime (b), despite the presence of inter‑factor dependencies, the model still learns an orthogonal factorized layout early in training; only later does a modest increase in dimensionality occur, reflecting the need to capture residual correlations. In regime (c), the model initially adopts the factored layout, even though it incurs a measurable loss in next‑token accuracy; as training proceeds, the model gradually expands its representation toward the full joint space, reducing loss but sacrificing the strict orthogonal factorization.
These observations support the authors’ claim that transformers possess an inductive bias toward low‑dimensional, orthogonal factorizations. The bias is strong enough to dominate early learning dynamics, even when it temporarily harms predictive performance. The bias can be interpreted as a preference for “dimension‑efficiency” over immediate fidelity, with the model later adjusting when the loss gradient pressures it to recover lost information.
The paper concludes that the trade‑off between dimensional efficiency and predictive fidelity explains why transformers naturally decompose inputs into parts. This insight has practical implications: the existence of persistent orthogonal subspaces suggests avenues for model interpretability, parameter pruning, and modular fine‑tuning, especially in real‑world data where latent compositional structure is likely present but not explicit. The work thus bridges a theoretical understanding of latent factorization with empirical evidence, offering a principled explanation for the emergence of interpretable low‑dimensional structure in modern transformer models.
Comments & Academic Discussion
Loading comments...
Leave a Comment