This vignette demonstrates how to
use the bart()
function for Bayesian supervised learning
(Chipman, George, and McCulloch (2010)),
with an additional “variance forest,” for modeling conditional variance
(see Murray (2021)). To begin, we load the
stochtree
package.
Here, we generate data with a constant (zero) mean and a relatively simple covariate-modified variance function.
# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- 0
s_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (1) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2) +
((0.75 <= X[,1]) & (1 > X[,1])) * (3)
)
y <- f_XW + rnorm(n, 0, 1)*s_XW
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_x_test <- f_XW[test_inds]
f_x_train <- f_XW[train_inds]
s_x_test <- s_XW[test_inds]
s_x_train <- s_XW[train_inds]
We first sample the σ2(X) ensemble
using “warm-start” initialization (He and Hahn
(2023)). This is the default in stochtree
.
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_trees <- 20
a_0 <- 1.5
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0)
variance_forest_params <- list(num_trees = num_trees)
bart_model_warmstart <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
We now sample the σ2(X) ensemble using MCMC with root initialization (as in Chipman, George, and McCulloch (2010)).
num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0)
variance_forest_params <- list(num_trees = num_trees)
bart_model_mcmc <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
Here, we generate data with a constant (zero) mean and a more complex covariate-modified variance function.
# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- 0
s_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.5*X[,3]) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (1*X[,3]) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2*X[,3]) +
((0.75 <= X[,1]) & (1 > X[,1])) * (3*X[,3])
)
y <- f_XW + rnorm(n, 0, 1)*s_XW
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_x_test <- f_XW[test_inds]
f_x_train <- f_XW[train_inds]
s_x_test <- s_XW[test_inds]
s_x_train <- s_XW[train_inds]
We first sample the σ2(X) ensemble
using “warm-start” initialization (He and Hahn
(2023)). This is the default in stochtree
.
num_trees <- 20
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = num_trees, alpha = 0.95,
beta = 1.25, min_samples_leaf = 1)
bart_model_warmstart <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
We now sample the σ2(X) ensemble using MCMC with root initialization (as in Chipman, George, and McCulloch (2010)).
num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 0,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = num_trees, alpha = 0.95,
beta = 1.25, min_samples_leaf = 1)
bart_model_mcmc <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
Here, we generate data with (relatively simple) covariate-modified mean and variance functions.
# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- (
((0 <= X[,2]) & (0.25 > X[,2])) * (-6) +
((0.25 <= X[,2]) & (0.5 > X[,2])) * (-2) +
((0.5 <= X[,2]) & (0.75 > X[,2])) * (2) +
((0.75 <= X[,2]) & (1 > X[,2])) * (6)
)
s_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (1) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2) +
((0.75 <= X[,1]) & (1 > X[,1])) * (3)
)
y <- f_XW + rnorm(n, 0, 1)*s_XW
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_x_test <- f_XW[test_inds]
f_x_train <- f_XW[train_inds]
s_x_test <- s_XW[test_inds]
s_x_train <- s_XW[train_inds]
We first sample the σ2(X) ensemble
using “warm-start” initialization (He and Hahn
(2023)). This is the default in stochtree
.
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = 50, alpha = 0.95,
beta = 1.25, min_samples_leaf = 5)
bart_model_warmstart <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_warmstart$y_hat_test), y_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function")
abline(0,1,col="red",lty=2,lwd=2.5)
plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
We now sample the σ2(X) ensemble using MCMC with root initialization (as in Chipman, George, and McCulloch (2010)).
num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = 50, alpha = 0.95,
beta = 1.25, min_samples_leaf = 5)
bart_model_mcmc <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_mcmc$y_hat_test), y_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function")
abline(0,1,col="red",lty=2,lwd=2.5)
plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
Here, we generate data with more complex covariate-modified mean and variance functions.
# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- (
((0 <= X[,2]) & (0.25 > X[,2])) * (-6*X[,4]) +
((0.25 <= X[,2]) & (0.5 > X[,2])) * (-2*X[,4]) +
((0.5 <= X[,2]) & (0.75 > X[,2])) * (2*X[,4]) +
((0.75 <= X[,2]) & (1 > X[,2])) * (6*X[,4])
)
s_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.5*X[,3]) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (1*X[,3]) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2*X[,3]) +
((0.75 <= X[,1]) & (1 > X[,1])) * (3*X[,3])
)
y <- f_XW + rnorm(n, 0, 1)*s_XW
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
f_x_test <- f_XW[test_inds]
f_x_train <- f_XW[train_inds]
s_x_test <- s_XW[test_inds]
s_x_train <- s_XW[train_inds]
We first sample the σ2(X) ensemble
using “warm-start” initialization (He and Hahn
(2023)). This is the default in stochtree
.
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = 50, alpha = 0.95,
beta = 1.25, min_samples_leaf = 5)
bart_model_warmstart <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_warmstart$y_hat_test), y_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function")
abline(0,1,col="red",lty=2,lwd=2.5)
plot(rowMeans(bart_model_warmstart$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)
We now sample the σ2(X) ensemble using MCMC with root initialization (as in Chipman, George, and McCulloch (2010)).
num_gfr <- 0
num_burnin <- 1000
num_mcmc <- 100
num_samples <- num_gfr + num_burnin + num_mcmc
general_params <- list(sample_sigma2_global = F)
mean_forest_params <- list(sample_sigma2_leaf = F, num_trees = 50,
alpha = 0.95, beta = 2, min_samples_leaf = 5)
variance_forest_params <- list(num_trees = 50, alpha = 0.95,
beta = 1.25, min_samples_leaf = 5)
bart_model_mcmc <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params, mean_forest_params = mean_forest_params,
variance_forest_params = variance_forest_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model_mcmc$y_hat_test), y_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "mean function")
abline(0,1,col="red",lty=2,lwd=2.5)
plot(rowMeans(bart_model_mcmc$sigma_x_hat_test), s_x_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual", main = "standard deviation function")
abline(0,1,col="red",lty=2,lwd=2.5)