POP: Online Structural Pruning Enables Efficient Inference of Large Foundation Models
Large foundation models (LFMs) achieve strong performance through scaling, yet current structural pruning methods derive fixed pruning decisions during inference, overlooking sparsity patterns that emerge in the autoregressive token generation. In this paper, we propose POP (Partition-guided Online Pruning), an efficient online structural pruning framework that enables context-conditioned dynamic pruning with minimal computational overhead. POP partitions model channels into retained, candidate, and pruned regions, where prefilling defines a coarse pruning partition, and the decoding stage generates a fine-grained mask within the candidate region, avoiding full-channel re-evaluation. The coarse pruning partition preserves consistently important weights, while the fine-grained masking provides context-conditioned variation during decoding. Moreover, POP is a lightweight, plug-and-play method that requires no preprocessing, including offline calibration, retraining, or learning predictors. Extensive evaluations across diverse LFMs, including large language models (LLMs), mixture-of-experts models (MoEs), and vision-language models (VLMs), demonstrate that POP consistently delivers higher accuracy than existing pruning approaches while incurring smaller computational overhead and minimizing inference latency.
💡 Research Summary
The paper introduces POP (Partition‑guided Online Pruning), a novel framework for structural pruning that operates dynamically during the autoregressive inference of large foundation models (LFMs) such as LLMs, MoEs, and vision‑language models. Traditional structural pruning methods compute a static mask offline—often after a calibration or fine‑tuning step—and then apply the same mask throughout inference. This static approach ignores the fact that, during generation, the model’s activation patterns change dramatically from the prompt (prefill) stage to each subsequent token decoding step. The authors term this phenomenon “contextual sparsity”: a subset of neurons (or channels) becomes important depending on the current token context.
POP addresses contextual sparsity with a two‑stage process. In the calibration (prefill) stage, channel‑wise importance scores are computed using a simple weight‑activation product: I_i,k = |W_i,k|·‖X_k‖_2, where X_k aggregates activations of input channel k over the whole prompt. These element‑wise scores are aggregated (e.g., by sum or mean) to obtain an output‑channel importance vector I_out. Based on I_out, channels are partitioned into three regions:
- Retained – consistently high‑importance channels that are always kept.
- Pruned – consistently low‑importance channels that are permanently removed.
- Candidate – intermediate‑importance channels that may be kept or dropped depending on the decoding context.
During each decoding step t, POP recomputes a lightweight importance estimate for only the candidate channels using the current token’s activation X_t: I_i(t) = |W_i|·‖X_t‖_2. A predefined proportion (e.g., top 10‑20% of candidates) is then selected, and a binary mask is applied to the candidate region only. Because the retained and pruned regions are fixed, the dynamic step touches only a small subset of the model, incurring negligible overhead (reported as <3 % FLOPs increase). No additional forward passes, probing, or learned predictors are required.
Key contributions include:
- Demonstrating that static pruning decisions made at prefill time cause systematic bias for later tokens, especially on long‑form generation tasks.
- Proposing a tri‑state channel partition that balances global stability (retained) with local adaptability (candidate).
- Providing a training‑free, plug‑and‑play method that works across model families without hardware‑specific sparse kernels.
Empirical evaluation spans Llama2‑7B, Llama3.1‑8B, several mixture‑of‑experts models, and CLIP‑based vision‑language models. Across a variety of benchmarks (short‑form QA such as ARC‑C, long‑form generation like MBPP, and multimodal tasks), POP consistently outperforms static baselines (e.g., Tyr, LLM‑Pruner, FLAP) at comparable pruning ratios (20‑30 %). Notably, on MBPP a 20 % static prune drops performance to ~35 % of the dense model, whereas POP recovers it to >70 %. FLOPs overhead remains under 2.85 %, and overall inference speed improves by 1.29× on Llama2‑7B. The method also shows robustness to different candidate‑region sizes, allowing practitioners to trade accuracy for latency.
Limitations are acknowledged: the proportion of candidate channels is manually set, and the current implementation focuses on feed‑forward networks (FFNs), though the authors note that extending to attention heads is straightforward. Future work could integrate adaptive selection of the candidate budget, possibly via a lightweight predictor or meta‑learning, and explore more fine‑grained token‑level importance estimation.
In summary, POP offers a practical, low‑overhead solution for online, context‑aware structural pruning, bridging the gap between the high accuracy of dense LFMs and the efficiency demands of real‑world deployment. Its simplicity, model‑agnostic nature, and strong empirical gains make it a compelling addition to the toolbox for LFM inference optimization.
Comments & Academic Discussion
Loading comments...
Leave a Comment