CauScale: Neural Causal Discovery at Scale
Causal discovery is essential for advancing data-driven fields such as scientific AI and data analysis, yet existing approaches face significant time- and space-efficiency bottlenecks when scaling to large graphs. To address this challenge, we present CauScale, a neural architecture designed for efficient causal discovery that scales inference to graphs with up to 1000 nodes. CauScale improves time efficiency via a reduction unit that compresses data embeddings and improves space efficiency by adopting tied attention weights to avoid maintaining axis-specific attention maps. To keep high causal discovery accuracy, CauScale adopts a two-stream design: a data stream extracts relational evidence from high-dimensional observations, while a graph stream integrates statistical graph priors and preserves key structural signals. CauScale successfully scales to 500-node graphs during training, where prior work fails due to space limitations. Across testing data with varying graph scales and causal mechanisms, CauScale achieves 99.6% mAP on in-distribution data and 84.4% on out-of-distribution data, while delivering 4-13,000 times inference speedups over prior methods. Our project page is at https://github.com/OpenCausaLab/CauScale.
💡 Research Summary
CauScale is a novel neural architecture designed to overcome the severe time and memory bottlenecks that have limited existing causal discovery methods when scaling to large graphs. The authors identify three primary inefficiencies in current amortized (zero‑shot) approaches such as AVICI: (1) the attention mechanism grows quadratically with the number of variables, (2) the number of observational samples (often orders of magnitude larger than the number of variables) is processed naively, and (3) the models lack a principled way to fuse high‑dimensional data evidence with structural priors.
To address these issues, CauScale introduces a two‑stream design consisting of a data stream and a graph stream. The data stream processes the raw observational matrix D (including an intervention indicator) through linear embedding and axial attention, while the graph stream starts from a statistical prior ρ (the inverse covariance of D) and maintains an n × n embedding that captures potential causal links. The streams interact via a Data‑2‑Graph block. Inside this block, the data stream’s attention output is pooled across the sample dimension, passed through two separate Feed‑Forward Networks to obtain node‑level vectors u and v, and combined as ω = u vᵀ, a dense relational matrix that encodes directed pairwise evidence. This matrix is concatenated with the current graph embedding and projected back to the original dimensionality, thereby injecting data‑driven relational cues into the graph stream without discarding structural information.
A key innovation is the Reduction Unit, which periodically compresses the data‑stream embedding along the observation axis by average‑pooling groups of r samples. Because causal signals are primarily expressed through dependencies among variables within each sample, aggregating across samples after several Data‑2‑Graph transformations preserves essential information while reducing the effective sample length from m to m/r⁽⌊b/k⌋⁾ after b blocks (k blocks between reductions). This yields a theoretical reduction of the sample‑axis attention cost from O(n m²) to roughly O(n (m/r)²) and similarly reduces node‑axis costs.
Memory consumption is further curtailed through tied attention weights. Standard axial attention stores separate attention maps for each target row (or column), requiring O(R H C²) memory. By sharing the attention weight matrix across the target axis (as proposed by Rao et al., 2021), CauScale stores only a single H × C × C tensor, cutting memory usage by a factor of R. This is especially beneficial when the number of samples R = m far exceeds the number of variables C = n.
The final prediction head operates on the graph‑stream output after the last block. For each unordered node pair (i, j), a small feed‑forward network processes the concatenated embeddings of (i, j) and (j, i) to produce logits for three states: no edge, i → j, and j → i. A softmax yields edge probabilities, and the model does not enforce acyclicity during inference, allowing it to naturally handle real‑world data that may contain cycles.
Experimental evaluation spans synthetic datasets with varying graph sizes (50–1000 nodes) and causal mechanisms, as well as single‑cell RNA‑seq data. CauScale achieves 99.6 % mean average precision (mAP) on in‑distribution test sets and 84.4 % mAP on out‑of‑distribution (OOD) sets, outperforming prior state‑of‑the‑art methods. Inference speedups range from 4× to 13,000× relative to baselines, with 1000‑node graphs processed in sub‑second time. Training scalability is demonstrated up to 500‑node graphs, a regime where AVICI runs out of memory.
The paper’s contributions are threefold: (i) introducing the first large‑scale pre‑training framework for amortized causal discovery, (ii) proposing architectural components (reduction unit, tied attention, two‑stream data‑graph interaction) that jointly improve computational efficiency and discovery accuracy, and (iii) providing extensive empirical evidence that these components enable both speed and performance gains across diverse causal settings. Limitations include the lack of explicit DAG enforcement (which may be required in certain domains) and reliance on the inverse‑covariance prior, which could be sensitive to non‑Gaussian or highly nonlinear data. Future work may explore adaptive priors, post‑hoc acyclicity projection, and extensions to heterogeneous data modalities.
Comments & Academic Discussion
Loading comments...
Leave a Comment