While out of sample evaluation and
MCMC diagnostics on parametric BART components (i.e. σ2, the global error
variance) are helpful, it’s important to be able to inspect the trees in
a BART / BCF model (or a custom tree ensemble model). This vignette
walks through some of the features stochtree
provides to
query and understand the forests / trees in a model.
To begin, we load the stochtree
package.
Generate sample data where feature 10 is the only “important” feature.
# Generate the data
n <- 500
p_x <- 10
X <- matrix(runif(n*p_x), ncol = p_x)
f_XW <- (
((0 <= X[,10]) & (0.25 > X[,10])) * (-7.5) +
((0.25 <= X[,10]) & (0.5 > X[,10])) * (-2.5) +
((0.5 <= X[,10]) & (0.75 > X[,10])) * (2.5) +
((0.75 <= X[,10]) & (1 > X[,10])) * (7.5)
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, 1)*noise_sd
# Split data into test and train sets
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- as.data.frame(X[test_inds,])
X_train <- as.data.frame(X[train_inds,])
y_test <- y[test_inds]
y_train <- y[train_inds]
Run BART.
num_gfr <- 10
num_burnin <- 0
num_mcmc <- 100
general_params <- list(keep_gfr = T)
bart_model <- stochtree::bart(
X_train = X_train, y_train = y_train, X_test = X_test,
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
general_params = general_params
)
Inspect the MCMC samples
plot(rowMeans(bart_model$y_hat_test), y_test,
pch=16, cex=0.75, xlab = "pred", ylab = "actual")
abline(0,1,col="red",lty=2,lwd=2.5)
Check the variable split count in the last GFR sample
bart_model$mean_forests$get_aggregate_split_counts(p_x)
#> [1] 2349 2962 2692 2256 2436 3450 2428 2147 2846 3587
The split counts appear relatively uniform across features, so let’s dig deeper and look at individual trees, starting with the first tree in the last “grow-from-root” sample.
splits = bart_model$mean_forests$get_granular_split_counts(p_x)
splits[10,1,]
#> [1] 0 0 0 0 0 0 0 0 0 1
This tree has a single split on the only “important” feature. Now, let’s look at the second tree.
This tree also only splits on the important feature.
We see that “later” trees are splitting on other features, but we also note that these trees are fitting an outcome that is already residualized many “relevant splits” made by trees 1 and 2.
Now, let’s inspect the first tree for this last GFR sample in more depth, following this scikit-learn vignette
forest_num <- 9
tree_num <- 0
nodes <- sort(bart_model$mean_forests$nodes(forest_num, tree_num))
for (nid in nodes) {
if (bart_model$mean_forests$is_leaf_node(forest_num, tree_num, nid)) {
node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid)
space_text <- rep("\t", node_depth)
leaf_values <- bart_model$mean_forests$node_leaf_values(forest_num, tree_num, nid)
cat(space_text, "node=", nid, " is a leaf node with value=",
format(leaf_values, digits = 3), "\n", sep = "")
} else {
node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid)
space_text <- rep("\t", node_depth)
left <- bart_model$mean_forests$left_child_node(forest_num, tree_num, nid)
feature <- bart_model$mean_forests$node_split_index(forest_num, tree_num, nid)
threshold <- bart_model$mean_forests$node_split_threshold(forest_num, tree_num, nid)
right <- bart_model$mean_forests$right_child_node(forest_num, tree_num, nid)
cat(space_text, "node=", nid, " is a split node, which tells us to go to node ",
left, " if X[:, ", feature, "] <= ", format(threshold, digits = 3),
" else to node ", right, "\n", sep = "")
}
}
#> node=0 is a split node, which tells us to go to node 1 if X[:, 9] <= 0.497 else to node 2
#> node=1 is a leaf node with value=-0.0962
#> node=2 is a leaf node with value=0.261