StochTree: BART-based modeling in R and Python
stochtree is a C++ library for Bayesian tree ensemble models such as BART and Bayesian Causal Forests (BCF), as well as user-specified variations. Unlike previous BART packages, stochtree provides bindings to both R and Python for full interoperability. stochtree boasts a more comprehensive range of models relative to previous packages, including heteroskedastic forests, random effects, and treed linear models. Additionally, stochtree offers flexible handling of model fits: the ability to save model fits, reinitialize models from existing fits (facilitating improved model initialization heuristics), and pass fits between R and Python. On both platforms, stochtree exposes lower-level functionality, allowing users to specify models incorporating Bayesian tree ensembles without needing to modify C++ code. We illustrate the use of stochtree in three settings: i) straightfoward applications of existing models such as BART and BCF, ii) models that include more sophisticated components like heteroskedasticity and leaf-wise regression models, and iii) as a component of custom MCMC routines to fit nonstandard tree ensemble models.
💡 Research Summary
The paper introduces stochTree, a modern C++ library for Bayesian tree‑ensemble models such as Bayesian Additive Regression Trees (BART) and Bayesian Causal Forests (BCF). Unlike earlier BART implementations that are confined to a single language, stochTree provides seamless bindings for both R and Python, enabling users to fit models, serialize results, and transfer fitted objects across the two environments without leaving the high‑level language. The library consolidates a wide array of recent methodological extensions—heteroskedastic forests, group‑level random effects, leaf‑wise linear regression, multivariate leaf models, monotone regression, survival analysis, and more—into a single package, thereby reducing the need for custom C++ development.
A central technical contribution is the incorporation of a warm‑start strategy based on the X‑BART “grow‑from‑root” (GFR) algorithm. By running a small number of fast GFR iterations before the full MCMC, stochTree can initialize the Markov chain in a high‑likelihood region, dramatically shortening burn‑in periods, especially for large data sets. Users can control the number of GFR steps (num_gfr), burn‑in iterations (num_burnin), and total MCMC draws (num_mcmc) directly in the high‑level bart() or bcf() functions. Hyper‑parameters such as the CGM98 tree‑growth prior (α, β), number of trees, minimum leaf size, and maximum depth are passed via the mean_forest_params list, while global settings (including priors for leaf‑mean variance) reside in general_params.
StochTree’s API is deliberately dual‑layered. The high‑level interface (stochtree::bart, stochtree::bcf in R; BARTModel, BCFModel in Python) handles data preprocessing, model fitting, and prediction with a syntax mirroring familiar R packages. The low‑level objects (bartmodel, bcfmodel) expose the underlying C++ structures, allowing advanced users to modify tree priors, implement custom leaf functions, or embed the ensemble as a component of a larger MCMC scheme. Model objects can be saved to JSON strings (saveBARTModelToJsonString) and reloaded (createBARTModelFromJsonString), facilitating checkpointing, extended sampling, or cross‑language workflows.
Four illustrative examples demonstrate the library’s breadth. 1) Supervised learning on the Friedman function showcases a standard BART fit with 200 trees, warm‑started by 20 GFR iterations, followed by 10 000 MCMC draws; trace plots of the global error variance illustrate rapid convergence. 2) Causal inference with BCF on a synthetic treatment dataset highlights the built‑in causal forest model and its own warm‑start variant, enabling efficient estimation of heterogeneous treatment effects. 3) Hierarchical modeling on a semi‑synthetic CIC dataset incorporates group‑level random effects and heteroskedastic error variance, illustrating how stochTree can handle clustered data and non‑constant noise. 4) Regression discontinuity on an academic probation dataset employs leaf‑wise linear regression, demonstrating that tree‑based linear models can seamlessly encode design‑specific structures. In each case, the same workflow—data preparation, prior specification, sampling, prediction, and serialization—is reproduced in both R and Python, underscoring the library’s language‑agnostic design.
Overall, stochTree unifies recent advances in Bayesian tree modeling into a performant, extensible, and interoperable framework. By exposing low‑level C++ objects while offering high‑level language bindings, it lowers the barrier for researchers to prototype novel tree‑based priors, integrate ensembles into custom Bayesian pipelines, and leverage GPU or parallel computing extensions in the future. The authors anticipate that the library will accelerate the diffusion of cutting‑edge BART methodology into applied statistics, causal inference, and machine learning practice.
Comments & Academic Discussion
Loading comments...
Leave a Comment