--- title: "Deeper Dive on Sampled Forests in StochTree" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Tree-Inspection} %\VignetteEncoding{UTF-8} %\VignetteEngine{knitr::rmarkdown} bibliography: vignettes.bib editor_options: markdown: wrap: 72 --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) ``` While out of sample evaluation and MCMC diagnostics on parametric BART components (i.e. $\sigma^2$, the global error variance) are helpful, it's important to be able to inspect the trees in a BART / BCF model (or a custom tree ensemble model). This vignette walks through some of the features `stochtree` provides to query and understand the forests / trees in a model. To begin, we load the `stochtree` package. ```{r} library(stochtree) ``` # Demo 1: Supervised Learning Generate sample data where feature 10 is the only "important" feature. ```{r} # Generate the data n <- 500 p_x <- 10 X <- matrix(runif(n*p_x), ncol = p_x) f_XW <- ( ((0 <= X[,10]) & (0.25 > X[,10])) * (-7.5) + ((0.25 <= X[,10]) & (0.5 > X[,10])) * (-2.5) + ((0.5 <= X[,10]) & (0.75 > X[,10])) * (2.5) + ((0.75 <= X[,10]) & (1 > X[,10])) * (7.5) ) noise_sd <- 1 y <- f_XW + rnorm(n, 0, 1)*noise_sd # Split data into test and train sets test_set_pct <- 0.2 n_test <- round(test_set_pct*n) n_train <- n - n_test test_inds <- sort(sample(1:n, n_test, replace = FALSE)) train_inds <- (1:n)[!((1:n) %in% test_inds)] X_test <- as.data.frame(X[test_inds,]) X_train <- as.data.frame(X[train_inds,]) y_test <- y[test_inds] y_train <- y[train_inds] ``` ## Sampling and Analysis Run BART. ```{r} num_gfr <- 10 num_burnin <- 0 num_mcmc <- 100 general_params <- list(keep_gfr = T) bart_model <- stochtree::bart( X_train = X_train, y_train = y_train, X_test = X_test, num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, general_params = general_params ) ``` Inspect the MCMC samples ```{r} plot(bart_model$sigma2_global_samples, ylab="sigma^2") abline(h=noise_sd^2,col="red",lty=2,lwd=2.5) plot(rowMeans(bart_model$y_hat_test), y_test, pch=16, cex=0.75, xlab = "pred", ylab = "actual") abline(0,1,col="red",lty=2,lwd=2.5) ``` Check the variable split count in the last GFR sample ```{r} bart_model$mean_forests$get_forest_split_counts(10, p_x) ``` ```{r} bart_model$mean_forests$get_aggregate_split_counts(p_x) ``` The split counts appear relatively uniform across features, so let's dig deeper and look at individual trees, starting with the first tree in the last "grow-from-root" sample. ```{r} splits = bart_model$mean_forests$get_granular_split_counts(p_x) splits[10,1,] ``` This tree has a single split on the only "important" feature. Now, let's look at the second tree. ```{r} splits[10,2,] ``` This tree also only splits on the important feature. ```{r} splits[10,20,] ``` ```{r} splits[10,30,] ``` We see that "later" trees are splitting on other features, but we also note that these trees are fitting an outcome that is already residualized many "relevant splits" made by trees 1 and 2. Now, let's inspect the first tree for this last GFR sample in more depth, following [this scikit-learn vignette](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html) ```{r} forest_num <- 9 tree_num <- 0 nodes <- sort(bart_model$mean_forests$nodes(forest_num, tree_num)) for (nid in nodes) { if (bart_model$mean_forests$is_leaf_node(forest_num, tree_num, nid)) { node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid) space_text <- rep("\t", node_depth) leaf_values <- bart_model$mean_forests$node_leaf_values(forest_num, tree_num, nid) cat(space_text, "node=", nid, " is a leaf node with value=", format(leaf_values, digits = 3), "\n", sep = "") } else { node_depth <- bart_model$mean_forests$node_depth(forest_num, tree_num, nid) space_text <- rep("\t", node_depth) left <- bart_model$mean_forests$left_child_node(forest_num, tree_num, nid) feature <- bart_model$mean_forests$node_split_index(forest_num, tree_num, nid) threshold <- bart_model$mean_forests$node_split_threshold(forest_num, tree_num, nid) right <- bart_model$mean_forests$right_child_node(forest_num, tree_num, nid) cat(space_text, "node=", nid, " is a split node, which tells us to go to node ", left, " if X[:, ", feature, "] <= ", format(threshold, digits = 3), " else to node ", right, "\n", sep = "") } } ```