| Title: | Stochastic Tree Ensembles (XBART and BART) for Supervised Learning and Causal Inference |
|---|---|
| Description: | Flexible stochastic tree ensemble software. Robust implementations of Bayesian Additive Regression Trees (BART) (Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>) for supervised learning and Bayesian Causal Forests (BCF) (Hahn, Murray, Carvalho (2020) <doi:10.1214/19-BA1195>) for causal inference. Enables model serialization and parallel sampling and provides a low-level interface for custom stochastic forest samplers. Includes the grow-from-root algorithm for accelerated forest sampling (He and Hahn (2021) <doi:10.1080/01621459.2021.1942012>), a log-linear leaf model for forest-based heteroskedasticity (Murray (2020) <doi:10.1080/01621459.2020.1813587>), and the cloglog BART model of Alam and Linero (2025) <doi:10.48550/arXiv.2502.00606> for ordinal outcomes. |
| Authors: | Drew Herren [aut, cre] (ORCID: <https://orcid.org/0000-0003-4109-6611>), Richard Hahn [aut], Jared Murray [aut], Carlos Carvalho [aut], Jingyu He [aut], Pedro Lima [ctb], Entejar Alam [ctb], stochtree contributors [cph], Eigen contributors [cph] (C++ source uses the Eigen library for matrix operations, see inst/COPYRIGHTS), xgboost contributors [cph] (C++ tree code and related operations include or are inspired by code from the xgboost library, see inst/COPYRIGHTS), treelite contributors [cph] (C++ tree code and related operations include or are inspired by code from the treelite library, see inst/COPYRIGHTS), Microsoft Corporation [cph] (C++ I/O and various project structure code include or are inspired by code from the LightGBM library, which is a copyright of Microsoft, see inst/COPYRIGHTS), Niels Lohmann [cph] (C++ source uses the JSON for Modern C++ library for JSON operations, see inst/COPYRIGHTS), Daniel Lemire [cph] (C++ source uses the fast_double_parser library internally, see inst/COPYRIGHTS), Victor Zverovich [cph] (C++ source uses the fmt library internally, see inst/COPYRIGHTS) |
| Maintainer: | Drew Herren <[email protected]> |
| License: | MIT + file LICENSE |
| Version: | 0.4.4 |
| Built: | 2026-06-05 21:44:35 UTC |
| Source: | https://github.com/stochastictree/stochtree |
Flexible stochastic tree ensemble software. Robust implementations of Bayesian Additive Regression Trees (BART) (Chipman, George, McCulloch (2010) doi:10.1214/09-AOAS285) for supervised learning and Bayesian Causal Forests (BCF) (Hahn, Murray, Carvalho (2020) doi:10.1214/19-BA1195) for causal inference. Enables model serialization and parallel sampling and provides a low-level interface for custom stochastic forest samplers. Includes the grow-from-root algorithm for accelerated forest sampling (He and Hahn (2021) doi:10.1080/01621459.2021.1942012), a log-linear leaf model for forest-based heteroskedasticity (Murray (2020) doi:10.1080/01621459.2020.1813587), and the cloglog BART model of Alam and Linero (2025) doi:10.48550/arXiv.2502.00606 for ordinal outcomes.
Maintainer: Drew Herren [email protected] (ORCID)
Authors:
Richard Hahn
Jared Murray
Carlos Carvalho
Jingyu He
Other contributors:
Pedro Lima [contributor]
Entejar Alam [contributor]
stochtree contributors [copyright holder]
Eigen contributors (C++ source uses the Eigen library for matrix operations, see inst/COPYRIGHTS) [copyright holder]
xgboost contributors (C++ tree code and related operations include or are inspired by code from the xgboost library, see inst/COPYRIGHTS) [copyright holder]
treelite contributors (C++ tree code and related operations include or are inspired by code from the treelite library, see inst/COPYRIGHTS) [copyright holder]
Microsoft Corporation (C++ I/O and various project structure code include or are inspired by code from the LightGBM library, which is a copyright of Microsoft, see inst/COPYRIGHTS) [copyright holder]
Niels Lohmann (C++ source uses the JSON for Modern C++ library for JSON operations, see inst/COPYRIGHTS) [copyright holder]
Daniel Lemire (C++ source uses the fast_double_parser library internally, see inst/COPYRIGHTS) [copyright holder]
Victor Zverovich (C++ source uses the fmt library internally, see inst/COPYRIGHTS) [copyright holder]
Useful links:
Report bugs at https://github.com/StochasticTree/stochtree/issues
Run the BART algorithm for supervised learning.
bart( X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, general_params = list(), mean_forest_params = list(), variance_forest_params = list(), random_effects_params = list() )bart( X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, general_params = list(), mean_forest_params = list(), variance_forest_params = list(), random_effects_params = list() )
X_train |
Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata that the column is ordered categorical). |
y_train |
Outcome to be modeled by the ensemble. |
leaf_basis_train |
(Optional) Bases used to define a regression model |
rfx_group_ids_train |
(Optional) Group labels used for an additive random effects model. |
rfx_basis_train |
(Optional) Basis for "random-slope" regression in an additive random effects model.
If |
X_test |
(Optional) Test set of covariates used to define "out of sample" evaluation data.
May be provided either as a dataframe or a matrix, but the format of |
leaf_basis_test |
(Optional) Test set of bases used to define "out of sample" evaluation data.
While a test set is optional, the structure of any provided test set must match that
of the training set (i.e. if both |
rfx_group_ids_test |
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
rfx_basis_test |
(Optional) Test set basis for "random-slope" regression in additive random effects model. |
observation_weights |
(Optional) Numeric vector of observation weights of length |
num_gfr |
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. |
num_burnin |
Number of "burn-in" iterations of the MCMC sampler. Default: 0. |
num_mcmc |
Number of "retained" iterations of the MCMC sampler. Default: 100. |
previous_model_json |
(Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: |
previous_model_warmstart_sample_num |
(Optional) Sample number from |
general_params |
(Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
|
mean_forest_params |
(Optional) A list of mean forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
variance_forest_params |
(Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
random_effects_params |
(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
|
List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10)n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10)
BART models contains external pointers to C++ objects, which means they cannot
be correctly serialized to .Rds from an R session in their default state.
This group of serialization functions allow us to convert between C++ data structures and a persistent JSON
representation. The CppJson class wraps a performant C++ JSON API, and the functions
saveBARTModelToJson and createBARTModelFromJson save to and load from this format.
This representation, of course, also relies on external C++ pointers, so in order to
save and reload BART models across sessions, we provide two other interfaces.
saveBARTModelToJsonString converts a BART model to an in-memory string containing the model's
JSON representation and createBARTModelFromJsonString converts this representation back to a BART model object.
saveBARTModelToJsonFile and createBARTModelFromJsonFile save or reload a BART model
directly to / from a .json file.
Finally, for cases in which multiple BART models have been sampled (for instance, multiple processes
run via doParallel), we offer createBARTModelFromCombinedJson and createBARTModelFromCombinedJsonString for
loading a new combined BART model from a list of BART JSON objects / strings.
saveBARTModelToJson(object) saveBARTModelToJsonFile(object, filename) saveBARTModelToJsonString(object) createBARTModelFromJson(json_object) createBARTModelFromJsonFile(json_filename) createBARTModelFromJsonString(json_string) createBARTModelFromCombinedJson(json_object_list) createBARTModelFromCombinedJsonString(json_string_list)saveBARTModelToJson(object) saveBARTModelToJsonFile(object, filename) saveBARTModelToJsonString(object) createBARTModelFromJson(json_object) createBARTModelFromJsonFile(json_filename) createBARTModelFromJsonString(json_string) createBARTModelFromCombinedJson(json_object_list) createBARTModelFromCombinedJsonString(json_string_list)
object |
Object of type |
filename |
String of filepath, must end in ".json" |
json_object |
Object of type |
json_filename |
String of filepath, must end in ".json" |
json_string |
JSON string dump |
json_object_list |
List of objects of type |
json_string_list |
List of JSON strings which can be parsed to objects of type |
saveBARTModelToJson return an object of type CppJson.
saveBARTModelToJsonString returns a string dump of the BART model's JSON representation.
saveBARTModelToJsonFile returns nothing, but writes to the provided filename.
createBARTModelFromJson, createBARTModelFromJsonFile, createBARTModelFromJsonString,
createBARTModelFromCombinedJson, and createBARTModelFromCombinedJsonString all return
objects of type bartmodel.
# Generate data n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) y <- X[,1] + rnorm(n, 0, 1) # Sample BART model bart_model <- bart(X_train = X, y_train = y, num_gfr = 0, num_burnin = 0, num_mcmc = 10) # Save to in-memory JSON bart_json <- saveBARTModelToJson(bart_model) # Save to JSON string bart_json_string <- saveBARTModelToJsonString(bart_model) # Save to JSON file tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) # Reload BART model from in-memory JSON object bart_model_roundtrip <- createBARTModelFromJson(bart_json) # Reload BART model from JSON string bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) # Reload BART model from JSON file bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson)) unlink(tmpjson) # Reload BART model from list of JSON objects bart_model_roundtrip <- createBARTModelFromCombinedJson(list(bart_json)) # Reload BART model from list of JSON strings bart_model_roundtrip <- createBARTModelFromCombinedJsonString(list(bart_json_string))# Generate data n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) y <- X[,1] + rnorm(n, 0, 1) # Sample BART model bart_model <- bart(X_train = X, y_train = y, num_gfr = 0, num_burnin = 0, num_mcmc = 10) # Save to in-memory JSON bart_json <- saveBARTModelToJson(bart_model) # Save to JSON string bart_json_string <- saveBARTModelToJsonString(bart_model) # Save to JSON file tmpjson <- tempfile(fileext = ".json") saveBARTModelToJsonFile(bart_model, file.path(tmpjson)) # Reload BART model from in-memory JSON object bart_model_roundtrip <- createBARTModelFromJson(bart_json) # Reload BART model from JSON string bart_model_roundtrip <- createBARTModelFromJsonString(bart_json_string) # Reload BART model from JSON file bart_model_roundtrip <- createBARTModelFromJsonFile(file.path(tmpjson)) unlink(tmpjson) # Reload BART model from list of JSON objects bart_model_roundtrip <- createBARTModelFromCombinedJson(list(bart_json)) # Reload BART model from list of JSON strings bart_model_roundtrip <- createBARTModelFromCombinedJsonString(list(bart_json_string))
Run the Bayesian Causal Forest (BCF) algorithm for regularized causal effect estimation.
bcf( X_train, Z_train, y_train, propensity_train = NULL, rfx_group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, propensity_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, general_params = list(), prognostic_forest_params = list(), treatment_effect_forest_params = list(), variance_forest_params = list(), random_effects_params = list() )bcf( X_train, Z_train, y_train, propensity_train = NULL, rfx_group_ids_train = NULL, rfx_basis_train = NULL, X_test = NULL, Z_test = NULL, propensity_test = NULL, rfx_group_ids_test = NULL, rfx_basis_test = NULL, observation_weights = NULL, num_gfr = 5, num_burnin = 0, num_mcmc = 100, previous_model_json = NULL, previous_model_warmstart_sample_num = NULL, general_params = list(), prognostic_forest_params = list(), treatment_effect_forest_params = list(), variance_forest_params = list(), random_effects_params = list() )
X_train |
Covariates used to split trees in the ensemble. May be provided either as a dataframe or a matrix. Matrix covariates will be assumed to be all numeric. Covariates passed as a dataframe will be preprocessed based on the variable types (e.g. categorical columns stored as unordered factors will be one-hot encoded, categorical columns stored as ordered factors will passed as integers to the core algorithm, along with the metadata that the column is ordered categorical). |
Z_train |
Vector of (continuous or binary) treatment assignments. |
y_train |
Outcome to be modeled by the ensemble. |
propensity_train |
(Optional) Vector of propensity scores. If not provided, this will be estimated from the data.
If |
rfx_group_ids_train |
(Optional) Group labels used for an additive random effects model. |
rfx_basis_train |
(Optional) Basis for "random-slope" regression in an additive random effects model.
If |
X_test |
(Optional) Test set of covariates used to define "out of sample" evaluation data.
May be provided either as a dataframe or a matrix, but the format of |
Z_test |
(Optional) Test set of (continuous or binary) treatment assignments. |
propensity_test |
(Optional) Vector of propensity scores. If not provided, this will be estimated from the data. |
rfx_group_ids_test |
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
rfx_basis_test |
(Optional) Test set basis for "random-slope" regression in additive random effects model. |
observation_weights |
(Optional) Numeric vector of observation weights of length |
num_gfr |
Number of "warm-start" iterations run using the grow-from-root algorithm (He and Hahn, 2021). Default: 5. |
num_burnin |
Number of "burn-in" iterations of the MCMC sampler. Default: 0. |
num_mcmc |
Number of "retained" iterations of the MCMC sampler. Default: 100. |
previous_model_json |
(Optional) JSON string containing a previous BCF model. This can be used to "continue" a
sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest
samples. If the previous model used an internally estimated propensity score (i.e. |
previous_model_warmstart_sample_num |
(Optional) Sample number from |
general_params |
(Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
|
prognostic_forest_params |
(Optional) A list of prognostic forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
treatment_effect_forest_params |
(Optional) A list of treatment effect forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
variance_forest_params |
(Optional) A list of variance forest model parameters, each of which has a default value processed internally, so this argument list is optional.
|
random_effects_params |
(Optional) A list of random effects model parameters, each of which has a default value processed internally, so this argument list is optional.
|
List of sampling outputs and a wrapper around the sampled forests (which can be used for in-memory prediction on new data, or serialized to JSON on disk).
n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10)n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10)
BCF models contains external pointers to C++ objects, which means they cannot
be correctly serialized to .Rds from an R session in their default state.
These functions allow us to convert between C++ data structures and a persistent JSON
representation. The CppJson class wraps a performant C++ JSON API, and the functions
saveBCFModelToJson and createBCFModelFromJson save to and load from this format.
This representation, of course, also relies on external C++ pointers, so in order to
save and reload BCF models across sessions, we provide two other interfaces.
saveBCFModelToJsonFile and createBCFModelFromJsonFile save or reload a BCF model's JSON
representation directly to / from a .json file.
saveBCFModelToJsonString and createBCFModelFromJsonString handle in-memory strings containing JSON data,
which can be written to disk or passed between processes.
Finally, for cases in which multiple BCF models have been sampled (for instance, sampled in multiple processes
via doParallel), we offer createBCFModelFromCombinedJson and createBCFModelFromCombinedJsonString for
loading a new combined BCF model from a list of BCF JSON objects or strings.
saveBCFModelToJson(object) saveBCFModelToJsonFile(object, filename) saveBCFModelToJsonString(object) createBCFModelFromJson(json_object) createBCFModelFromJsonFile(json_filename) createBCFModelFromJsonString(json_string) createBCFModelFromCombinedJson(json_object_list) createBCFModelFromCombinedJsonString(json_string_list)saveBCFModelToJson(object) saveBCFModelToJsonFile(object, filename) saveBCFModelToJsonString(object) createBCFModelFromJson(json_object) createBCFModelFromJsonFile(json_filename) createBCFModelFromJsonString(json_string) createBCFModelFromCombinedJson(json_object_list) createBCFModelFromCombinedJsonString(json_string_list)
object |
Object of type |
filename |
String of filepath, must end in ".json" |
json_object |
Object of type |
json_filename |
String of filepath, must end in ".json" |
json_string |
JSON string dump |
json_object_list |
List of objects of type |
json_string_list |
List of JSON strings which can be parsed to objects of type |
saveBCFModelToJson return an object of type CppJson.
saveBCFModelToJsonFile returns nothing, but writes to the provided filename.
saveBCFModelToJsonString returns a string dump of the BCF model's JSON representation.
createBCFModelFromJson, createBCFModelFromJsonFile, createBCFModelFromJsonString,
createBCFModelFromCombinedJson, and createBCFModelFromCombinedJsonString all return
objects of type bcfmodel.
# Generate data n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) pi_X <- runif(n, 0.3, 0.7) Z <- rbinom(n, p = pi_X, size = 1) y <- X[,1] + Z + rnorm(n, 0, 1) # Sample BCF model bcf_model <- bcf(X_train = X, Z_train = Z, propensity_train = pi_X, y_train = y, num_gfr = 0, num_burnin = 0, num_mcmc = 10) # Save to in-memory JSON bcf_json <- saveBCFModelToJson(bcf_model) # Save to JSON string bcf_json_string <- saveBCFModelToJsonString(bcf_model) # Save to JSON file tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) # Reload BCF model from in-memory JSON object bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) # Reload BCF model from JSON string bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) # Reload BCF model from JSON file bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson)) unlink(tmpjson) # Reload BCF model from list of JSON objects bcf_model_roundtrip <- createBCFModelFromCombinedJson(list(bcf_json)) # Reload BCF model from list of JSON strings bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(list(bcf_json_string))# Generate data n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) pi_X <- runif(n, 0.3, 0.7) Z <- rbinom(n, p = pi_X, size = 1) y <- X[,1] + Z + rnorm(n, 0, 1) # Sample BCF model bcf_model <- bcf(X_train = X, Z_train = Z, propensity_train = pi_X, y_train = y, num_gfr = 0, num_burnin = 0, num_mcmc = 10) # Save to in-memory JSON bcf_json <- saveBCFModelToJson(bcf_model) # Save to JSON string bcf_json_string <- saveBCFModelToJsonString(bcf_model) # Save to JSON file tmpjson <- tempfile(fileext = ".json") saveBCFModelToJsonFile(bcf_model, file.path(tmpjson)) # Reload BCF model from in-memory JSON object bcf_model_roundtrip <- createBCFModelFromJson(bcf_json) # Reload BCF model from JSON string bcf_model_roundtrip <- createBCFModelFromJsonString(bcf_json_string) # Reload BCF model from JSON file bcf_model_roundtrip <- createBCFModelFromJsonFile(file.path(tmpjson)) unlink(tmpjson) # Reload BCF model from list of JSON objects bcf_model_roundtrip <- createBCFModelFromCombinedJson(list(bcf_json)) # Reload BCF model from list of JSON strings bcf_model_roundtrip <- createBCFModelFromCombinedJsonString(list(bcf_json_string))
Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022)
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). doi:10.1002/9781118445112.stat08288
calibrateInverseGammaErrorVariance( y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE )calibrateInverseGammaErrorVariance( y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE )
y |
Outcome to be modeled using BART, BCF or another nonparametric ensemble method. |
X |
Covariates to be used to partition trees in an ensemble or series of ensemble. |
W |
(Optional) Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: |
nu |
The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as |
quant |
(Optional) Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of |
standardize |
(Optional) Whether or not outcome should be standardized ( |
Value of lambda which determines the scale parameter of the global error variance prior (sigma^2 ~ IG(nu,nu*lambda))
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) y <- 10*X[,1] - 20*X[,2] + rnorm(n) nu <- 3 lambda <- calibrateInverseGammaErrorVariance(y, X, nu = nu) sigma2hat <- mean(resid(lm(y~X))^2) mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat)n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) y <- 10*X[,1] - 20*X[,2] + rnorm(n) nu <- 3 lambda <- calibrateInverseGammaErrorVariance(y, X, nu = nu) sigma2hat <- mean(resid(lm(y~X))^2) mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat)
Compute posterior credible intervals for specified terms from a fitted BART model. Supports intervals for mean functions, variance functions, random effects, and overall outcome predictions.
computeBARTPosteriorInterval( model_object, terms, level = 0.95, scale = "linear", X = NULL, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL )computeBARTPosteriorInterval( model_object, terms, level = 0.95, scale = "linear", X = NULL, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL )
model_object |
A fitted BART or BCF model object of class |
terms |
A character string specifying the model term(s) for which to compute intervals. Options for BART models are |
level |
A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). |
scale |
(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability". |
X |
A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions). |
leaf_basis |
An optional matrix of basis function evaluations for mean forest models with regression defined in the leaves. Required for "leaf regression" models. |
rfx_group_ids |
An optional vector of group IDs for random effects. Required if the requested term includes random effects. |
rfx_basis |
An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned.
n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) intervals <- computeBARTPosteriorInterval( model_object = bart_model, terms = c("mean_forest", "y_hat"), X = X, level = 0.90 )n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) intervals <- computeBARTPosteriorInterval( model_object = bart_model, terms = c("mean_forest", "y_hat"), X = X, level = 0.90 )
Compute posterior credible intervals for specified terms from a fitted BCF model. Supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions.
computeBCFPosteriorInterval( model_object, terms, level = 0.95, scale = "linear", X = NULL, Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL )computeBCFPosteriorInterval( model_object, terms, level = 0.95, scale = "linear", X = NULL, Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL )
model_object |
A fitted BCF model object of class |
terms |
A character string specifying the model term(s) for which to compute intervals. Options are The treatment effect terms follow a three-level hierarchy:
Similarly for the prognostic term: |
level |
A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). |
scale |
(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
X |
(Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). |
Z |
(Optional) A vector or matrix of treatment assignments. Required if the requested term is |
propensity |
(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. |
rfx_group_ids |
An optional vector of group IDs for random effects. Required if the requested term includes random effects. |
rfx_basis |
An optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. |
A list containing the lower and upper bounds of the credible interval for the specified term. If multiple terms are requested, a named list with intervals for each term is returned.
n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) pi_X <- pnorm(0.5 * X[,1]) Z <- rbinom(n, 1, pi_X) mu_X <- X[,1] tau_X <- 0.25 * X[,2] y <- mu_X + tau_X * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) intervals <- computeBCFPosteriorInterval( model_object = bcf_model, terms = c("prognostic_function", "cate"), X = X, Z = Z, propensity = pi_X, level = 0.90 )n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) pi_X <- pnorm(0.5 * X[,1]) Z <- rbinom(n, 1, pi_X) mu_X <- X[,1] tau_X <- 0.25 * X[,2] y <- mu_X + tau_X * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) intervals <- computeBCFPosteriorInterval( model_object = bcf_model, terms = c("prognostic_function", "cate"), X = X, Z = Z, propensity = pi_X, level = 0.90 )
Compute a contrast using a BART model by making two sets of outcome predictions and taking their difference.
This function provides the flexibility to compute any contrast of interest by specifying covariates, leaf basis, and random effects
bases / IDs for both sides of a two term contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or
Y0 term and the minuend of the contrast as the Y1 term, though the requested contrast need not match the "control vs treatment"
terminology of a classic two-treatment causal inference problem. We mirror the function calls and terminology of the predict.bartmodel
function, labeling each prediction data term with a 1 to denote its contribution to the treatment prediction of a contrast and
0 to denote inclusion in the control prediction.
Only valid when there is either a mean forest or a random effects term in the BART model.
computeContrastBARTModel( object, X_0, X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, rfx_group_ids_1 = NULL, rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", scale = "linear" )computeContrastBARTModel( object, X_0, X_1, leaf_basis_0 = NULL, leaf_basis_1 = NULL, rfx_group_ids_0 = NULL, rfx_group_ids_1 = NULL, rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", scale = "linear" )
object |
Object of type |
X_0 |
Covariates used for prediction in the "control" case. Must be a matrix or dataframe. |
X_1 |
Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. |
leaf_basis_0 |
(Optional) Bases used for prediction in the "control" case (by e.g. dot product with leaf values). Default: |
leaf_basis_1 |
(Optional) Bases used for prediction in the "treatment" case (by e.g. dot product with leaf values). Default: |
rfx_group_ids_0 |
(Optional) Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector. |
rfx_group_ids_1 |
(Optional) Test set group labels used for prediction from an additive random effects model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector. |
rfx_basis_0 |
(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. |
rfx_basis_1 |
(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. |
type |
(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BART model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
scale |
(Optional) Scale of the contrast. Options are "linear", which returns contrast of predictions on the original scale of the mean forest / RFX terms, and "probability". |
Contrast matrix or vector, depending on whether type = "mean" or "posterior".
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) W <- matrix(runif(n*1), ncol = 1) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] W_test <- W[test_inds,] W_train <- W[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) contrast_test <- computeContrastBARTModel( bart_model, X_0 = X_test, X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", scale = "linear" )n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) W <- matrix(runif(n*1), ncol = 1) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] W_test <- W[test_inds,] W_train <- W[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, leaf_basis_train = W_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) contrast_test <- computeContrastBARTModel( bart_model, X_0 = X_test, X_1 = X_test, leaf_basis_0 = matrix(0, nrow = n_test, ncol = 1), leaf_basis_1 = matrix(1, nrow = n_test, ncol = 1), type = "posterior", scale = "linear" )
Compute a contrast using a BCF model by making two sets of outcome predictions and taking their difference.
For simple BCF models with binary treatment, this will yield the same prediction as requesting terms = "cate"
in the predict.bcfmodel function. For more general models, such as models with continuous / multivariate treatments or
an additive random effects term with a coefficient on the treatment, this function provides the flexibility to compute a
any contrast of interest by specifying covariates, treatment, and random effects bases and IDs for both sides of a two term
contrast. For simplicity, we refer to the subtrahend of the contrast as the "control" or Y0 term and the minuend of the
contrast as the Y1 term, though the requested contrast need not match the "control vs treatment" terminology of a classic
two-arm experiment. We mirror the function calls and terminology of the predict.bcfmodel function, labeling each prediction
data term with a 1 to denote its contribution to the treatment prediction of a contrast and 0 to denote inclusion in the
control prediction.
computeContrastBCFModel( object, X_0, X_1, Z_0, Z_1, propensity_0 = NULL, propensity_1 = NULL, rfx_group_ids_0 = NULL, rfx_group_ids_1 = NULL, rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", scale = "linear" )computeContrastBCFModel( object, X_0, X_1, Z_0, Z_1, propensity_0 = NULL, propensity_1 = NULL, rfx_group_ids_0 = NULL, rfx_group_ids_1 = NULL, rfx_basis_0 = NULL, rfx_basis_1 = NULL, type = "posterior", scale = "linear" )
object |
Object of type |
X_0 |
Covariates used for prediction in the "control" case. Must be a matrix or dataframe. |
X_1 |
Covariates used for prediction in the "treatment" case. Must be a matrix or dataframe. |
Z_0 |
Treatments used for prediction in the "control" case. Must be a matrix or vector. |
Z_1 |
Treatments used for prediction in the "treatment" case. Must be a matrix or vector. |
propensity_0 |
(Optional) Propensities used for prediction in the "control" case. Must be a matrix or vector. |
propensity_1 |
(Optional) Propensities used for prediction in the "treatment" case. Must be a matrix or vector. |
rfx_group_ids_0 |
(Optional) Test set group labels used for prediction from an additive random effects model in the "control" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector. |
rfx_group_ids_1 |
(Optional) Test set group labels used for prediction from an additive random effects model in the "treatment" case. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. Must be a vector. |
rfx_basis_0 |
(Optional) Test set basis for used for prediction from an additive random effects model in the "control" case. Must be a matrix or vector. |
rfx_basis_1 |
(Optional) Test set basis for used for prediction from an additive random effects model in the "treatment" case. Must be a matrix or vector. |
type |
(Optional) Aggregation level of the contrast. Options are "mean", which averages the contrast evaluations over every draw of a BCF model, and "posterior", which returns the entire matrix of posterior contrast estimates. Default: "posterior". |
scale |
(Optional) Scale of the contrast. Options are "linear", which returns a contrast on the original scale of the mean forest / RFX terms, and "probability", which transforms each contrast term into a probability of observing |
List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tau_hat_test <- computeContrastBCFModel( bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), propensity_0 = pi_test, propensity_1 = pi_test )n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) tau_hat_test <- computeContrastBCFModel( bcf_model, X_0=X_test, X_1=X_test, Z_0=rep(0, n_test), Z_1=rep(1, n_test), propensity_0 = pi_test, propensity_1 = pi_test )
Wrapper around a C++ nlohmann::json object
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
json_ptrExternal pointer to a C++ nlohmann::json object
num_forestsNumber of forests in the nlohmann::json object
forest_labelsNames of forest objects in the overall nlohmann::json object
num_rfxNumber of random effects terms in the nlohman::json object
rfx_container_labelsNames of rfx container objects in the overall nlohmann::json object
rfx_mapper_labelsNames of rfx label mapper objects in the overall nlohmann::json object
rfx_groupid_labelsNames of rfx group id objects in the overall nlohmann::json object
new()
Create a new CppJson object.
CppJson$new()
A new CppJson object.
add_forest()
Convert a forest container to json and add to the current CppJson object
CppJson$add_forest(forest_samples)
forest_samplesForestSamples R class
None
add_random_effects()
Convert a random effects container to json and add to the current CppJson object
CppJson$add_random_effects(rfx_samples)
rfx_samplesRandomEffectSamples R class
None
add_scalar()
Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_scalar(field_name, field_value, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_valueNumeric value of the field to be added to json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_integer()
Add a scalar to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_integer(field_name, field_value, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_valueInteger value of the field to be added to json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_boolean()
Add a boolean value to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_boolean(field_name, field_value, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_valueNumeric value of the field to be added to json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_string()
Add a string value to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_string(field_name, field_value, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_valueNumeric value of the field to be added to json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_vector()
Add a vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_vector(field_name, field_vector, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_vectorVector to be stored in json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_integer_vector()
Add an integer vector to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_integer_vector(field_name, field_vector, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_vectorVector to be stored in json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_string_vector()
Add an array to the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$add_string_vector(field_name, field_vector, subfolder_name = NULL)
field_nameThe name of the field to be added to json
field_vectorCharacter vector to be stored in json
subfolder_name(Optional) Name of the subfolder / hierarchy under which to place the value
None
add_list()
Add a list of vectors (as an object map of arrays) to the json object under the name "field_name"
CppJson$add_list(field_name, field_list)
field_nameThe name of the field to be added to json
field_listList to be stored in json
None
add_string_list()
Add a list of vectors (as an object map of arrays) to the json object under the name "field_name"
CppJson$add_string_list(field_name, field_list)
field_nameThe name of the field to be added to json
field_listList to be stored in json
None
get_scalar()
Retrieve a scalar value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_scalar(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_integer()
Retrieve a integer value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_integer(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_boolean()
Retrieve a boolean value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_boolean(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_string()
Retrieve a string value from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_string(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_vector()
Retrieve a vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_vector(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_integer_vector()
Retrieve an integer vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_integer_vector(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_string_vector()
Retrieve a character vector from the json object under the name "field_name" (with optional subfolder "subfolder_name")
CppJson$get_string_vector(field_name, subfolder_name = NULL)
field_nameThe name of the field to be accessed from json
subfolder_name(Optional) Name of the subfolder / hierarchy under which the field is stored
None
get_numeric_list()
Reconstruct a list of numeric vectors from the json object stored under "field_name"
CppJson$get_numeric_list(field_name, key_names)
field_nameThe name of the field to be added to json
key_namesVector of names of list elements (each of which is a vector)
None
get_string_list()
Reconstruct a list of string vectors from the json object stored under "field_name"
CppJson$get_string_list(field_name, key_names)
field_nameThe name of the field to be added to json
key_namesVector of names of list elements (each of which is a vector)
None
return_json_string()
Convert a JSON object to in-memory string
CppJson$return_json_string()
JSON string
save_file()
Save a json object to file
CppJson$save_file(filename)
filenameString of filepath, must end in ".json"
None
load_from_file()
Load a json object from file
CppJson$load_from_file(filename)
filenameString of filepath, must end in ".json"
None
load_from_string()
Load a json object from string
CppJson$load_from_string(json_string)
json_stringJSON string dump
None
Wrapper around a C++ random number generator object (for reproducibility).
The class persists a C++ random number generator throughout an R session to
ensure a given seed generates the same outputs (on the same OS). If no seed is provided,
the C++ random number generator is initialized using std::random_device.
rng_ptrExternal pointer to a C++ std::mt19937 class
new()
Create a new CppRNG object.
CppRNG$new(random_seed = -1)
random_seed(Optional) random seed for sampling
A new CppRNG object.
Create an R class that wraps a C++ random number generator
createCppRNG(random_seed = -1)createCppRNG(random_seed = -1)
random_seed |
(Optional) random seed for sampling |
CppRng object
rng <- createCppRNG(1234) rng <- createCppRNG()rng <- createCppRNG(1234) rng <- createCppRNG()
Create a forest
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createForest( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )createForest( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
num_trees |
Number of trees in the forest |
leaf_dimension |
Dimensionality of the outcome model |
is_leaf_constant |
Whether leaf is constant |
is_exponentiated |
Whether forest predictions should be exponentiated before being returned |
Forest object
num_trees <- 100 leaf_dimension <- 2 is_leaf_constant <- FALSE is_exponentiated <- FALSE forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)num_trees <- 100 leaf_dimension <- 2 is_leaf_constant <- FALSE is_exponentiated <- FALSE forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
Create a forest dataset object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createForestDataset(covariates, basis = NULL, variance_weights = NULL)createForestDataset(covariates, basis = NULL, variance_weights = NULL)
covariates |
Matrix of covariates |
basis |
(Optional) Matrix of bases used to define a leaf regression |
variance_weights |
(Optional) Vector of observation-specific variance weights |
ForestDataset object
covariate_matrix <- matrix(runif(10*100), ncol = 10) basis_matrix <- matrix(rnorm(3*100), ncol = 3) weight_vector <- rnorm(100) forest_dataset <- createForestDataset(covariate_matrix) forest_dataset <- createForestDataset(covariate_matrix, basis_matrix) forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector)covariate_matrix <- matrix(runif(10*100), ncol = 10) basis_matrix <- matrix(rnorm(3*100), ncol = 3) weight_vector <- rnorm(100) forest_dataset <- createForestDataset(covariate_matrix) forest_dataset <- createForestDataset(covariate_matrix, basis_matrix) forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector)
Create a forest model object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createForestModel(forest_dataset, forest_model_config, global_model_config)createForestModel(forest_dataset, forest_model_config, global_model_config)
forest_dataset |
ForestDataset object, used to initialize forest sampling data structures |
forest_model_config |
ForestModelConfig object containing forest model parameters and settings |
global_model_config |
GlobalModelConfig object containing global model parameters and settings |
ForestModel object
num_trees <- 100 n <- 100 p <- 10 alpha <- 0.95 beta <- 2.0 min_samples_leaf <- 2 max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)num_trees <- 100 n <- 100 p <- 10 alpha <- 0.95 beta <- 2.0 min_samples_leaf <- 2 max_depth <- 10 feature_types <- as.integer(rep(0, p)) X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_features=p, num_observations=n, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, leaf_model_type=1) global_model_config <- createGlobalModelConfig(global_error_variance=1.0) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config)
Create a forest model config object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createForestModelConfig( feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1, variance_forest_scale = 1, cloglog_forest_shape = 2, cloglog_forest_rate = 2, cutpoint_grid_size = 100, num_features_subsample = NULL )createForestModelConfig( feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1, variance_forest_scale = 1, cloglog_forest_shape = 2, cloglog_forest_rate = 2, cutpoint_grid_size = 100, num_features_subsample = NULL )
feature_types |
Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical) |
sweep_update_indices |
Vector of (0-indexed) indices of trees to update in a sweep |
num_trees |
Number of trees in the forest being sampled |
num_features |
Number of features in training dataset |
num_observations |
Number of observations in training dataset |
variable_weights |
Vector specifying sampling probability for all p covariates in ForestDataset |
leaf_dimension |
Dimension of the leaf model (default: |
alpha |
Root node split probability in tree prior (default: |
beta |
Depth prior penalty in tree prior (default: |
min_samples_leaf |
Minimum number of samples in a tree leaf (default: |
max_depth |
Maximum depth of any tree in the ensemble in the model. Setting to |
leaf_model_type |
Integer specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: |
leaf_model_scale |
Scale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when |
variance_forest_shape |
Shape parameter for IG leaf models (applicable when |
variance_forest_scale |
Scale parameter for IG leaf models (applicable when |
cloglog_forest_shape |
Shape parameter for conditional gamma component of cloglog leaf models (applicable when |
cloglog_forest_rate |
Rate parameter for conditional gamma component of cloglog leaf models (applicable when |
cutpoint_grid_size |
Number of unique cutpoints to consider (default: |
num_features_subsample |
Number of features to subsample for the GFR algorithm |
ForestModelConfig object
config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100)config <- createForestModelConfig(num_trees = 10, num_features = 5, num_observations = 100)
Create a container of forest samples
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createForestSamples( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )createForestSamples( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
num_trees |
Number of trees |
leaf_dimension |
Dimensionality of the outcome model |
is_leaf_constant |
Whether leaf is constant |
is_exponentiated |
Whether forest predictions should be exponentiated before being returned |
ForestSamples object
num_trees <- 100 leaf_dimension <- 2 is_leaf_constant <- FALSE is_exponentiated <- FALSE forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)num_trees <- 100 leaf_dimension <- 2 is_leaf_constant <- FALSE is_exponentiated <- FALSE forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated)
Create a global model config object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createGlobalModelConfig(global_error_variance = 1)createGlobalModelConfig(global_error_variance = 1)
global_error_variance |
Global error variance parameter (default: |
GlobalModelConfig object
config <- createGlobalModelConfig(global_error_variance = 100)config <- createGlobalModelConfig(global_error_variance = 100)
Create an outcome object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createOutcome(outcome)createOutcome(outcome)
outcome |
Vector of outcome values |
Outcome object
X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) outcome <- createOutcome(y)X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) outcome <- createOutcome(y)
Create a RandomEffectSamples object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createRandomEffectSamples(num_components, num_groups, random_effects_tracker)createRandomEffectSamples(num_components, num_groups, random_effects_tracker)
num_components |
Number of "components" or bases defining the random effects regression |
num_groups |
Number of random effects groups |
random_effects_tracker |
Object of type |
RandomEffectSamples object
n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker)n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker)
Create a random effects dataset object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createRandomEffectsDataset(group_labels, basis, variance_weights = NULL)createRandomEffectsDataset(group_labels, basis, variance_weights = NULL)
group_labels |
Vector of group labels |
basis |
Matrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones) |
variance_weights |
(Optional) Vector of observation-specific variance weights |
RandomEffectsDataset object
rfx_group_ids <- sample(1:2, size = 100, replace = TRUE) rfx_basis <- matrix(rnorm(3*100), ncol = 3) weight_vector <- rnorm(100) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector)rfx_group_ids <- sample(1:2, size = 100, replace = TRUE) rfx_basis <- matrix(rnorm(3*100), ncol = 3) weight_vector <- rnorm(100) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector)
Create a RandomEffectsModel object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createRandomEffectsModel(num_components, num_groups)createRandomEffectsModel(num_components, num_groups)
num_components |
Number of "components" or bases defining the random effects regression |
num_groups |
Number of random effects groups |
RandomEffectsModel object
n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_model <- createRandomEffectsModel(num_components, num_groups)n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_model <- createRandomEffectsModel(num_components, num_groups)
Create a RandomEffectsTracker object
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createRandomEffectsTracker(rfx_group_indices)createRandomEffectsTracker(rfx_group_indices)
rfx_group_indices |
Integer indices indicating groups used to define random effects |
RandomEffectsTracker object
n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids)n <- 100 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids)
The functions in this group are designed to handle data preprocessing for stochastic forest models. For example, factor-valued columns in data frames are either one-hot encoded or converted to integer indices before the dataframe is converted to a standard matrix format for sampling. This preprocessing routine defines a set of "steps" that must be repeated on out-of-sample datasets before predictions can be obtained from a sampling model.
preprocessTrainData preprocesses covariates for the forest sampler routines, depending on the input type.
DataFrames will be preprocessed based on their column types (numeric columns are not modified, ordered factors are
integer coded, and unordered factors are one-hot encoded). Matrices are unmodified (assuming all columns are numeric).
This function also records and returns a "metadata" list with preprocessing details to ensure that other datasets can be preprocessing identically.
preprocessPredictionData preprocesses covariates for the forest sampler routines, based on the steps outlined in a metadata list produced by preprocessTrainData.
These procedures are handled internally in the bart() and bcf() functions, but they are provided in stochtree as convenience functions for users writing custom samplers.
Furthermore, while R lists can be serialized to RDS format, we offer a number of JSON serialization routines for the metadata list produced by preprocessTrainData for
consistency with the broader serialization approach of stochtree (see BARTSerialization and BCFSerialization).
Following the API for serializing bartmodel and bcfmodel objects, we can convert metadata to JSON or JSON strings via
savePreprocessorToJson and savePreprocessorToJsonString. Similarly, we can reload a metadata list from JSON or JSON strings
via createPreprocessorFromJson and createPreprocessorFromJsonString.
preprocessTrainData(input_data) preprocessPredictionData(input_data, metadata) savePreprocessorToJson(object) savePreprocessorToJsonString(object) createPreprocessorFromJson(json_object) createPreprocessorFromJsonString(json_string)preprocessTrainData(input_data) preprocessPredictionData(input_data, metadata) savePreprocessorToJson(object) savePreprocessorToJsonString(object) createPreprocessorFromJson(json_object) createPreprocessorFromJsonString(json_string)
input_data |
Covariates, provided as either a dataframe or a matrix |
metadata |
List containing information on variables, including train set categories for categorical variables |
object |
List containing information on variables, including train set categories for categorical variables |
json_object |
in-memory wrapper around JSON C++ object containing covariate preprocessor metadata |
json_string |
in-memory JSON string containing covariate preprocessor metadata |
preprocessTrainData returns a list with transformed matrix data and a "metadata" list with details on the preprocessing procedures applied.
preprocessPredictionData returns a matrix reflecting the data transformations specified in the provided metadata list.
savePreprocessorToJson return an object of type CppJson.
savePreprocessorToJsonString returns a string dump of the preprocessor's JSON representation.
createPreprocessorFromJson and createPreprocessorFromJsonString both return metadata lists.
# Check that running the same data through `preprocessTrainData` # and `preprocessPredictionData` yields the same result n <- 100 x1 <- rnorm(n) x2 <- factor(sample(1:3, n, replace = TRUE), ordered = TRUE) x3 <- factor(sample(1:3, n, replace = TRUE), ordered = FALSE) df1 <- data.frame(x1 = x1, x2 = x2, x3 = x3) df2 <- data.frame(x1 = x1, x2 = x2, x3 = x3) preprocess_train_list <- preprocessTrainData(df1) df1_process <- preprocess_train_list$data df1_metadata <- preprocess_train_list$metadata df2_process <- preprocessPredictionData(df2, df1_metadata) all.equal(df1_process, df2_process) # Save to in-memory JSON metadata_json <- savePreprocessorToJson(df1_metadata) # Save to JSON string metadata_json_string <- savePreprocessorToJsonString(df1_metadata) # Reload metadata list from in-memory JSON object metadata_roundtrip <- createPreprocessorFromJson(metadata_json) # Reload metadata list from JSON string metadata_roundtrip <- createPreprocessorFromJsonString(metadata_json_string) cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata)# Check that running the same data through `preprocessTrainData` # and `preprocessPredictionData` yields the same result n <- 100 x1 <- rnorm(n) x2 <- factor(sample(1:3, n, replace = TRUE), ordered = TRUE) x3 <- factor(sample(1:3, n, replace = TRUE), ordered = FALSE) df1 <- data.frame(x1 = x1, x2 = x2, x3 = x3) df2 <- data.frame(x1 = x1, x2 = x2, x3 = x3) preprocess_train_list <- preprocessTrainData(df1) df1_process <- preprocess_train_list$data df1_metadata <- preprocess_train_list$metadata df2_process <- preprocessPredictionData(df2, df1_metadata) all.equal(df1_process, df2_process) # Save to in-memory JSON metadata_json <- savePreprocessorToJson(df1_metadata) # Save to JSON string metadata_json_string <- savePreprocessorToJsonString(df1_metadata) # Reload metadata list from in-memory JSON object metadata_roundtrip <- createPreprocessorFromJson(metadata_json) # Reload metadata list from JSON string metadata_roundtrip <- createPreprocessorFromJsonString(metadata_json_string) cov_df <- data.frame(x1 = 1:5, x2 = 5:1, x3 = 6:10) metadata <- list(num_ordered_cat_vars = 0, num_unordered_cat_vars = 0, num_numeric_vars = 3, numeric_vars = c("x1", "x2", "x3")) X_preprocessed <- preprocessPredictionData(cov_df, metadata)
Generic function for extracting parameter samples from a model object (BCF, BART, etc...)
extractParameter(object, term)extractParameter(object, term)
object |
Fitted model object from which to extract parameter samples |
term |
Name of the parameter to extract (e.g., |
Parameter sample array
n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) sigma2_samples <- extractParameter(bart_model, "sigma2")n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) sigma2_samples <- extractParameter(bart_model, "sigma2")
Extract a vector, matrix or array of parameter samples from a BART model by name.
Random effects are handled by a separate getRandomEffectSamples function due to the complexity of the random effects parameters.
If the requested model term is not found, an error is thrown.
The following conventions are used for parameter names:
Global error variance: "sigma2", "global_error_scale", "sigma2_global"
Leaf scale: "sigma2_leaf", "leaf_scale"
In-sample mean function predictions: "y_hat_train"
Test set mean function predictions: "y_hat_test"
In-sample variance forest predictions: "sigma2_x_train", "var_x_train"
Test set variance forest predictions: "sigma2_x_test", "var_x_test"
Ordinal model cutpoints (valid only for ordinal cloglog models): "cloglog_cutpoints", "cutpoints"
## S3 method for class 'bartmodel' extractParameter(object, term)## S3 method for class 'bartmodel' extractParameter(object, term)
object |
Object of type |
term |
Name of the parameter to extract (e.g., |
Array of parameter samples. If the underlying parameter is a scalar, this will be a vector of length num_samples.
If the underlying parameter is vector-valued, this will be (parameter_dimension x num_samples) matrix, and if the underlying
parameter is multidimensional, this will be an array of dimension (parameter_dimension_1 x parameter_dimension_2 x ... x num_samples).
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) E_y <- f_XW + rfx_term y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] rfx_group_ids_test <- group_ids[test_inds] rfx_group_ids_train <- group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) sigma2_samples <- extractParameter(bart_model, "sigma2")n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) E_y <- f_XW + rfx_term y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] rfx_group_ids_test <- group_ids[test_inds] rfx_group_ids_train <- group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) sigma2_samples <- extractParameter(bart_model, "sigma2")
Extract a vector, matrix or array of parameter samples from a BCF model by name.
Random effects are handled by a separate getRandomEffectSamples function due to the complexity of the random effects parameters.
If the requested model term is not found, an error is thrown.
The following conventions are used for parameter names:
Global error variance: "sigma2", "global_error_scale", "sigma2_global"
Prognostic forest leaf scale: "sigma2_leaf_mu", "leaf_scale_mu", "mu_leaf_scale"
Treatment effect forest leaf scale: "sigma2_leaf_tau", "leaf_scale_tau", "tau_leaf_scale"
Adaptive coding parameters: "adaptive_coding" (returns both the control and treated parameters jointly, with control in the first row and treated in the second row)
In-sample mean function predictions: "y_hat_train"
Test set mean function predictions: "y_hat_test"
In-sample treatment effect forest predictions: "tau_hat_train"
Test set treatment effect forest predictions: "tau_hat_test"
Treatment effect intercept: "tau_0", "treatment_intercept", "tau_intercept"
In-sample variance forest predictions: "sigma2_x_train", "var_x_train"
Test set variance forest predictions: "sigma2_x_test", "var_x_test"
## S3 method for class 'bcfmodel' extractParameter(object, term)## S3 method for class 'bcfmodel' extractParameter(object, term)
object |
Object of type |
term |
Name of the parameter to extract (e.g., |
Array of parameter samples. If the underlying parameter is a scalar, this will be a vector of length num_samples.
If the underlying parameter is vector-valued, this will be (parameter_dimension x num_samples) matrix, and if the underlying
parameter is multidimensional, this will be an array of dimension (parameter_dimension_1 x parameter_dimension_2 x ... x num_samples).
n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) E_XZ <- mu_x + Z*tau_x snr <- 3 rfx_group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) bcf_model <- bcf(X_train = X, y_train = y, Z_train = Z, rfx_group_ids_train = rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr = 10, num_burnin = 0, num_mcmc = 10) sigma2_samples <- extractParameter(bcf_model, "sigma2")n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) E_XZ <- mu_x + Z*tau_x snr <- 3 rfx_group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) bcf_model <- bcf(X_train = X, y_train = y, Z_train = Z, rfx_group_ids_train = rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr = 10, num_burnin = 0, num_mcmc = 10) sigma2_samples <- extractParameter(bcf_model, "sigma2")
Wrapper around a C++ class that stores a single ensemble of decision trees (often treated as the "active forest" / current state of a forest term in a sampling loop in R)
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
forest_ptrExternal pointer to a C++ TreeEnsemble class
internal_forest_is_emptyWhether the forest has not yet been "initialized" such that its predict function can be called.
new()
Create a new Forest object.
Forest$new( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
num_treesNumber of trees in the forest
leaf_dimensionDimensionality of the outcome model
is_leaf_constantWhether leaf is constant
is_exponentiatedWhether forest predictions should be exponentiated before being returned
A new Forest object.
merge_forest()
Create a larger forest by merging the trees of this forest with those of another forest
Forest$merge_forest(forest)
forestForest to be merged into this forest
add_constant()
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, constant_value will be added to every dimension of the leaves.
Forest$add_constant(constant_value)
constant_valueValue that will be added to every leaf of every tree
multiply_constant()
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, constant_multiple will be multiplied through every dimension of the leaves.
Forest$multiply_constant(constant_multiple)
constant_multipleValue that will be multiplied by every leaf of every tree
predict()
Predict forest on every sample in forest_dataset
Forest$predict(forest_dataset)
forest_datasetForestDataset R class
vector of predictions with as many rows as in forest_dataset
predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every sample in forest_dataset
Forest$predict_raw(forest_dataset)
forest_datasetForestDataset R class
Array of predictions for each observation in forest_dataset and
each sample in the ForestSamples class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is a vector (length is the number of
observations). In the case of a multivariate leaf regression,
this array is a matrix (number of observations by leaf model dimension,
number of samples).
set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Forest$set_root_leaves(leaf_value)
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
Forest$prepare_for_sampler( dataset, outcome, forest_model, leaf_model_int, leaf_value )
datasetForestDataset Dataset class (covariates, basis, etc...)
outcomeOutcome Outcome class (residual / partial residual)
forest_modelForestModel object storing tracking structures used in training / sampling
leaf_model_intInteger value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
Forest$adjust_residual(dataset, outcome, forest_model, requires_basis, add)
datasetForestDataset object storing the covariates and bases for a given forest
outcomeOutcome object storing the residuals to be updated based on forest predictions
forest_modelForestModel object storing tracking structures used in training / sampling
requires_basisWhether or not a forest requires a basis for prediction
addWhether forest predictions should be added to or subtracted from residuals
num_trees()
Return number of trees in each ensemble of a Forest object
Forest$num_trees()
Tree count
leaf_dimension()
Return output dimension of trees in a Forest object
Forest$leaf_dimension()
Leaf node parameter size
is_leaf_constant()
Return constant leaf status of trees in a Forest object
Forest$is_leaf_constant()
TRUE if leaves are constant, FALSE otherwise
is_exponentiated()
Return exponentiation status of trees in a Forest object
Forest$is_exponentiated()
TRUE if leaf predictions must be exponentiated, FALSE otherwise
add_numeric_split_tree()
Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble
Forest$add_numeric_split_tree( tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value )
tree_numIndex of the tree to be split
leaf_numLeaf to be split
feature_numFeature that defines the new split
split_thresholdValue that defines the cutoff of the new split
left_leaf_valueValue (or vector of values) to assign to the newly created left node
right_leaf_valueValue (or vector of values) to assign to the newly created right node
get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
Forest$get_tree_leaves(tree_num)
tree_numIndex of the tree for which leaf indices will be retrieved
get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in the forest
Forest$get_tree_split_counts(tree_num, num_features)
tree_numIndex of the tree for which split counts will be retrieved
num_featuresTotal number of features in the training set
get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in the forest
Forest$get_forest_split_counts(num_features)
num_featuresTotal number of features in the training set
tree_max_depth()
Maximum depth of a specific tree in the forest
Forest$tree_max_depth(tree_num)
tree_numTree index within forest
Maximum leaf depth
average_max_depth()
Average the maximum depth of each tree in the forest
Forest$average_max_depth()
Average maximum depth
is_empty()
When a forest object is created, it is "empty" in the sense that none
of its component trees have leaves with values. There are two ways to
"initialize" a Forest object. First, the set_root_leaves() method
simply initializes every tree in the forest to a single node carrying
the same (user-specified) leaf value. Second, the prepare_for_sampler()
method initializes every tree in the forest to a single node with the
same value and also propagates this information through to a ForestModel
object, which must be synchronized with a Forest during a forest
sampler loop.
Forest$is_empty()
TRUE if a Forest has not yet been initialized with a constant
root value, FALSE otherwise if the forest has already been
initialized / grown.
Wrapper around a C++ dataset class used to sample a forest. A dataset consists of three matrices / vectors: covariates, bases, and variance weights. Both the basis vector and variance weights are optional.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
data_ptrExternal pointer to a C++ ForestDataset class
new()
Create a new ForestDataset object.
ForestDataset$new(covariates, basis = NULL, variance_weights = NULL)
covariatesMatrix of covariates
basis(Optional) Matrix of bases used to define a leaf regression
variance_weights(Optional) Vector of observation-specific variance weights
A new ForestDataset object.
update_basis()
Update basis matrix in a dataset
ForestDataset$update_basis(basis)
basisUpdated matrix of bases used to define a leaf regression
update_variance_weights()
Update variance_weights in a dataset
ForestDataset$update_variance_weights(variance_weights, exponentiate = F)
variance_weightsUpdated vector of variance weights used to define individual variance / case weights
exponentiateWhether or not input vector should be exponentiated before being written to the Dataset's variance weights. Default: F.
num_observations()
Return number of observations in a ForestDataset object
ForestDataset$num_observations()
Observation count
num_covariates()
Return number of covariates in a ForestDataset object
ForestDataset$num_covariates()
Covariate count
num_basis()
Return number of bases in a ForestDataset object
ForestDataset$num_basis()
Basis count
get_covariates()
Return covariates as an R matrix
ForestDataset$get_covariates()
Covariate data
get_basis()
Return bases as an R matrix
ForestDataset$get_basis()
Basis data
get_variance_weights()
Return variance weights as an R vector
ForestDataset$get_variance_weights()
Variance weight data
has_basis()
Whether or not a dataset has a basis matrix
ForestDataset$has_basis()
True if basis matrix is loaded, false otherwise
has_variance_weights()
Whether or not a dataset has variance weights
ForestDataset$has_variance_weights()
True if variance weights are loaded, false otherwise
has_auxiliary_dimension()
Whether or not a dataset has auxiliary data stored at the dimension indicated
ForestDataset$has_auxiliary_dimension(dim_idx)
dim_idxDimension of auxiliary data
True if auxiliary data has been allocated for dim_idx False otherwise
add_auxiliary_dimension()
Initialize a new dimension / lane of auxiliary data and allocate data in its place
ForestDataset$add_auxiliary_dimension(dim_size)
dim_sizeSize of the new vector of data to allocate
None
get_auxiliary_data_value()
Retrieve auxiliary data value
ForestDataset$get_auxiliary_data_value(dim_idx, element_idx)
dim_idxDimension from which data value to be retrieved
element_idxElement to retrieve from dimension dim_idx
Floating point value stored in the requested auxiliary data space
set_auxiliary_data_value()
Set auxiliary data value
ForestDataset$set_auxiliary_data_value(dim_idx, element_idx, value)
dim_idxDimension in which data value to be set
element_idxElement to set within dimension dim_idx
valueData value to set at auxiliary data dimension dim_idx and element element_idx
None
get_auxiliary_data_vector()
Retrieve entire auxiliary data vector
ForestDataset$get_auxiliary_data_vector(dim_idx)
dim_idxDimension to retrieve
Vector of all of the auxiliary data stored at dimension dim_idx
Decision tree ensembles can be represented in part by a "kernel" function whose distance metric is based on the extent to which two observations are mapped to the same leaf nodes. This function group offers utilities for evaluating this kernel.
computeForestLeafIndices computes and return a vector representation of a forest's
leaf predictions for every observation in a dataset.
The resulting vector has a "tree-major" format that can be easily re-represented as
as a CSR sparse matrix: elements are organized so that the first n elements
correspond to leaf predictions for all n observations in a dataset for the
first tree in an ensemble, the next n elements correspond to predictions for
the second tree and so on. The "data" for each element corresponds to a uniquely
mapped column index that corresponds to a single leaf of a single tree (i.e.
if tree 1 has 3 leaves, its column indices range from 0 to 2, and then tree 2's
leaf indices begin at 3, etc...).
computeForestLeafVariances returns each forest's leaf node scale parameters.
If leaf scale is not sampled for the forest in question, the function throws an error that the
leaf model does not have a stochastic scale parameter.
computeForestMaxLeafIndex computes and returns the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
computeForestLeafIndices( model_object, covariates, forest_type = NULL, propensity = NULL, forest_inds = NULL ) computeForestLeafVariances(model_object, forest_type, forest_inds = NULL) computeForestMaxLeafIndex(model_object, forest_type = NULL, forest_inds = NULL)computeForestLeafIndices( model_object, covariates, forest_type = NULL, propensity = NULL, forest_inds = NULL ) computeForestLeafVariances(model_object, forest_type, forest_inds = NULL) computeForestMaxLeafIndex(model_object, forest_type = NULL, forest_inds = NULL)
model_object |
Object of type |
covariates |
Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest. |
forest_type |
Which forest to use from 1. BART
2. BCF
3. ForestSamples
|
propensity |
(Optional) Propensities used for prediction (BCF-only). |
forest_inds |
(Optional) Indices of the forest sample(s) for which to compute max leaf indices. If not provided,
this function will return max leaf indices for every sample of a forest.
This function uses 0-indexing, so the first forest sample corresponds to |
computeForestLeafIndices returns a vector of size num_obs * num_trees, where num_obs = nrow(covariates)
and num_trees is the number of trees in the relevant forest of model_object.
computeForestLeafVariances returns a vector of size length(forest_inds) with the leaf scale parameter for each requested forest.
computeForestMaxLeafIndex returns a vector containing the largest possible leaf index computable by computeForestLeafIndices for the forests in a designated forest sample container.
X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) leaf_indices <- computeForestLeafIndices(bart_model, X, "mean") leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", 0) leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) leaf_variances <- computeForestLeafVariances(bart_model, "mean") leaf_variances <- computeForestLeafVariances(bart_model, "mean", 0) leaf_variances <- computeForestLeafVariances(bart_model, "mean", c(1,3,5)) max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean") max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", 0) max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9))X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) leaf_indices <- computeForestLeafIndices(bart_model, X, "mean") leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", 0) leaf_indices <- computeForestLeafIndices(bart_model, X, "mean", c(1,3,9)) leaf_variances <- computeForestLeafVariances(bart_model, "mean") leaf_variances <- computeForestLeafVariances(bart_model, "mean", 0) leaf_variances <- computeForestLeafVariances(bart_model, "mean", c(1,3,5)) max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean") max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", 0) max_leaf_index <- computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9))
Wraps the C++ data structures needed to sample an ensemble of decision trees and exposes functionality to run a forest sampler (using either MCMC or the grow-from-root algorithm).
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
tracker_ptrExternal pointer to a C++ ForestTracker class
tree_prior_ptrExternal pointer to a C++ TreePrior class
new()
Create a new ForestModel object.
ForestModel$new( forest_dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf, max_depth = -1 )
forest_datasetForestDataset object, used to initialize forest sampling data structures
feature_typesFeature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
num_treesNumber of trees in the forest being sampled
nNumber of observations in forest_dataset
alphaRoot node split probability in tree prior
betaDepth prior penalty in tree prior
min_samples_leafMinimum number of samples in a tree leaf
max_depthMaximum depth that any tree can reach
A new ForestModel object.
sample_one_iteration()
Run a single iteration of the forest sampling algorithm (MCMC or GFR)
ForestModel$sample_one_iteration( forest_dataset, residual, forest_samples, active_forest, rng, forest_model_config, global_model_config, num_threads = -1, keep_forest = TRUE, gfr = TRUE )
forest_datasetDataset used to sample the forest
residualOutcome used to sample the forest
forest_samplesContainer of forest samples
active_forest"Active" forest updated by the sampler in each iteration
rngWrapper around C++ random number generator
forest_model_configForestModelConfig object containing forest model parameters and settings
global_model_configGlobalModelConfig object containing global model parameters and settings
num_threadsNumber of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's system, this will default to 1, otherwise to the maximum number of available threads.
keep_forest(Optional) Whether the updated forest sample should be saved to forest_samples. Default: TRUE.
gfr(Optional) Whether or not the forest should be sampled using the "grow-from-root" (GFR) algorithm. Default: TRUE.
get_cached_forest_predictions()
Extract an internally-cached prediction of a forest on the training dataset in a sampler.
ForestModel$get_cached_forest_predictions()
Vector with as many elements as observations in the training dataset
propagate_basis_update()
Propagates basis update through to the (full/partial) residual by iteratively (a) adding back in the previous prediction of each tree, (b) recomputing predictions for each tree (caching on the C++ side), (c) subtracting the new predictions from the residual.
This is useful in cases where a basis (for e.g. leaf regression) is updated outside of a tree sampler (as with e.g. adaptive coding for binary treatment BCF). Once a basis has been updated, the overall "function" represented by a tree model has changed and this should be reflected through to the residual before the next sampling loop is run.
ForestModel$propagate_basis_update(dataset, outcome, active_forest)
datasetForestDataset object storing the covariates and bases for a given forest
outcomeOutcome object storing the residuals to be updated based on forest predictions
active_forest"Active" forest updated by the sampler in each iteration
propagate_residual_update()
Update the current state of the outcome (i.e. partial residual) data by subtracting the current predictions of each tree.
This function is run after the Outcome class's update_data method, which overwrites the partial residual with an entirely new stream of outcome data.
ForestModel$propagate_residual_update(residual)
residualOutcome used to sample the forest
None
update_alpha()
Update alpha in the tree prior
ForestModel$update_alpha(alpha)
alphaNew value of alpha to be used
None
update_beta()
Update beta in the tree prior
ForestModel$update_beta(beta)
betaNew value of beta to be used
None
update_min_samples_leaf()
Update min_samples_leaf in the tree prior
ForestModel$update_min_samples_leaf(min_samples_leaf)
min_samples_leafNew value of min_samples_leaf to be used
None
update_max_depth()
Update max_depth in the tree prior
ForestModel$update_max_depth(max_depth)
max_depthNew value of max_depth to be used
None
get_alpha()
Update alpha in the tree prior
ForestModel$get_alpha()
Value of alpha in the tree prior
get_beta()
Update beta in the tree prior
ForestModel$get_beta()
Value of beta in the tree prior
get_min_samples_leaf()
Query min_samples_leaf in the tree prior
ForestModel$get_min_samples_leaf()
Value of min_samples_leaf in the tree prior
get_max_depth()
Query max_depth in the tree prior
ForestModel$get_max_depth()
Value of max_depth in the tree prior
Object used to get / set parameters and other model configuration options for a forest model in the "low-level" stochtree interface. The "low-level" stochtree interface enables a high degreee of sampler customization, in which users employ R wrappers around C++ objects like ForestDataset, Outcome, CppRng, and ForestModel to run the Gibbs sampler of a BART model with custom modifications. ForestModelConfig allows users to specify / query the parameters of a forest model they wish to run.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
Vector of (0-indexed) indices of trees to update in a sweep
Vector specifying sampling probability for all p covariates in ForestDataset
Number of trees in a forest
Number of features in a forest model training set
Number of observations in a forest model training set
Root node split probability in tree prior
Depth prior penalty in tree prior
Minimum number of samples in a tree leaf
Maximum depth of any tree in the ensemble in the model
Integer coded leaf model type
Scale parameter used in Gaussian leaf models
Shape parameter for IG leaf models
Scale parameter for IG leaf models
Shape parameter for conditional gamma component of cloglog leaf models
Rate parameter for conditional gamma component of cloglog leaf models
Number of unique cutpoints to consider
Number of features to subsample for the GFR algorithm
feature_typesVector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
sweep_update_indicesVector of trees to update in a sweep
num_treesNumber of trees in the forest being sampled
num_featuresNumber of features in training dataset
num_observationsNumber of observations in training dataset
leaf_dimensionDimension of the leaf model
alphaRoot node split probability in tree prior
betaDepth prior penalty in tree prior
min_samples_leafMinimum number of samples in a tree leaf
max_depthMaximum depth of any tree in the ensemble in the model. Setting to -1 does not enforce any depth limits on trees.
leaf_model_typeInteger specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression)
leaf_model_scaleScale parameter used in Gaussian leaf models
variable_weightsVector specifying sampling probability for all p covariates in ForestDataset
variance_forest_shapeShape parameter for IG leaf models (applicable when leaf_model_type = 3)
variance_forest_scaleScale parameter for IG leaf models (applicable when leaf_model_type = 3)
cloglog_forest_shapeShape parameter for conditional gamma component of cloglog leaf models (applicable when leaf_model_type = 4)
cloglog_forest_rateRate parameter for conditional gamma component of cloglog leaf models (applicable when leaf_model_type = 4)
cutpoint_grid_sizeNumber of unique cutpoints to consider
num_features_subsampleNumber of features to subsample for the GFR algorithm Create a new ForestModelConfig object.
new()
ForestModelConfig$new( feature_types = NULL, sweep_update_indices = NULL, num_trees = NULL, num_features = NULL, num_observations = NULL, variable_weights = NULL, leaf_dimension = 1, alpha = 0.95, beta = 2, min_samples_leaf = 5, max_depth = -1, leaf_model_type = 1, leaf_model_scale = NULL, variance_forest_shape = 1, variance_forest_scale = 1, cloglog_forest_shape = 2, cloglog_forest_rate = 2, cutpoint_grid_size = 100, num_features_subsample = NULL )
feature_typesVector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
sweep_update_indicesVector of (0-indexed) indices of trees to update in a sweep
num_treesNumber of trees in the forest being sampled
num_featuresNumber of features in training dataset
num_observationsNumber of observations in training dataset
variable_weightsVector specifying sampling probability for all p covariates in ForestDataset
leaf_dimensionDimension of the leaf model (default: 1)
alphaRoot node split probability in tree prior (default: 0.95)
betaDepth prior penalty in tree prior (default: 2.0)
min_samples_leafMinimum number of samples in a tree leaf (default: 5)
max_depthMaximum depth of any tree in the ensemble in the model. Setting to -1 does not enforce any depth limits on trees. Default: -1.
leaf_model_typeInteger specifying the leaf model type (0 = constant leaf, 1 = univariate leaf regression, 2 = multivariate leaf regression). Default: 0.
leaf_model_scaleScale parameter used in Gaussian leaf models (can either be a scalar or a q x q matrix, where q is the dimensionality of the basis and is only >1 when leaf_model_int = 2). Calibrated internally as 1/num_trees, propagated along diagonal if needed for multivariate leaf models.
variance_forest_shapeShape parameter for IG leaf models (applicable when leaf_model_type = 3). Default: 1.
variance_forest_scaleScale parameter for IG leaf models (applicable when leaf_model_type = 3). Default: 1.
cloglog_forest_shapeShape parameter for conditional gamma component of cloglog leaf models (applicable when leaf_model_type = 4). Default: 1.
cloglog_forest_rateRate parameter for conditional gamma component of cloglog leaf models (applicable when leaf_model_type = 4). Default: 1.
cutpoint_grid_sizeNumber of unique cutpoints to consider (default: 100)
num_features_subsampleNumber of features to subsample for the GFR algorithm
A new ForestModelConfig object.
update_feature_types()
Update feature types
ForestModelConfig$update_feature_types(feature_types)
feature_typesVector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
update_sweep_indices()
Update sweep update indices
ForestModelConfig$update_sweep_indices(sweep_update_indices)
sweep_update_indicesVector of (0-indexed) indices of trees to update in a sweep
update_variable_weights()
Update variable weights
ForestModelConfig$update_variable_weights(variable_weights)
variable_weightsVector specifying sampling probability for all p covariates in ForestDataset
update_alpha()
Update root node split probability in tree prior
ForestModelConfig$update_alpha(alpha)
alphaRoot node split probability in tree prior
update_beta()
Update depth prior penalty in tree prior
ForestModelConfig$update_beta(beta)
betaDepth prior penalty in tree prior
update_min_samples_leaf()
Update minimum number of samples per leaf node in the tree prior
ForestModelConfig$update_min_samples_leaf(min_samples_leaf)
min_samples_leafMinimum number of samples in a tree leaf
update_max_depth()
Update max depth in the tree prior
ForestModelConfig$update_max_depth(max_depth)
max_depthMaximum depth of any tree in the ensemble in the model
update_leaf_model_scale()
Update scale parameter used in Gaussian leaf models
ForestModelConfig$update_leaf_model_scale(leaf_model_scale)
leaf_model_scaleScale parameter used in Gaussian leaf models
update_variance_forest_shape()
Update shape parameter for IG leaf models
ForestModelConfig$update_variance_forest_shape(variance_forest_shape)
variance_forest_shapeShape parameter for IG leaf models
update_variance_forest_scale()
Update scale parameter for IG leaf models
ForestModelConfig$update_variance_forest_scale(variance_forest_scale)
variance_forest_scaleScale parameter for IG leaf models
update_cloglog_forest_shape()
Update shape parameter for conditional gamma component of cloglog leaf models
ForestModelConfig$update_cloglog_forest_shape(cloglog_forest_shape)
cloglog_forest_shapeShape parameter for conditional gamma component of cloglog leaf models
update_cloglog_forest_rate()
Update rate parameter for conditional gamma component of cloglog leaf models
ForestModelConfig$update_cloglog_forest_rate(cloglog_forest_rate)
cloglog_forest_rateRate parameter for conditional gamma component of cloglog leaf models
update_cutpoint_grid_size()
Update number of unique cutpoints to consider
ForestModelConfig$update_cutpoint_grid_size(cutpoint_grid_size)
cutpoint_grid_sizeNumber of unique cutpoints to consider
update_num_features_subsample()
Update number of features to subsample for the GFR algorithm
ForestModelConfig$update_num_features_subsample(num_features_subsample)
num_features_subsampleNumber of features to subsample for the GFR algorithm
get_feature_types()
Query feature types for this ForestModelConfig object
ForestModelConfig$get_feature_types()
get_sweep_indices()
Query sweep update indices for this ForestModelConfig object
ForestModelConfig$get_sweep_indices()
get_variable_weights()
Query variable weights for this ForestModelConfig object
ForestModelConfig$get_variable_weights()
get_num_trees()
Query number of trees
ForestModelConfig$get_num_trees()
get_num_features()
Query number of features
ForestModelConfig$get_num_features()
get_num_observations()
Query number of observations
ForestModelConfig$get_num_observations()
get_alpha()
Query root node split probability in tree prior for this ForestModelConfig object
ForestModelConfig$get_alpha()
get_beta()
Query depth prior penalty in tree prior for this ForestModelConfig object
ForestModelConfig$get_beta()
get_min_samples_leaf()
Query root node split probability in tree prior for this ForestModelConfig object
ForestModelConfig$get_min_samples_leaf()
get_max_depth()
Query root node split probability in tree prior for this ForestModelConfig object
ForestModelConfig$get_max_depth()
get_leaf_model_type()
Query (integer-coded) type of leaf model
ForestModelConfig$get_leaf_model_type()
get_leaf_model_scale()
Query scale parameter used in Gaussian leaf models for this ForestModelConfig object
ForestModelConfig$get_leaf_model_scale()
get_variance_forest_shape()
Query shape parameter for IG leaf models for this ForestModelConfig object
ForestModelConfig$get_variance_forest_shape()
get_variance_forest_scale()
Query scale parameter for IG leaf models for this ForestModelConfig object
ForestModelConfig$get_variance_forest_scale()
get_cloglog_forest_shape()
Query shape parameter for conditional gamma component of cloglog leaf models for this ForestModelConfig object
ForestModelConfig$get_cloglog_forest_shape()
get_cloglog_forest_rate()
Query rate parameter for conditional gamma component of cloglog leaf models for this ForestModelConfig object
ForestModelConfig$get_cloglog_forest_rate()
get_cutpoint_grid_size()
Query number of unique cutpoints to consider for this ForestModelConfig object
ForestModelConfig$get_cutpoint_grid_size()
get_num_features_subsample()
Query number of features to subsample for the GFR algorithm
ForestModelConfig$get_num_features_subsample()
Wrapper around a C++ class that stores draws from an random ensemble of decision trees.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
forest_container_ptrExternal pointer to a C++ ForestContainer class
new()
Create a new ForestContainer object.
ForestSamples$new( num_trees, leaf_dimension = 1, is_leaf_constant = FALSE, is_exponentiated = FALSE )
num_treesNumber of trees
leaf_dimensionDimensionality of the outcome model
is_leaf_constantWhether leaf is constant
is_exponentiatedWhether forest predictions should be exponentiated before being returned
A new ForestContainer object.
collapse()
Collapse forests in this container by a pre-specified batch size.
For example, if we have a container of twenty 10-tree forests, and we
specify a batch_size of 5, then this method will yield four 50-tree
forests. "Excess" forests remaining after the size of a forest container
is divided by batch_size will be pruned from the beginning of the
container (i.e. earlier sampled forests will be deleted). This method
has no effect if batch_size is larger than the number of forests
in a container.
ForestSamples$collapse(batch_size)
batch_sizeNumber of forests to be collapsed into a single forest
combine_forests()
Merge specified forests into a single forest
ForestSamples$combine_forests(forest_inds)
forest_indsIndices of forests to be combined (0-indexed)
add_to_forest()
Add a constant value to every leaf of every tree of a given forest
ForestSamples$add_to_forest(forest_index, constant_value)
forest_indexIndex of forest whose leaves will be modified (0-indexed)
constant_valueValue to add to every leaf of every tree of the forest at forest_index
multiply_forest()
Multiply every leaf of every tree of a given forest by constant value
ForestSamples$multiply_forest(forest_index, constant_multiple)
forest_indexIndex of forest whose leaves will be modified (0-indexed)
constant_multipleValue to multiply through by every leaf of every tree of the forest at forest_index
load_from_json()
Create a new ForestContainer object from a json object
ForestSamples$load_from_json(json_object, json_forest_label)
json_objectObject of class CppJson
json_forest_labelLabel referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
A new ForestContainer object.
append_from_json()
Append to a ForestContainer object from a json object
ForestSamples$append_from_json(json_object, json_forest_label)
json_objectObject of class CppJson
json_forest_labelLabel referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
None
load_from_json_string()
Create a new ForestContainer object from a json object
ForestSamples$load_from_json_string(json_string, json_forest_label)
json_stringJSON string which parses into object of class CppJson
json_forest_labelLabel referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
A new ForestContainer object.
append_from_json_string()
Append to a ForestContainer object from a json object
ForestSamples$append_from_json_string(json_string, json_forest_label)
json_stringJSON string which parses into object of class CppJson
json_forest_labelLabel referring to a particular forest (i.e. "forest_0") in the overall json hierarchy
None
predict()
Predict every tree ensemble on every sample in forest_dataset
ForestSamples$predict(forest_dataset)
forest_datasetForestDataset R class
matrix of predictions with as many rows as in forest_dataset
and as many columns as samples in the ForestContainer
predict_raw()
Predict "raw" leaf values (without being multiplied by basis) for every tree ensemble on every sample in forest_dataset
ForestSamples$predict_raw(forest_dataset)
forest_datasetForestDataset R class
Array of predictions for each observation in forest_dataset and
each sample in the ForestSamples class with each prediction having the
dimensionality of the forests' leaf model. In the case of a constant leaf model
or univariate leaf regression, this array is two-dimensional (number of observations,
number of forest samples). In the case of a multivariate leaf regression,
this array is three-dimension (number of observations, leaf model dimension,
number of samples).
predict_raw_single_forest()
Predict "raw" leaf values (without being multiplied by basis) for a specific forest on every sample in forest_dataset
ForestSamples$predict_raw_single_forest(forest_dataset, forest_num)
forest_datasetForestDataset R class
forest_numIndex of the forest sample within the container
matrix of predictions with as many rows as in forest_dataset
and as many columns as dimensions in the leaves of trees in ForestContainer
predict_raw_single_tree()
Predict "raw" leaf values (without being multiplied by basis) for a specific tree in a specific forest on every observation in forest_dataset
ForestSamples$predict_raw_single_tree(forest_dataset, forest_num, tree_num)
forest_datasetForestDataset R class
forest_numIndex of the forest sample within the container
tree_numIndex of the tree to be queried
matrix of predictions with as many rows as in forest_dataset
and as many columns as dimensions in the leaves of trees in ForestContainer
set_root_leaves()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
ForestSamples$set_root_leaves(forest_num, leaf_value)
forest_numIndex of the forest sample within the container.
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
prepare_for_sampler()
Set a constant predicted value for every tree in the ensemble. Stops program if any tree is more than a root node.
ForestSamples$prepare_for_sampler( dataset, outcome, forest_model, leaf_model_int, leaf_value )
datasetForestDataset Dataset class (covariates, basis, etc...)
outcomeOutcome Outcome class (residual / partial residual)
forest_modelForestModel object storing tracking structures used in training / sampling
leaf_model_intInteger value encoding the leaf model type (0 = constant gaussian, 1 = univariate gaussian, 2 = multivariate gaussian, 3 = log linear variance).
leaf_valueConstant leaf value(s) to be fixed for each tree in the ensemble indexed by forest_num. Can be either a single number or a vector, depending on the forest's leaf dimension.
adjust_residual()
Adjusts residual based on the predictions of a forest
This is typically run just once at the beginning of a forest sampling algorithm. After trees are initialized with constant root node predictions, their root predictions are subtracted out of the residual.
ForestSamples$adjust_residual( dataset, outcome, forest_model, requires_basis, forest_num, add )
datasetForestDataset object storing the covariates and bases for a given forest
outcomeOutcome object storing the residuals to be updated based on forest predictions
forest_modelForestModel object storing tracking structures used in training / sampling
requires_basisWhether or not a forest requires a basis for prediction
forest_numIndex of forest used to update residuals
addWhether forest predictions should be added to or subtracted from residuals
save_json()
Store the trees and metadata of ForestDataset class in a json file
ForestSamples$save_json(json_filename)
json_filenameName of output json file (must end in ".json")
load_json()
Load trees and metadata for an ensemble from a json file. Note that
any trees and metadata already present in ForestDataset class will
be overwritten.
ForestSamples$load_json(json_filename)
json_filenameName of model input json file (must end in ".json")
num_samples()
Return number of samples in a ForestContainer object
ForestSamples$num_samples()
Sample count
num_trees()
Return number of trees in each ensemble of a ForestContainer object
ForestSamples$num_trees()
Tree count
leaf_dimension()
Return output dimension of trees in a ForestContainer object
ForestSamples$leaf_dimension()
Leaf node parameter size
is_leaf_constant()
Return constant leaf status of trees in a ForestContainer object
ForestSamples$is_leaf_constant()
TRUE if leaves are constant, FALSE otherwise
is_exponentiated()
Return exponentiation status of trees in a ForestContainer object
ForestSamples$is_exponentiated()
TRUE if leaf predictions must be exponentiated, FALSE otherwise
add_forest_with_constant_leaves()
Add a new all-root ensemble to the container, with all of the leaves set to the value / vector provided
ForestSamples$add_forest_with_constant_leaves(leaf_value)
leaf_valueValue (or vector of values) to initialize root nodes in tree
add_numeric_split_tree()
Add a numeric (i.e. X[,i] <= c) split to a given tree in the ensemble
ForestSamples$add_numeric_split_tree( forest_num, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value )
forest_numIndex of the forest which contains the tree to be split
tree_numIndex of the tree to be split
leaf_numLeaf to be split
feature_numFeature that defines the new split
split_thresholdValue that defines the cutoff of the new split
left_leaf_valueValue (or vector of values) to assign to the newly created left node
right_leaf_valueValue (or vector of values) to assign to the newly created right node
get_tree_leaves()
Retrieve a vector of indices of leaf nodes for a given tree in a given forest
ForestSamples$get_tree_leaves(forest_num, tree_num)
forest_numIndex of the forest which contains tree tree_num
tree_numIndex of the tree for which leaf indices will be retrieved
get_tree_split_counts()
Retrieve a vector of split counts for every training set variable in a given tree in a given forest
ForestSamples$get_tree_split_counts(forest_num, tree_num, num_features)
forest_numIndex of the forest which contains tree tree_num
tree_numIndex of the tree for which split counts will be retrieved
num_featuresTotal number of features in the training set
get_forest_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest
ForestSamples$get_forest_split_counts(forest_num, num_features)
forest_numIndex of the forest for which split counts will be retrieved
num_featuresTotal number of features in the training set
get_aggregate_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest, aggregated across ensembles and trees
ForestSamples$get_aggregate_split_counts(num_features)
num_featuresTotal number of features in the training set
get_granular_split_counts()
Retrieve a vector of split counts for every training set variable in a given forest, reported separately for each ensemble and tree
ForestSamples$get_granular_split_counts(num_features)
num_featuresTotal number of features in the training set
ensemble_tree_max_depth()
Maximum depth of a specific tree in a specific ensemble in a ForestSamples object
ForestSamples$ensemble_tree_max_depth(ensemble_num, tree_num)
ensemble_numEnsemble number
tree_numTree index within ensemble ensemble_num
Maximum leaf depth
average_ensemble_max_depth()
Average the maximum depth of each tree in a given ensemble in a ForestSamples object
ForestSamples$average_ensemble_max_depth(ensemble_num)
ensemble_numEnsemble number
Average maximum depth
average_max_depth()
Average the maximum depth of each tree in each ensemble in a ForestContainer object
ForestSamples$average_max_depth()
Average maximum depth
num_forest_leaves()
Number of leaves in a given ensemble in a ForestSamples object
ForestSamples$num_forest_leaves(forest_num)
forest_numIndex of the ensemble to be queried
Count of leaves in the ensemble stored at forest_num
sum_leaves_squared()
Sum of squared (raw) leaf values in a given ensemble in a ForestSamples object
ForestSamples$sum_leaves_squared(forest_num)
forest_numIndex of the ensemble to be queried
Average maximum depth
is_leaf_node()
Whether or not a given node of a given tree in a given forest in the ForestSamples is a leaf
ForestSamples$is_leaf_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
TRUE if node is a leaf, FALSE otherwise
is_numeric_split_node()
Whether or not a given node of a given tree in a given forest in the ForestSamples is a numeric split node
ForestSamples$is_numeric_split_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
TRUE if node is a numeric split node, FALSE otherwise
is_categorical_split_node()
Whether or not a given node of a given tree in a given forest in the ForestSamples is a categorical split node
ForestSamples$is_categorical_split_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
TRUE if node is a categorical split node, FALSE otherwise
parent_node()
Parent node of given node of a given tree in a given forest in a ForestSamples object
ForestSamples$parent_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Integer ID of the parent node
left_child_node()
Left child node of given node of a given tree in a given forest in a ForestSamples object
ForestSamples$left_child_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Integer ID of the left child node
right_child_node()
Right child node of given node of a given tree in a given forest in a ForestSamples object
ForestSamples$right_child_node(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Integer ID of the right child node
node_depth()
Depth of given node of a given tree in a given forest in a ForestSamples object, with 0 depth for the root node.
ForestSamples$node_depth(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Integer valued depth of the node
node_split_index()
Split index of given node of a given tree in a given forest in a ForestSamples object. Returns -1 is node is a leaf.
ForestSamples$node_split_index(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Integer valued depth of the node
node_split_threshold()
Threshold that defines a numeric split for a given node of a given tree in a given forest in a ForestSamples object.
Returns Inf if the node is a leaf or a categorical split node.
ForestSamples$node_split_threshold(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Threshold defining a split for the node
node_split_categories()
Array of category indices that define a categorical split for a given node of a given tree in a given forest in a ForestSamples object.
Returns c(Inf) if the node is a leaf or a numeric split node.
ForestSamples$node_split_categories(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Categories defining a split for the node
node_leaf_values()
Leaf node value(s) for a given node of a given tree in a given forest in a ForestSamples object.
Values are stale if the node is a split node.
ForestSamples$node_leaf_values(forest_num, tree_num, node_id)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
node_idIndex of the node to be queried
Vector (often univariate) of leaf values
num_nodes()
Number of nodes in a given tree in a given forest in a ForestSamples object.
ForestSamples$num_nodes(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Count of total tree nodes
num_leaves()
Number of leaves in a given tree in a given forest in a ForestSamples object.
ForestSamples$num_leaves(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Count of total tree leaves
num_leaf_parents()
Number of leaf parents (split nodes with two leaves as children) in a given tree in a given forest in a ForestSamples object.
ForestSamples$num_leaf_parents(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Count of total tree leaf parents
num_split_nodes()
Number of split nodes in a given tree in a given forest in a ForestSamples object.
ForestSamples$num_split_nodes(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Count of total tree split nodes
nodes()
Array of node indices in a given tree in a given forest in a ForestSamples object.
ForestSamples$nodes(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Indices of tree nodes
leaves()
Array of leaf indices in a given tree in a given forest in a ForestSamples object.
ForestSamples$leaves(forest_num, tree_num)
forest_numIndex of the forest to be queried
tree_numIndex of the tree to be queried
Indices of leaf nodes
delete_sample()
Modify the ForestSamples object by removing the forest sample indexed by 'forest_num
ForestSamples$delete_sample(forest_num)
forest_numIndex of the forest to be removed
While the BARTSerialization and BCFSerialization topics focus on JSON serialization / deserialization for
entire bartmodel and bcfmodel objects, this function group provides an interface for a more focused use case:
loading a single ForestSamples container from a broader BART / BCF model (which may include multiple forests and other parametric terms).
loadForestContainerJson converts a CppJson object representing a BART or BCF model into a ForestSamples container
by extracting the JSON indexed by a forest label (i.e. "forest_0") and deserializing it into a ForestSamples object.
Both loadForestContainerJson and loadForestContainerCombinedJson operate similarly, but on a list of CppJson or JSON string
representations of BART / BCF models with the same structure.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
loadForestContainerJson(json_object, json_forest_label) loadForestContainerCombinedJson(json_object_list, json_forest_label) loadForestContainerCombinedJsonString(json_string_list, json_forest_label)loadForestContainerJson(json_object, json_forest_label) loadForestContainerCombinedJson(json_object_list, json_forest_label) loadForestContainerCombinedJsonString(json_string_list, json_forest_label)
json_object |
Object of class |
json_forest_label |
Label referring to a particular forest (i.e. "forest_0") in the overall json hierarchy (must exist in every json object in a list if a list is provided) |
json_object_list |
List of objects of class |
json_string_list |
List of strings that parse into objects of type |
Each of the functions in this group returns a ForestSamples object.
X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) bart_json <- saveBARTModelToJson(bart_model) bart_json_string <- saveBARTModelToJsonString(bart_model) bart_json_list <- list(bart_json) bart_json_string_list <- list(bart_json_string) mean_forest <- loadForestContainerJson(bart_json, "forest_0") mean_forest <- loadForestContainerCombinedJson(bart_json_list, "forest_0") mean_forest <- loadForestContainerCombinedJsonString(bart_json_string_list, "forest_0")X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) bart_model <- bart(X, y, num_gfr=0, num_mcmc=10) bart_json <- saveBARTModelToJson(bart_model) bart_json_string <- saveBARTModelToJsonString(bart_model) bart_json_list <- list(bart_json) bart_json_string_list <- list(bart_json_string) mean_forest <- loadForestContainerJson(bart_json, "forest_0") mean_forest <- loadForestContainerCombinedJson(bart_json_list, "forest_0") mean_forest <- loadForestContainerCombinedJsonString(bart_json_string_list, "forest_0")
A forest sampler features two types of state: ephemeral and persistent. Persistent state includes objects like ForestSamples and RandomEffectSamples which constitute part of the final sampled model. Ephemeral state supports the sampling computations, but is not retained after the sampler finishes.
The two primary forest-based bits of ephemeral state are the Forest and ForestModel classes, which represent the current state of a forest and its corresponding tracking data structures.
In a linear sampling loop, this ephemeral state is updated with each iteration of the sampler and any retained forests are copied to a ForestSamples object. However, in multi-chain settings, the state of a forest must typically be "reset" at the beginning of a new chain. These function enable this process by synchronizing the state of a Forest and ForestModel with a corresponding element of a ForestSamples object, or by resetting both to their default (root) state.
resetActiveForest resets a Forest object, either from a specific forest in a ForestSamples
object or to an ensemble of single-node (i.e. root) trees.
resetForestModel re-initializes a forest model (tracking data structures) from a specific forest in a ForestSamples object.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
resetActiveForest(active_forest, forest_samples = NULL, forest_num = NULL) resetForestModel(forest_model, forest, dataset, residual, is_mean_model)resetActiveForest(active_forest, forest_samples = NULL, forest_num = NULL) resetForestModel(forest_model, forest, dataset, residual, is_mean_model)
active_forest |
Current active forest |
forest_samples |
(Optional) Container of forest samples from which to re-initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees. |
forest_num |
(Optional) Index of forest samples from which to initialize active forest. If not provided, active forest will be reset to an ensemble of single-node (i.e. root) trees. |
forest_model |
Forest model with tracking data structures |
forest |
Forest from which to re-initialize forest model |
dataset |
Training dataset object |
residual |
Residual which will also be updated |
is_mean_model |
Whether the model being updated is a conditional mean model |
Both functions have no return type and operate in-place on the relevant Forest or ForestModel objects
n <- 100 p <- 10 num_trees <- 100 leaf_dimension <- 1 is_leaf_constant <- TRUE is_exponentiated <- FALSE alpha <- 0.95 beta <- 2.0 min_samples_leaf <- 2 max_depth <- 10 feature_types <- as.integer(rep(0, p)) leaf_model <- 0 sigma2 <- 1.0 leaf_scale <- as.matrix(1.0) variable_weights <- rep(1/p, p) a_forest <- 1 b_forest <- 1 cutpoint_grid_size <- 100 X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)n <- 100 p <- 10 num_trees <- 100 leaf_dimension <- 1 is_leaf_constant <- TRUE is_exponentiated <- FALSE alpha <- 0.95 beta <- 2.0 min_samples_leaf <- 2 max_depth <- 10 feature_types <- as.integer(rep(0, p)) leaf_model <- 0 sigma2 <- 1.0 leaf_scale <- as.matrix(1.0) variable_weights <- rep(1/p, p) a_forest <- 1 b_forest <- 1 cutpoint_grid_size <- 100 X <- matrix(runif(n*p), ncol = p) forest_dataset <- createForestDataset(X) y <- -5 + 10*(X[,1] > 0.5) + rnorm(n) outcome <- createOutcome(y) rng <- createCppRNG(1234) global_model_config <- createGlobalModelConfig(global_error_variance=sigma2) forest_model_config <- createForestModelConfig(feature_types=feature_types, num_trees=num_trees, num_observations=n, num_features=p, alpha=alpha, beta=beta, min_samples_leaf=min_samples_leaf, max_depth=max_depth, variable_weights=variable_weights, cutpoint_grid_size=cutpoint_grid_size, leaf_model_type=leaf_model, leaf_model_scale=leaf_scale) forest_model <- createForestModel(forest_dataset, forest_model_config, global_model_config) active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) forest_samples <- createForestSamples(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) active_forest$prepare_for_sampler(forest_dataset, outcome, forest_model, 0, 0.) forest_model$sample_one_iteration( forest_dataset, outcome, forest_samples, active_forest, rng, forest_model_config, global_model_config, keep_forest = TRUE, gfr = FALSE ) resetActiveForest(active_forest, forest_samples, 0) resetForestModel(forest_model, active_forest, forest_dataset, outcome, TRUE)
Generic function for extracting random effect samples from a model object (BCF, BART, etc...)
getRandomEffectSamples(object, ...)getRandomEffectSamples(object, ...)
object |
Fitted model object from which to extract random effects |
... |
Other parameters to be used in random effects extraction |
List of random effect samples
n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) rfx_samples <- getRandomEffectSamples(bart_model)n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) rfx_samples <- getRandomEffectSamples(bart_model)
Extract raw sample values for each of the random effect parameter terms.
## S3 method for class 'bartmodel' getRandomEffectSamples(object, ...)## S3 method for class 'bartmodel' getRandomEffectSamples(object, ...)
object |
Object of type |
... |
Other parameters to be used in random effects extraction |
List of arrays. The alpha array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
The xi and beta arrays have dimension (num_components, num_groups, num_samples) and is simply a matrix if num_components = 1.
The sigma array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) E_y <- f_XW + rfx_term y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] rfx_group_ids_test <- group_ids[test_inds] rfx_group_ids_train <- group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model)n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) snr <- 3 group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[group_ids,] * rfx_basis) E_y <- f_XW + rfx_term y <- E_y + rnorm(n, 0, 1)*(sd(E_y)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] rfx_group_ids_test <- group_ids[test_inds] rfx_group_ids_train <- group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test, rfx_group_ids_train = rfx_group_ids_train, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_train = rfx_basis_train, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10) rfx_samples <- getRandomEffectSamples(bart_model)
Extract raw sample values for each of the random effect parameter terms.
## S3 method for class 'bcfmodel' getRandomEffectSamples(object, ...)## S3 method for class 'bcfmodel' getRandomEffectSamples(object, ...)
object |
Object of type |
... |
Other parameters to be used in random effects extraction |
List of arrays. The alpha array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
The xi and beta arrays have dimension (num_components, num_groups, num_samples) and is simply a matrix if num_components = 1.
The sigma array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) E_XZ <- mu_x + Z*tau_x snr <- 3 rfx_group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] rfx_group_ids_test <- rfx_group_ids[test_inds] rfx_group_ids_train <- rfx_group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10, prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model)n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) E_XZ <- mu_x + Z*tau_x snr <- 3 rfx_group_ids <- rep(c(1,2), n %/% 2) rfx_coefs <- matrix(c(-1, -1, 1, 1), nrow=2, byrow=TRUE) rfx_basis <- cbind(1, runif(n, -1, 1)) rfx_term <- rowSums(rfx_coefs[rfx_group_ids,] * rfx_basis) y <- E_XZ + rfx_term + rnorm(n, 0, 1)*(sd(E_XZ)/snr) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] rfx_group_ids_test <- rfx_group_ids[test_inds] rfx_group_ids_train <- rfx_group_ids[train_inds] rfx_basis_test <- rfx_basis[test_inds,] rfx_basis_train <- rfx_basis[train_inds,] rfx_term_test <- rfx_term[test_inds] rfx_term_train <- rfx_term[train_inds] mu_params <- list(sample_sigma2_leaf = TRUE) tau_params <- list(sample_sigma2_leaf = FALSE) bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, rfx_group_ids_train = rfx_group_ids_train, rfx_basis_train = rfx_basis_train, X_test = X_test, Z_test = Z_test, propensity_test = pi_test, rfx_group_ids_test = rfx_group_ids_test, rfx_basis_test = rfx_basis_test, num_gfr = 10, num_burnin = 0, num_mcmc = 10, prognostic_forest_params = mu_params, treatment_effect_forest_params = tau_params) rfx_samples <- getRandomEffectSamples(bcf_model)
Object used to get / set global parameters and other global model configuration options in the "low-level" stochtree interface. The "low-level" stochtree interface enables a high degreee of sampler customization, in which users employ R wrappers around C++ objects like ForestDataset, Outcome, CppRng, and ForestModel to run the Gibbs sampler of a BART model with custom modifications. GlobalModelConfig allows users to specify / query the global parameters of a model they wish to run.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
Global error variance parameter
global_error_varianceGlobal error variance parameter Create a new GlobalModelConfig object.
new()
GlobalModelConfig$new(global_error_variance = 1)
global_error_varianceGlobal error variance parameter (default: 1.0)
A new GlobalModelConfig object.
update_global_error_variance()
Update global error variance parameter
GlobalModelConfig$update_global_error_variance(global_error_variance)
global_error_varianceGlobal error variance parameter
get_global_error_variance()
Query global error variance parameter for this GlobalModelConfig object
GlobalModelConfig$get_global_error_variance()
The CppJson class wraps an external pointer to C++ JSON object for a seamless R serialization interface. This function group provides several utilities for creating CppJson objects: from scratch (i.e. initializing an empty JSON object), from a JSON file, or from a string that parses to valid JSON.
createCppJson creates a new (empty) CppJson object. createCppJsonFile creates and populates a CppJson object from data in a JSON file.
createCppJsonString creates and populates a CppJson object from a JSON string.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
createCppJson() createCppJsonFile(json_filename) createCppJsonString(json_string)createCppJson() createCppJsonFile(json_filename) createCppJsonString(json_string)
json_filename |
Name of JSON file to read. Must end in |
json_string |
JSON string |
Each of the functions in this group returns a CppJson object.
example_vec <- runif(10) example_json <- createCppJson() example_json$add_vector("myvec", example_vec) tmpjson <- tempfile(fileext = ".json") example_json$save_file(file.path(tmpjson)) example_json_roundtrip <- createCppJsonFile(file.path(tmpjson)) unlink(tmpjson) example_json_string <- example_json$return_json_string() example_json_roundtrip <- createCppJsonString(example_json_string)example_vec <- runif(10) example_json <- createCppJson() example_json$add_vector("myvec", example_vec) tmpjson <- tempfile(fileext = ".json") example_json$save_file(file.path(tmpjson)) example_json_roundtrip <- createCppJsonFile(file.path(tmpjson)) unlink(tmpjson) example_json_string <- example_json$return_json_string() example_json_roundtrip <- createCppJsonString(example_json_string)
Load a scalar from json
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
loadScalarJson(json_object, json_scalar_label, subfolder_name = NULL)loadScalarJson(json_object, json_scalar_label, subfolder_name = NULL)
json_object |
Object of class |
json_scalar_label |
Label referring to a particular scalar / string value (i.e. "num_samples") in the overall json hierarchy |
subfolder_name |
(Optional) Name of the subfolder / hierarchy under which vector sits |
R vector
example_scalar <- 5.4 example_json <- createCppJson() example_json$add_scalar("myscalar", example_scalar) roundtrip_scalar <- loadScalarJson(example_json, "myscalar")example_scalar <- 5.4 example_json <- createCppJson() example_json$add_scalar("myscalar", example_scalar) roundtrip_scalar <- loadScalarJson(example_json, "myscalar")
Load a vector from json
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
loadVectorJson(json_object, json_vector_label, subfolder_name = NULL)loadVectorJson(json_object, json_vector_label, subfolder_name = NULL)
json_object |
Object of class |
json_vector_label |
Label referring to a particular vector (i.e. "sigma2_global_samples") in the overall json hierarchy |
subfolder_name |
(Optional) Name of the subfolder / hierarchy under which vector sits |
R vector
example_vec <- runif(10) example_json <- createCppJson() example_json$add_vector("myvec", example_vec) roundtrip_vec <- loadVectorJson(example_json, "myvec")example_vec <- runif(10) example_json <- createCppJson() example_json$add_vector("myvec", example_vec) roundtrip_vec <- loadVectorJson(example_json, "myvec")
Outcome / partial residual used to sample an additive model. The outcome class is a wrapper around a vector of (mutable) outcomes for ML tasks (supervised learning, causal inference). When an additive tree ensemble is sampled, the outcome used to sample a specific model term is the "partial residual" consisting of the outcome minus the predictions of every other model term (trees, group random effects, etc...).
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
data_ptrExternal pointer to a C++ Outcome class
new()
Create a new Outcome object.
Outcome$new(outcome)
outcomeVector of outcome values
A new Outcome object.
get_data()
Extract raw data in R from the underlying C++ object
Outcome$get_data()
R vector containing (copy of) the values in Outcome object
add_vector()
Update the current state of the outcome (i.e. partial residual) data by adding the values of update_vector
Outcome$add_vector(update_vector)
update_vectorVector to be added to outcome
None
subtract_vector()
Update the current state of the outcome (i.e. partial residual) data by subtracting the values of update_vector
Outcome$subtract_vector(update_vector)
update_vectorVector to be subtracted from outcome
None
update_data()
Update the current state of the outcome (i.e. partial residual) data by replacing each element with the elements of new_vector
Outcome$update_data(new_vector)
new_vectorVector from which to overwrite the current data
None
Create a New Outcome Model Object
OutcomeModel(outcome = "continuous", link = NULL)OutcomeModel(outcome = "continuous", link = NULL)
outcome |
Character string specifying the outcome type. |
link |
Character string specifying the link function. |
An object of class outcome_model.
my_model <- OutcomeModel(outcome = "continuous", link = "identity")my_model <- OutcomeModel(outcome = "continuous", link = "identity")
Plot the BART model fit and any relevant sampled quantities. This will default to a traceplot of the global error scale and the in-sample mean forest predictions for the first train set observation. Since stochtree::bart() is flexible and it's possible to sample a model with a fixed global error scale and no mean forest, this procedure is adaptive and will attempt to plot a trace of whichever model terms are included if these two default terms are omitted.
## S3 method for class 'bartmodel' plot(x, ...)## S3 method for class 'bartmodel' plot(x, ...)
x |
The BART model object |
... |
Additional arguments |
BART model object unchanged after summarizing
Plot the BCF model fit and any relevant sampled quantities. This will default to a traceplot of the global error scale and the in-sample mean forest predictions for the first train set observation. Since stochtree::bcf() is flexible and it's possible to sample a model with a fixed global error scale and no mean forest, this procedure will throw an error if these two default terms are omitted.
## S3 method for class 'bcfmodel' plot(x, ...)## S3 method for class 'bcfmodel' plot(x, ...)
x |
The BCF model object |
... |
Additional arguments |
BCF model object unchanged after summarizing
Predict from a sampled BART model on new data
## S3 method for class 'bartmodel' predict( object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", scale = "linear", ... )## S3 method for class 'bartmodel' predict( object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", scale = "linear", ... )
object |
Object of type |
X |
Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. |
leaf_basis |
(Optional) Bases used for prediction (by e.g. dot product with leaf values). Default: |
rfx_group_ids |
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
rfx_basis |
(Optional) Test set basis for "random-slope" regression in additive random effects model. |
type |
(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
terms |
(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "mean_forest", "rfx", "variance_forest", or "all". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return |
scale |
(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, "probability", which transforms predictions into class probabilities for models with discrete outcomes, and "class", which returns predicted outcome categories for discrete outcome models. "probability" is only valid for outcome models with |
... |
(Optional) Other prediction parameters. |
List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
n <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X=X_test)$y_hatn <- 100 p <- 5 X <- matrix(runif(n*p), ncol = p) f_XW <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] y_test <- y[test_inds] y_train <- y[train_inds] bart_model <- bart(X_train = X_train, y_train = y_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) y_hat_test <- predict(bart_model, X=X_test)$y_hat
Predict from a sampled BCF model on new data
## S3 method for class 'bcfmodel' predict( object, X, Z, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", scale = "linear", ... )## S3 method for class 'bcfmodel' predict( object, X, Z, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, type = "posterior", terms = "all", scale = "linear", ... )
object |
Object of type |
X |
Covariates used to determine tree leaf predictions for each observation. Must be passed as a matrix or dataframe. |
Z |
Treatments used for prediction. |
propensity |
(Optional) Propensities used for prediction. |
rfx_group_ids |
(Optional) Test set group labels used for an additive random effects model. We do not currently support (but plan to in the near future), test set evaluation for group labels that were not in the training set. |
rfx_basis |
(Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects |
type |
(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". |
terms |
(Optional) Which model terms to include in the prediction. Options include The treatment effect terms follow a three-level hierarchy:
Similarly for the prognostic term: If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return |
scale |
(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing |
... |
(Optional) Other prediction parameters. |
List of prediction matrices or single prediction matrix / vector, depending on the terms requested.
n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test)n <- 500 p <- 5 X <- matrix(runif(n*p), ncol = p) mu_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5) ) pi_x <- ( ((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) + ((0.75 <= X[,1]) & (1 > X[,1])) * (0.8) ) tau_x <- ( ((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) + ((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) + ((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) + ((0.75 <= X[,2]) & (1 > X[,2])) * (2.0) ) Z <- rbinom(n, 1, pi_x) noise_sd <- 1 y <- mu_x + tau_x*Z + rnorm(n, 0, noise_sd) test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- X[test_inds,] X_train <- X[train_inds,] pi_test <- pi_x[test_inds] pi_train <- pi_x[train_inds] Z_test <- Z[test_inds] Z_train <- Z[train_inds] y_test <- y[test_inds] y_train <- y[train_inds] mu_test <- mu_x[test_inds] mu_train <- mu_x[train_inds] tau_test <- tau_x[test_inds] tau_train <- tau_x[train_inds] bcf_model <- bcf(X_train = X_train, Z_train = Z_train, y_train = y_train, propensity_train = pi_train, num_gfr = 10, num_burnin = 0, num_mcmc = 10) preds <- predict(bcf_model, X_test, Z_test, pi_test)
Prints a summary of the BART model, including the model terms and their specifications.
## S3 method for class 'bartmodel' print(x, ...)## S3 method for class 'bartmodel' print(x, ...)
x |
The BART model object |
... |
Additional arguments |
BART model object unchanged after printing summary
Prints a summary of the BCF model, including the model terms and their specifications.
## S3 method for class 'bcfmodel' print(x, ...)## S3 method for class 'bcfmodel' print(x, ...)
x |
The BCF model object |
... |
Additional arguments (currently unused) |
BCF model object unchanged after printing summary
Prints a summary of the ForestSamples object, including number of forests and the underlying model of each forest.
## S3 method for class 'ForestSamples' print(x, ...)## S3 method for class 'ForestSamples' print(x, ...)
x |
ForestSamples object |
... |
Additional arguments |
ForestSamples object unchanged after printing summary
Prints a summary of the RandomEffectSamples object, including number of forests and the underlying model of each forest.
## S3 method for class 'RandomEffectSamples' print(x, ...)## S3 method for class 'RandomEffectSamples' print(x, ...)
x |
RandomEffectSamples object |
... |
Additional arguments |
RandomEffectSamples object unchanged after printing summary
Class that wraps the "persistent" aspects of a C++ random effects model, including draws of the parameters and a map from the original label indices to the 0-indexed label numbers used to place group samples in memory (i.e. the first label is stored in column 0 of the sample matrix, the second label is store in column 1 of the sample matrix, etc...)
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
Coordinates various C++ random effects classes and persists those needed for prediction / serialization
rfx_container_ptrExternal pointer to a C++ StochTree::RandomEffectsContainer class
label_mapper_ptrExternal pointer to a C++ StochTree::LabelMapper class
training_group_idsUnique vector of group IDs that were in the training dataset
new()
Create a new RandomEffectSamples object.
RandomEffectSamples$new()
A new RandomEffectSamples object.
load_in_session()
Construct RandomEffectSamples object from other "in-session" R objects
RandomEffectSamples$load_in_session( num_components, num_groups, random_effects_tracker )
num_componentsNumber of "components" or bases defining the random effects regression
num_groupsNumber of random effects groups
random_effects_trackerObject of type RandomEffectsTracker
None
load_from_json()
Construct RandomEffectSamples object from a json object
RandomEffectSamples$load_from_json( json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label )
json_objectObject of class CppJson
json_rfx_container_labelLabel referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy
json_rfx_mapper_labelLabel referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy
json_rfx_groupids_labelLabel referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy
A new RandomEffectSamples object.
append_from_json()
Append random effect draws to RandomEffectSamples object from a json object
RandomEffectSamples$append_from_json( json_object, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label )
json_objectObject of class CppJson
json_rfx_container_labelLabel referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy
json_rfx_mapper_labelLabel referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy
json_rfx_groupids_labelLabel referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy
None
load_from_json_string()
Construct RandomEffectSamples object from a json object
RandomEffectSamples$load_from_json_string( json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label )
json_stringJSON string which parses into object of class CppJson
json_rfx_container_labelLabel referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy
json_rfx_mapper_labelLabel referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy
json_rfx_groupids_labelLabel referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy
A new RandomEffectSamples object.
append_from_json_string()
Append random effect draws to RandomEffectSamples object from a json object
RandomEffectSamples$append_from_json_string( json_string, json_rfx_container_label, json_rfx_mapper_label, json_rfx_groupids_label )
json_stringJSON string which parses into object of class CppJson
json_rfx_container_labelLabel referring to a particular rfx sample container (i.e. "random_effect_container_0") in the overall json hierarchy
json_rfx_mapper_labelLabel referring to a particular rfx label mapper (i.e. "random_effect_label_mapper_0") in the overall json hierarchy
json_rfx_groupids_labelLabel referring to a particular set of rfx group IDs (i.e. "random_effect_groupids_0") in the overall json hierarchy
None
predict()
Predict random effects for each observation implied by rfx_group_ids and rfx_basis.
If a random effects model is "intercept-only" the rfx_basis will be a vector of ones of size length(rfx_group_ids).
RandomEffectSamples$predict(rfx_group_ids, rfx_basis = NULL)
rfx_group_idsIndices of random effects groups in a prediction set
rfx_basis(Optional) Basis used for random effects prediction
Matrix with as many rows as observations provided and as many columns as samples drawn of the model.
extract_parameter_samples()
Extract the random effects parameters sampled. With the "redundant parameterization" of Gelman et al (2008), this includes four parameters: alpha (the "working parameter" shared across every group), xi (the "group parameter" sampled separately for each group), beta (the product of alpha and xi, which corresponds to the overall group-level random effects), and sigma (group-independent prior variance for each component of xi).
RandomEffectSamples$extract_parameter_samples()
List of arrays. The alpha array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
The xi and beta arrays have dimension (num_components, num_groups, num_samples) and are simply matrices if num_components = 1.
The sigma array has dimension (num_components, num_samples) and is simply a vector if num_components = 1.
delete_sample()
Modify the RandomEffectsSamples object by removing the parameter samples index by sample_num.
RandomEffectSamples$delete_sample(sample_num)
sample_numIndex of the RFX sample to be removed
extract_label_mapping()
Convert the mapping of group IDs to random effect components indices from C++ to R native format
RandomEffectSamples$extract_label_mapping()
List mapping group ID to random effect components.
num_samples()
Query the number of samples in the RandomEffectsSamples object.
RandomEffectSamples$num_samples()
Integer number of samples
num_components()
Query the number of components in the RandomEffectsSamples object.
RandomEffectSamples$num_components()
Integer number of components
num_groups()
Query the number of groups in the RandomEffectsSamples object.
RandomEffectSamples$num_groups()
Integer number of groups
While the BARTSerialization and BCFSerialization topics focus on JSON serialization / deserialization for
entire bartmodel and bcfmodel objects, this function group provides an interface for a more focused use case:
loading a single RandomEffectSamples container from a broader BART / BCF model (which may include forests and other parametric terms).
loadRandomEffectSamplesJson converts a CppJson object representing a BART or BCF model into a RandomEffectSamples container
by extracting the JSON indexed by an integer label (i.e. 0) and deserializing it into a RandomEffectSamples object.
Both loadRandomEffectSamplesJson and loadRandomEffectSamplesCombinedJson operate similarly, but on a list of CppJson or JSON string
representations of BART / BCF models with the same structure.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
loadRandomEffectSamplesJson(json_object, json_rfx_num) loadRandomEffectSamplesCombinedJson(json_object_list, json_rfx_num) loadRandomEffectSamplesCombinedJsonString(json_string_list, json_rfx_num)loadRandomEffectSamplesJson(json_object, json_rfx_num) loadRandomEffectSamplesCombinedJson(json_object_list, json_rfx_num) loadRandomEffectSamplesCombinedJsonString(json_string_list, json_rfx_num)
json_object |
Object of class |
json_rfx_num |
Integer index indicating the position of the random effects term to be unpacked (must exist in every json object in a list if a list is provided) |
json_object_list |
List of objects of class |
json_string_list |
List of objects of class |
Each of the functions in this group returns a RandomEffectSamples object.
n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) bart_json <- saveBARTModelToJson(bart_model) bart_json_string <- saveBARTModelToJsonString(bart_model) bart_json_list <- list(bart_json) bart_json_string_list <- list(bart_json_string) rfx_container <- loadRandomEffectSamplesJson(bart_json, 0) rfx_container <- loadRandomEffectSamplesCombinedJson(bart_json_list, 0) rfx_container <- loadRandomEffectSamplesCombinedJsonString(bart_json_string_list, 0)n <- 100 p <- 10 X <- matrix(runif(n*p), ncol = p) rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- rep(1.0, n) y <- (-5 + 10*(X[,1] > 0.5)) + (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) bart_model <- bart(X_train=X, y_train=y, rfx_group_ids_train=rfx_group_ids, rfx_basis_train = rfx_basis, num_gfr=0, num_mcmc=10) bart_json <- saveBARTModelToJson(bart_model) bart_json_string <- saveBARTModelToJsonString(bart_model) bart_json_list <- list(bart_json) bart_json_string_list <- list(bart_json_string) rfx_container <- loadRandomEffectSamplesJson(bart_json, 0) rfx_container <- loadRandomEffectSamplesCombinedJson(bart_json_list, 0) rfx_container <- loadRandomEffectSamplesCombinedJsonString(bart_json_string_list, 0)
Dataset used to sample a random effects model. A random effects dataset consists of three matrices / vectors: group labels, bases, and variance weights. Variance weights are optional.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
data_ptrExternal pointer to a C++ RandomEffectsDataset class
new()
Create a new RandomEffectsDataset object.
RandomEffectsDataset$new(group_labels, basis, variance_weights = NULL)
group_labelsVector of group labels
basisMatrix of bases used to define the random effects regression (for an intercept-only model, pass an array of ones)
variance_weights(Optional) Vector of observation-specific variance weights
A new RandomEffectsDataset object.
update_basis()
Update basis matrix in a dataset
RandomEffectsDataset$update_basis(basis)
basisUpdated matrix of bases used to define random slopes / intercepts
update_variance_weights()
Update variance_weights in a dataset
RandomEffectsDataset$update_variance_weights( variance_weights, exponentiate = F )
variance_weightsUpdated vector of variance weights used to define individual variance / case weights
exponentiateWhether or not input vector should be exponentiated before being written to the RandomEffectsDataset's variance weights. Default: F.
num_observations()
Return number of observations in a RandomEffectsDataset object
RandomEffectsDataset$num_observations()
Observation count
num_basis()
Return dimension of the basis matrix in a RandomEffectsDataset object
RandomEffectsDataset$num_basis()
Basis vector count
get_group_labels()
Return group labels as an R vector
RandomEffectsDataset$get_group_labels()
Group label data
get_basis()
Return bases as an R matrix
RandomEffectsDataset$get_basis()
Basis data
get_variance_weights()
Return variance weights as an R vector
RandomEffectsDataset$get_variance_weights()
Variance weight data
has_group_labels()
Whether or not a dataset has group label indices
RandomEffectsDataset$has_group_labels()
True if group label vector is loaded, false otherwise
has_basis()
Whether or not a dataset has a basis matrix
RandomEffectsDataset$has_basis()
True if basis matrix is loaded, false otherwise
has_variance_weights()
Whether or not a dataset has variance weights
RandomEffectsDataset$has_variance_weights()
True if variance weights are loaded, false otherwise
The core "model" class for sampling random effects. Stores current model state, prior parameters, and procedures for sampling from the conditional posterior of each parameter.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
rfx_model_ptrExternal pointer to a C++ StochTree::RandomEffectsModel class
num_groupsNumber of groups in the random effects model
num_componentsNumber of components (i.e. dimension of basis) in the random effects model
new()
Create a new RandomEffectsModel object.
RandomEffectsModel$new(num_components, num_groups)
num_componentsNumber of "components" or bases defining the random effects regression
num_groupsNumber of random effects groups
A new RandomEffectsModel object.
sample_random_effect()
Sample from random effects model.
RandomEffectsModel$sample_random_effect( rfx_dataset, residual, rfx_tracker, rfx_samples, keep_sample, global_variance, rng )
rfx_datasetObject of type RandomEffectsDataset
residualObject of type Outcome
rfx_trackerObject of type RandomEffectsTracker
rfx_samplesObject of type RandomEffectSamples
keep_sampleWhether sample should be retained in rfx_samples. If FALSE, the state of rfx_tracker will be updated, but the parameter values will not be added to the sample container. Samples are commonly discarded due to burn-in or thinning.
global_varianceScalar global variance parameter
rngObject of type CppRNG
None
predict()
Predict from (a single sample of a) random effects model.
RandomEffectsModel$predict(rfx_dataset, rfx_tracker)
rfx_datasetObject of type RandomEffectsDataset
rfx_trackerObject of type RandomEffectsTracker
Vector of predictions with size matching number of observations in rfx_dataset
set_working_parameter()
Set value for the "working parameter." This is typically used for initialization, but could also be used to interrupt or override the sampler.
RandomEffectsModel$set_working_parameter(value)
valueParameter input
None
set_group_parameters()
Set value for the "group parameters." This is typically used for initialization, but could also be used to interrupt or override the sampler.
RandomEffectsModel$set_group_parameters(value)
valueParameter input
None
set_working_parameter_cov()
Set value for the working parameter covariance. This is typically used for initialization, but could also be used to interrupt or override the sampler.
RandomEffectsModel$set_working_parameter_cov(value)
valueParameter input
None
set_group_parameter_cov()
Set value for the group parameter covariance. This is typically used for initialization, but could also be used to interrupt or override the sampler.
RandomEffectsModel$set_group_parameter_cov(value)
valueParameter input
None
set_variance_prior_shape()
Set shape parameter for the group parameter variance prior.
RandomEffectsModel$set_variance_prior_shape(value)
valueParameter input
None
set_variance_prior_scale()
Set shape parameter for the group parameter variance prior.
RandomEffectsModel$set_variance_prior_scale(value)
valueParameter input
None
A forest sampler features two types of state: ephemeral and persistent. Persistent state includes objects like ForestSamples and RandomEffectSamples which constitute part of the final sampled model. Ephemeral state supports the sampling computations, but is not retained after the sampler finishes.
The two primary random-effects-based bits of ephemeral state are the RandomEffectsModel and RandomEffectsTracker classes, which represent the current state of a random effects model and its corresponding tracking data structures.
In a linear sampling loop, this ephemeral state is updated with each iteration of the sampler and any retained forests are copied to a RandomEffectSamples object. However, in multi-chain settings, the state of a random effects model must typically be "reset" at the beginning of a new chain. These function enable this process by synchronizing the state of a RandomEffectsModel and RandomEffectsTracker with a corresponding element of a RandomEffectSamples object, or by resetting both to their default (root) state.
resetRandomEffectsModel resets a RandomEffectsModel object based on the parameters indexed by sample_num in a RandomEffectsSamples object.
resetRandomEffectsTracker resets a RandomEffectsTracker object based on the parameters indexed by sample_num in a RandomEffectsSamples object.
rootResetRandomEffectsModel resets a RandomEffectsModel object to its "default" state.
rootResetRandomEffectsTracker resets a RandomEffectsTracker object to its "default" state.
These functions are intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
resetRandomEffectsModel(rfx_model, rfx_samples, sample_num, sigma_alpha_init) resetRandomEffectsTracker( rfx_tracker, rfx_model, rfx_dataset, residual, rfx_samples ) rootResetRandomEffectsModel( rfx_model, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale ) rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, residual)resetRandomEffectsModel(rfx_model, rfx_samples, sample_num, sigma_alpha_init) resetRandomEffectsTracker( rfx_tracker, rfx_model, rfx_dataset, residual, rfx_samples ) rootResetRandomEffectsModel( rfx_model, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale ) rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, residual)
rfx_model |
Object of type |
rfx_samples |
Object of type |
sample_num |
Index of sample stored in |
sigma_alpha_init |
Initial value of the "working parameter" scale parameter. |
rfx_tracker |
Object of type |
rfx_dataset |
Object of type |
residual |
Object of type |
alpha_init |
Initial value of the "working parameter". |
xi_init |
Initial value of the "group parameters". |
sigma_xi_init |
Initial value of the "group parameters" scale parameter. |
sigma_xi_shape |
Shape parameter for the inverse gamma variance model on the group parameters. |
sigma_xi_scale |
Scale parameter for the inverse gamma variance model on the group parameters. |
All four functions have no return type and operate in-place on the relevant RandomEffectsModel or RandomEffectsTracker objects
n <- 100 p <- 10 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) y <- (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) y_std <- (y-mean(y))/sd(y) outcome <- createOutcome(y_std) rng <- createCppRNG(1234) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_model <- createRandomEffectsModel(num_components, num_groups) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) alpha_init <- rep(1,num_components) xi_init <- matrix(rep(alpha_init, num_groups),num_components,num_groups) sigma_alpha_init <- diag(1,num_components,num_components) sigma_xi_init <- diag(1,num_components,num_components) sigma_xi_shape <- 1 sigma_xi_scale <- 1 rfx_model$set_working_parameter(alpha_init) rfx_model$set_group_parameters(xi_init) rfx_model$set_working_parameter_cov(sigma_alpha_init) rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) resetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome, rfx_samples) rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome)n <- 100 p <- 10 rfx_group_ids <- sample(1:2, size = n, replace = TRUE) rfx_basis <- matrix(rep(1.0, n), ncol=1) rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis) y <- (-2*(rfx_group_ids==1)+2*(rfx_group_ids==2)) + rnorm(n) y_std <- (y-mean(y))/sd(y) outcome <- createOutcome(y_std) rng <- createCppRNG(1234) num_groups <- length(unique(rfx_group_ids)) num_components <- ncol(rfx_basis) rfx_model <- createRandomEffectsModel(num_components, num_groups) rfx_tracker <- createRandomEffectsTracker(rfx_group_ids) rfx_samples <- createRandomEffectSamples(num_components, num_groups, rfx_tracker) alpha_init <- rep(1,num_components) xi_init <- matrix(rep(alpha_init, num_groups),num_components,num_groups) sigma_alpha_init <- diag(1,num_components,num_components) sigma_xi_init <- diag(1,num_components,num_components) sigma_xi_shape <- 1 sigma_xi_scale <- 1 rfx_model$set_working_parameter(alpha_init) rfx_model$set_group_parameters(xi_init) rfx_model$set_working_parameter_cov(sigma_alpha_init) rfx_model$set_group_parameter_cov(sigma_xi_init) rfx_model$set_variance_prior_shape(sigma_xi_shape) rfx_model$set_variance_prior_scale(sigma_xi_scale) for (i in 1:3) { rfx_model$sample_random_effect(rfx_dataset=rfx_dataset, residual=outcome, rfx_tracker=rfx_tracker, rfx_samples=rfx_samples, keep_sample=TRUE, global_variance=1.0, rng=rng) } resetRandomEffectsModel(rfx_model, rfx_samples, 0, 1.0) resetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome, rfx_samples) rootResetRandomEffectsModel(rfx_model, alpha_init, xi_init, sigma_alpha_init, sigma_xi_init, sigma_xi_shape, sigma_xi_scale) rootResetRandomEffectsTracker(rfx_tracker, rfx_model, rfx_dataset, outcome)
Class that defines a "tracker" for random effects models, most notably storing the data indices available in each group for quicker posterior computation and sampling of random effects terms. The class stores a mapping from every observation to its group index, a mapping from group indices to the training sample observations available in that group, and predictions for each observation.
This class is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
rfx_tracker_ptrExternal pointer to a C++ StochTree::RandomEffectsTracker class
new()
Create a new RandomEffectsTracker object.
RandomEffectsTracker$new(rfx_group_indices)
rfx_group_indicesInteger indices indicating groups used to define random effects
A new RandomEffectsTracker object.
Draw sample_size samples from population_vector without replacement, weighted by sampling_probabilities
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
sample_without_replacement( population_vector, sampling_probabilities, sample_size )sample_without_replacement( population_vector, sampling_probabilities, sample_size )
population_vector |
Vector from which to draw samples. |
sampling_probabilities |
Vector of probabilities of drawing each element of |
sample_size |
Number of samples to draw from |
Vector of size sample_size
a <- as.integer(c(4,3,2,5,1,9,7)) p <- c(0.7,0.2,0.05,0.02,0.01,0.01,0.01) num_samples <- 5 sample_without_replacement(a, p, num_samples)a <- as.integer(c(4,3,2,5,1,9,7)) p <- c(0.7,0.2,0.05,0.02,0.01,0.01,0.01) num_samples <- 5 sample_without_replacement(a, p, num_samples)
Sample from the posterior predictive distribution for outcomes modeled by BART
sampleBARTPosteriorPredictive( model_object, X = NULL, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL )sampleBARTPosteriorPredictive( model_object, X = NULL, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL )
model_object |
A fitted BART model object of class |
X |
A matrix or data frame of covariates. Required if the BART model depends on covariates (e.g., contains a mean or variance forest). |
leaf_basis |
A matrix of bases for mean forest models with regression defined in the leaves. Required for "leaf regression" models. |
rfx_group_ids |
A vector of group IDs for random effects model. Required if the BART model includes random effects. |
rfx_basis |
A matrix of bases for random effects model. Required if the BART model includes random effects. |
num_draws_per_sample |
The number of posterior predictive samples to draw for each posterior sample. Defaults to a heuristic based on the number of samples in a BART model (i.e. if the BART model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure intervals are based on at least 1000 posterior predictive draws). |
Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) ppd_samples <- sampleBARTPosteriorPredictive( model_object = bart_model, X = X )n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) y <- 2 * X[,1] + rnorm(n) bart_model <- bart(y_train = y, X_train = X) ppd_samples <- sampleBARTPosteriorPredictive( model_object = bart_model, X = X )
Sample from the posterior predictive distribution for outcomes modeled by BCF
sampleBCFPosteriorPredictive( model_object, X = NULL, Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL )sampleBCFPosteriorPredictive( model_object, X = NULL, Z = NULL, propensity = NULL, rfx_group_ids = NULL, rfx_basis = NULL, num_draws_per_sample = NULL )
model_object |
A fitted BCF model object of class |
X |
A matrix or data frame of covariates. |
Z |
A vector or matrix of treatment assignments. |
propensity |
(Optional) A vector or matrix of propensity scores. Required if the underlying model depends on user-provided propensities. |
rfx_group_ids |
(Optional) A vector of group IDs for random effects model. Required if the BCF model includes random effects. |
rfx_basis |
(Optional) A matrix of bases for random effects model. Required if the BCF model includes random effects. |
num_draws_per_sample |
(Optional) The number of samples to draw from the likelihood for each draw of the posterior. Defaults to a heuristic based on the number of samples in a BCF model (i.e. if the BCF model has >1000 draws, we use 1 draw from the likelihood per sample, otherwise we upsample to ensure at least 1000 posterior predictive draws). |
Array of posterior predictive samples with dimensions (num_observations, num_posterior_samples, num_draws_per_sample) if num_draws_per_sample > 1, otherwise (num_observations, num_posterior_samples).
n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) pi_X <- pnorm(X[,1] / 2) Z <- rbinom(n, 1, pi_X) y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) ppd_samples <- sampleBCFPosteriorPredictive( model_object = bcf_model, X = X, Z = Z, propensity = pi_X )n <- 100 p <- 5 X <- matrix(rnorm(n * p), nrow = n, ncol = p) pi_X <- pnorm(X[,1] / 2) Z <- rbinom(n, 1, pi_X) y <- 2 * X[,2] + 0.5 * X[,2] * Z + rnorm(n) bcf_model <- bcf(X_train = X, Z_train = Z, y_train = y, propensity_train = pi_X) ppd_samples <- sampleBCFPosteriorPredictive( model_object = bcf_model, X = X, Z = Z, propensity = pi_X )
Sample one iteration of the (inverse gamma) global variance model
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
sampleGlobalErrorVarianceOneIteration(residual, dataset, rng, a, b)sampleGlobalErrorVarianceOneIteration(residual, dataset, rng, a, b)
residual |
Outcome class |
dataset |
ForestDataset class |
rng |
C++ random number generator |
a |
Global variance shape parameter |
b |
Global variance scale parameter |
None
X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) y_std <- (y-mean(y))/sd(y) forest_dataset <- createForestDataset(X) outcome <- createOutcome(y_std) rng <- createCppRNG(1234) a <- 1.0 b <- 1.0 sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset, rng, a, b)X <- matrix(runif(10*100), ncol = 10) y <- -5 + 10*(X[,1] > 0.5) + rnorm(100) y_std <- (y-mean(y))/sd(y) forest_dataset <- createForestDataset(X) outcome <- createOutcome(y_std) rng <- createCppRNG(1234) a <- 1.0 b <- 1.0 sigma2 <- sampleGlobalErrorVarianceOneIteration(outcome, forest_dataset, rng, a, b)
Sample one iteration of the leaf parameter variance model (only for univariate basis and constant leaf!)
This function is intended for advanced use cases in which users require detailed control of sampling algorithms and data structures. Minimal input validation and error checks are performed – users are responsible for providing the correct inputs. For tutorials on the "proper" usage of the stochtree's advanced workflow, we provide several vignettes at https://stochtree.ai/
sampleLeafVarianceOneIteration(forest, rng, a, b)sampleLeafVarianceOneIteration(forest, rng, a, b)
forest |
C++ forest |
rng |
C++ random number generator |
a |
Leaf variance shape parameter |
b |
Leaf variance scale parameter |
None
num_trees <- 100 leaf_dimension <- 1 is_leaf_constant <- TRUE is_exponentiated <- FALSE active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) rng <- createCppRNG(1234) a <- 1.0 b <- 1.0 tau <- sampleLeafVarianceOneIteration(active_forest, rng, a, b)num_trees <- 100 leaf_dimension <- 1 is_leaf_constant <- TRUE is_exponentiated <- FALSE active_forest <- createForest(num_trees, leaf_dimension, is_leaf_constant, is_exponentiated) rng <- createCppRNG(1234) a <- 1.0 b <- 1.0 tau <- sampleLeafVarianceOneIteration(active_forest, rng, a, b)
Summarize a BART fit with a description of the model that was fit and numeric summaries of any sampled quantities.
## S3 method for class 'bartmodel' summary(object, ...)## S3 method for class 'bartmodel' summary(object, ...)
object |
The BART model object |
... |
Additional arguments |
BART model object unchanged after summarizing
Summarize a BCF fit with a description of the model that was fit and numeric summaries of any sampled quantities.
## S3 method for class 'bcfmodel' summary(object, ...)## S3 method for class 'bcfmodel' summary(object, ...)
object |
The BCF model object |
... |
Additional arguments |
BCF model object unchanged after summarizing