One of my research studies uses a random forest algorithm to predict continuous biomechanical variables. If you don’t know anything about random forestes, I’ll try to describe them in a sentence: They are a type of algorithm that aggregates a large number of predictions made by decision trees and produces a single prediction for the whole forest. Specifically, I am using a Quantile Regression Forest (QRF) in R using the
When I was assessing the accuracy of my QRF model, I wanted to visualize the variability in the individual tree predictions alongside the aggregated prediction (seeing the forest for the trees, so to speak). After looking around online, I couldn’t find any examples of someone doing this, so I thought I’d do a quick write-up on how I was able to accomplish this (plot at the bottom). Maybe it will be beneficial to someone wanting to better visualize what is going on inside the “black box” of this machine learning algorithm. For this example, I’ll use the
mtcars dataset to try and predict a car’s MPG. Jump to the bottom for the link to the whole R script.
library(tidyr) library(caret) library(quantregForest) library(ggplot2) library(ggridges) # head(mtcars) qrf<- train(mpg ~ cyl + disp + hp + drat + wt, data = mtcars, method = 'qrf') # train QRF model
We can then make predictions using
predict(). However, this function will not provide predictions for each QRF tree (500 of them in this case) by default. We need to change the class of the
finalModel to only ‘randomForest’. This provides predictions for the forest (
aggregate) and the trees (
n_cars = 8 class(qrf$finalModel)
##  "quantregForest" "randomForest"
class(qrf$finalModel) <- 'randomForest' pred <- predict(qrf$finalModel, mtcars[1:n_cars,], predict.all = T) str(pred)
## List of 2 ## $ aggregate : Named num [1:8] 20.9 20.9 23.9 20.1 17.5 ... ## ..- attr(*, "names")= chr [1:8] "Mazda RX4" "Mazda RX4 Wag" "Datsun 710" "Hornet 4 Drive" ... ## $ individual: num [1:8, 1:500] 21.2 21.2 21.8 21.2 16.5 ... ## ..- attr(*, "dimnames")=List of 2 ## .. ..$ : chr [1:8] "Mazda RX4" "Mazda RX4 Wag" "Datsun 710" "Hornet 4 Drive" ... ## .. ..$ : NULL
I chose to visualize the tree predictions with ridgeplots in a predicted vs observed plot. The format of
pred isn’t exactly ggplot-friendly, so we need to do some reshaping with
merge before plotting:
# reshape individual tree predictions for ggplot obs <- mtcars$mpg[1:n_cars] car <- row.names(mtcars)[1:n_cars] tree_pred <- as.data.frame(pred$individual) colnames(tree_pred) <- 1:500 tree_pred$obs <- obs tree_pred$car <- car tree_pred_long <- gather(tree_pred, tree, pred, 1:500) # reshape aggregate predictions for ggplot agg <- aggregate(pred ~ obs + car, data = tree_pred_long, FUN = mean) # same as pred$aggregate values colnames(agg) <- c('obs', 'car', 'mean_pred') pred_obs <- data.frame(pred = pred$aggregate, obs = obs) pred_obs <- merge(pred_obs, agg, by = 'obs') # plot distribution of tree predictions and aggregate prediction of final QRF model ggplot(tree_pred_long) + geom_abline(slope = 1, intercept = 0, lty = 2)+ # line of identity geom_density_ridges(aes(x = pred, y = obs, group = car, fill = car), alpha = 0.3, color = 'black', rel_min_height = 0.02, size = 0.5)+ geom_point(data = pred_obs, aes(x = pred, y = obs, fill = car), pch = 21, size = 2)+ theme_classic()+ ggtitle('Seeing the (Random) Forest for the Trees')+ coord_fixed(xlim = c(12,28), ylim = c(12,28))+ scale_y_continuous('Observed MPG', breaks = seq(15,30,5))+ scale_x_continuous('Predicted MPG', breaks = seq(15,30,5))
## Picking joint bandwidth of 0.276
I’ve annotated the figure above to highlight the tree vs forest predictions. The ridgeplots don’t adhere to the units of the y axis and show us the relative distribution of the individual predictions that make up the quantile regression forest. It’s kinda neat to see how the tree predictions compare to the final aggregated prediction!
A gist of the R script can be found here
Tips and Tricks
This type of plot can get messy when many predictions are made within a small range of values, but adjusting the opacity (
alpha) can help, or you can just plot a random subset of your predictions to get a general idea.
You might notice that the ridgeplots are not entirely continuous (e.g. Merc 240D). This is aesthetics and you can adjust the threshold of what to show with the
If you have categorical variables in your QRF, you will need to change them to numerical before using
predict(). I guess when you use caret’s
train()with the formal-style input, it puts in dummy variables for categorical. This can be fixed using the
If you have an interaction in your model, I think you will need to have a column that represents that interaction before using
predict(). That’s my best guess as it throws an error about regression equation variables not being present in the new dataset you’re testing on.