KLASS: KL-Guided Fast Inference in Masked Diffusion Models
Masked diffusion models have demonstrated competitive results on various tasks including language generation. However, due to its iterative refinement process, the inference is often bottlenecked by slow and static sampling speed. To overcome this problem, we introduce `KL-Adaptive Stability Sampling’ (KLASS), a fast yet effective sampling method that exploits token-level KL divergence to identify stable, high-confidence predictions. By unmasking multiple tokens in each iteration without any additional model training, our approach speeds up generation significantly while maintaining sample quality. On reasoning benchmarks, KLASS achieves up to $2.78\times$ wall-clock speedups while improving performance over standard greedy decoding, attaining state-of-the-art results among diffusion-based samplers. We further validate KLASS across diverse domains, including text, image, and molecular generation, showing its effectiveness as a broadly applicable sampler across different models.
💡 Research Summary
Masked diffusion models (MDMs) have become a powerful paradigm for generating sequences, images, and even molecular structures by iteratively denoising a fully masked input. Despite their strong performance, the inference process remains a bottleneck because each diffusion step typically unmasks only a few tokens according to a fixed schedule (e.g., Top‑k, stochastic sampling) or relies on an auxiliary planner that adds computational overhead. Consequently, generation can be slow and prone to sub‑optimal token choices, especially on complex reasoning tasks.
The paper introduces KL‑Adaptive Stability Sampling (KLASS), a training‑free sampler that exploits two signals already produced by the diffusion model: (1) a confidence score, defined as the maximum probability of the token’s categorical distribution, and (2) a token‑level Kullback‑Leibler (KL) divergence measuring the change of the token’s distribution between consecutive diffusion steps. A token is considered “stable” if its confidence exceeds a threshold τ and its KL score stays below a threshold ε_KL for a short history of n steps. All stable tokens are unmasked in parallel at the current step; if no stable token exists, a fallback unmasking of the top‑u most confident tokens is performed. This dynamic, data‑driven selection replaces static schedules and eliminates the need for any extra model or planner.
The authors provide a theoretical justification: under a well‑trained model that approximates the true conditional distribution within δ, any token that is wrong in the final context must exhibit a non‑negligible per‑step KL divergence along the reverse diffusion path. Therefore, low KL is a reliable proxy for correctness, and delaying unmasking until KL stabilizes reduces the risk of premature, erroneous decisions.
Empirically, KLASS is evaluated on a suite of reasoning benchmarks (GSM8K, MATH, HumanEval, MBPP) using large‑scale language diffusion models (LLaDA, Dream). Compared with standard greedy, Top‑1/Top‑2, confidence‑only, and KL‑only baselines, KLASS consistently achieves higher accuracy (3–5 percentage points improvement) while cutting the number of diffusion steps roughly in half. Wall‑clock speedups reach up to 2.78×. The method also generalizes to image synthesis and molecular generation, where it improves FID scores and validity rates respectively. Ablation studies confirm that both confidence and KL thresholds are essential: too low ε_KL hampers speed, while too high τ degrades quality.
In summary, KLASS offers a simple yet effective way to accelerate masked diffusion inference by leveraging the model’s own internal dynamics. It requires no additional training, adds minimal memory overhead, and works across modalities. Limitations include the need to tune ε_KL and τ per domain and the O(vocab) cost of computing KL for large vocabularies, which could be mitigated with caching or approximate top‑k KL estimation. Nonetheless, KLASS represents a practical advancement toward fast, high‑quality generation with masked diffusion models.
Comments & Academic Discussion
Loading comments...
Leave a Comment