JAX-Privacy: A library for differentially private machine learning
JAX-Privacy is a library designed to simplify the deployment of robust and performant mechanisms for differentially private machine learning. Guided by design principles of usability, flexibility, and efficiency, JAX-Privacy serves both researchers requiring deep customization and practitioners who want a more out-of-the-box experience. The library provides verified, modular primitives for critical components for all aspects of the mechanism design including batch selection, gradient clipping, noise addition, accounting, and auditing, and brings together a large body of recent research on differentially private ML.
💡 Research Summary
JAX‑Privacy is a comprehensive library designed to streamline the development and deployment of differentially private (DP) machine‑learning pipelines within the JAX ecosystem. The authors begin by outlining the practical challenges that arise when implementing DP‑ML, such as efficient per‑example gradient clipping, the gap between theoretical i.i.d. Poisson sampling and real‑world fixed‑size batch loading, and the risk of “silent failures” where a flawed privacy mechanism appears to work correctly. To address these issues, JAX‑Privacy provides a set of modular, verified primitives covering the entire DP workflow: batch selection, gradient clipping, noise addition, privacy accounting, and empirical auditing.
The batch selection API is deliberately framework‑agnostic, returning only integer indices via a pure NumPy interface. This design enables seamless integration with in‑memory arrays as well as disk‑based data loaders (e.g., pygrain). Because random‑access datasets often produce variable‑size batches, the library offers utilities to pad batches to a fixed size, thereby reducing the number of JIT recompilations required during training.
Per‑example clipping, historically a trade‑off between sequential loops (low memory, poor utilization) and full vectorization (high memory), is implemented using a higher‑order composition of jax.vmap, jax.grad, and jax.lax.scan. A user‑configurable micro‑batch size lets practitioners interpolate between the two extremes, achieving memory‑efficient execution without sacrificing speed. The returned gradient function automatically attaches a sensitivity attribute, preventing “ghost clipping” bugs and ensuring that the noise scale matches the true ℓ₂‑sensitivity of the clipped gradients.
Noise addition is exposed through the standard optax.GradientTransformation API, allowing a drop‑in replacement of the optimizer’s update rule. This abstraction supports simple i.i.d. Gaussian noise (DP‑SGD) as well as more sophisticated correlated‑noise schemes such as DP‑FTRL and matrix‑factorization‑based DP‑MF. The implementation exploits the embarrassingly parallel nature of noise generation, scaling across multiple machines with minimal configuration.
For privacy accounting, JAX‑Privacy builds directly on the dp‑accounting library, inheriting state‑of‑the‑art Rényi DP, Gaussian DP, and group‑privacy calculators. Consequently, users can obtain tight ε‑δ guarantees for a wide range of mechanisms, including truncated‑batch DP‑SGD, group‑privacy extensions, and DP‑BandMF.
The auditing component implements a canary‑insertion membership‑inference framework. By inserting synthetic “canary” examples into the training data with a known probability and then measuring the model’s ability to distinguish their presence, the library computes an empirical privacy leakage ε_emp. Comparing ε_emp to the theoretical bound ε_theory provides a practical sanity check for implementation bugs or overly loose accounting. The audit suite includes recent attacks and metrics from the literature (e.g., Nasr et al., Steinke et al., Mahloujifar et al.).
To cater to both researchers and production engineers, JAX‑Privacy also offers a high‑level Keras‑style API. This wrapper abstracts away the low‑level training loop, allowing users to declare a Keras model, dataset, and privacy budget in a concise configuration object. While less flexible than the core primitives, the Keras API guarantees a correct and performant DP implementation out‑of‑the‑box.
Performance experiments compare JAX‑Privacy against plain JAX, Opacus (the PyTorch DP library), and non‑DP PyTorch across three model families (CNN, State‑Space, Transformer) and three scales (≈1 M, 10 M, 100 M parameters). Throughput is measured as examples processed per second over 50 dummy‑training iterations. Results show that JAX‑Privacy achieves 40‑90 % of the throughput of non‑private JAX, and is competitive with Opacus (often within 0.5‑1.0×). Notably, for State‑Space models the gap narrows dramatically, indicating that the library’s micro‑batch clipping and noise pipelines are highly efficient for sequence‑heavy workloads.
Finally, the authors describe large‑scale internal deployments at Google, where JAX‑Privacy was used to fine‑tune multi‑billion‑parameter Gemma models across thousands of machines with DP‑FTRL and to pre‑train a 1 B‑parameter model with DP‑MF. In these settings, users are responsible for sharding data and model parameters, while JAX‑Privacy guarantees that DP‑specific intermediates are correctly sharded and that privacy accounting remains accurate.
In summary, JAX‑Privacy delivers a unified, well‑tested, and scalable stack for differentially private machine learning on JAX. By exposing low‑level primitives for custom research and a high‑level Keras interface for production, it bridges the gap between theoretical DP mechanisms and practical, large‑scale deployment, while providing tools for rigorous verification through auditing and tight accounting.
Comments & Academic Discussion
Loading comments...
Leave a Comment