Objectives

In this tutorial, we’ll learn how to use tree-based models (random forests) to predict the values of both categorical and continuous variables from the values of other variables.

  1. Understand how tree-based models, particularly the random forests algorithm, work, as well as when and why we might use these types of predictive models.
  2. Implement a random forests model in R.
  3. Interpret results from your model.

My breakdown of how this all works is based on the fabulous overview in the Introduction to Statistical Learning text I’ve referenced throughout the class (Chapter 8) - so if you want more detail and code, head there!

Let’s get started!


Set-up

You’ll need the following:

library(tidyverse)
library(randomForest)

Prediction

At this point, you should understand the important distinction between models we use for inference and models for prediction. Unlike inferential models, such as the regression tools we used earlier in the class which are great for understanding relationships between variables, predictive models are more about predicting an event or outcome. This distinction is important… because we are less interested in understanding the fundamental relationships between variables, predictive models are often more difficult to explain using equations. These models often generate model structures that are so complex they are difficult for humans to understand (see this week’s video!). What we can understand, however, are the basic algorithms that identify the structure that best predicts the outcome of interest. In this tutorial, we’ll focus on a commonly-used set of predictive models: tree-based models. We’ll work through an example using a tree-based model for regression (prediction of a continuous variable) and for classification (prediction of a categorical variable that takes on only a few discrete values).

So how do tree-based models work? This image is a really helpful overview:

I like this visual overview. It shows, at a high level, how tree-based models work. Basically, the trees are finding the values of each variable that best help to distill the category you’re trying to predict. Instead of a few trees in this example, full models can built hundreds of trees to best predict the outcome of interest. Trees are basically a giant flow chart of yes/no questions. We try to find the tree structure that best predicts the outcome at the end (or nodes) of the trees (“Low Risk” versus “High Risk” in this example).

In this tutorial, we’re going to focus on one type of tree-based model, random forests. Random forests are a very popular tree-based model that uses bootstrap aggregation or bagging, which basically makes tons of predictions each using regression or classification trees, and then forms a final prediction based on the average of all of these trees. Crucially, subsets of predictors are randomly selected for each tree in this forest of trees. This reduces correlation between trees in the forest, improving prediction accuracy!1

Random forests for classification

Random forests is most often used for classification problems, so building models that best predict what category an observation falls into. In our case, we’re interested in predicting whether or not a particular place in the US is cultivated with our favorite crop ever, corn. This is really data I’m working with for a growing research project predicting how where we grow crops will change in a warmer world. This data.frame contains a column called AP that stands for absence/presence. The column takes a value of 1 when corn is present in that pixel and 0 when corn is not present. The other columns stand for the following things:

  • B10_MEANTEMP_WARM - average temperature in the warmest quarter of the year
  • B4_TEMP_SEASONALITY - standard deviation of temperature * 100
  • IRR - whether or not the pixel is irrigated (0 == no, 1 == yes)
  • SLOPE - slope of terrain
  • ELEVATION - elevation of terrain
  • B2_MEAN_DIURNAL_RANGE
  • B8_MEANTEMP_WET - average temperature in the wettest quarter
  • B18_PPT_WARMQ - average precipitation in the warmest quarter
  • B12_TOTAL_PPT - total annual precipitation
  • B15_PPT_SEASONALITY - standard deviation of precipitation * 100
  • T_CEC_SOIL - cation exchange capacity of the topsoil
  • T_OC - topsoil organic carbon
  • S_PH_H2O - subsoil pH

Soil data comes from the HWSD project. Irrigation from the MIRAD project. And topography from the National Elevation Dataset.

Note that all of the columns that start with B*... are based on the list of biovars build by the WorldClim team. If you’re interested in species distribution modeling or cool open-source climate futures data, check them out! There’s even an R package out there (dismo) that helps you play with this cool data.

ap <- readRDS("./data/corn_ap.RDS")
glimpse(ap)
## Rows: 4,961
## Columns: 14
## $ AP                    <fct> 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, ...
## $ B10_MEANTEMP_WARM     <dbl> 22.56833, 21.99825, 24.21244, 22.40566, 19.48...
## $ B4_TEMP_SEASONALITY   <dbl> 1078.1600, 916.1887, 836.2860, 1107.4468, 888...
## $ IRR                   <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, ...
## $ SLOPE                 <dbl> 0.09171792, 0.36307254, 0.50377464, 0.2550289...
## $ ELEVATION             <dbl> 196, 254, 306, 392, 723, 533, 353, 452, 948, ...
## $ B2_MEAN_DIURNAL_RANGE <dbl> 11.220123, 12.530332, 12.551253, 12.100369, 1...
## $ B8_MEANTEMP_WET       <dbl> 18.38585, 15.49067, 11.31021, 18.24522, 11.54...
## $ B18_PPT_WARMQ         <dbl> 333.7619, 375.3810, 369.4762, 348.3810, 373.8...
## $ B12_TOTAL_PPT         <dbl> 1019.2381, 1281.5238, 1487.5714, 960.9524, 13...
## $ B15_PPT_SEASONALITY   <dbl> 62.61211, 40.68610, 44.64520, 75.86792, 39.95...
## $ T_CEC_SOIL            <dbl> 19, 12, 6, 19, 12, 23, 11, 18, 7, 18, 26, 12,...
## $ T_OC                  <dbl> 1.68, 1.45, 0.98, 1.68, 1.45, 2.08, 0.82, 1.8...
## $ S_PH_H2O              <dbl> 6.8, 5.2, 5.0, 6.8, 5.2, 7.4, 6.2, 6.8, 4.8, ...

Ok, so how do we build a predictive model of where corn is grown using random forests? First, make sure you’ve installed and loaded the randomForest package. Then, remember that we need to split our data into training and testing data.2 We train our algorithm on the training data, and then see how well the model does at predicting the held-out testing data.

set.seed(1) # this ensures you generate the same random row numbers every time you run the code

# hold out 25% of the data
random_rn <- sample(nrow(ap), ceiling(nrow(ap)*.25)) # generate random row numbers
train <- ap[-random_rn,] # remove those random row numbers
test <- ap[random_rn,] # keep those random row numbers

# run rf, it's easy!
rf_ap <- randomForest(AP ~ ., data = train)

The last line of code actually runs the random forests algorithm. It’s easy to implement in R, but note that there are lots of parameters you can set behind the scenes that can affect model performance (things like the number of trees). Check out Introduction to Statistical Learning for more on how to optimize these parameters. For now, we’re going to keep things simple. By running this little line of code, you’ve just built a fairly sophisticated model to predict where corn is grown.

rf_ap
## 
## Call:
##  randomForest(formula = AP ~ ., data = train) 
##                Type of random forest: classification
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##         OOB estimate of  error rate: 10.54%
## Confusion matrix:
##      0    1 class.error
## 0 1612  217  0.11864407
## 1  175 1716  0.09254363

When you print out the model, it tells you a few useful things. First, how many variables were tried at each split in the tree (yes this is a parameter you can change) and how many trees were built in the model (Number of trees). This also reports the “out of bag” estimate of error rate, which is 10.54%. Each tree built in the model has its own out-of-bag sample of data that was not used during construction. This essentially tells us the prediction accuracy on this out-of-bag sample data.

Finally, this returns our confusion matrix, which tells us how often the model got things wrong (guessed a 0 when the correct value was 1, or vice versa). We’ll use this confusion matrix on our test data to come up with an estimate of model performance. To do this, we first need to use our new model to predict our test data using the predict() function:

preds <- predict(rf_ap, test, type = "class") 
head(preds)
## 14389 11203 18797   421 14956 17496 
##     1     0     1     0     1     1 
## Levels: 0 1

preds now contains a list of zeros and ones for each row in the test data.frame. Here, the predict() function basically predicts the value of AP (absence or presence) based on the values of the climate, soil, and topographic predictors in the test data. By setting type = 'class' we tell the predict() function to generate its best guess of the value of AP (0 or 1). We can also set type = 'prob' which is cool because it will generate a probability of that row being corn! Try it!

So how do we assess predictive performance on the test data?

test$PREDICTIONS <- preds
cm <- table(test$PREDICTIONS, test$AP)
cm
##    
##       0   1
##   0 552  61
##   1  62 566

Here’s our confusion matrix on the test data. It shows, for example, that for 552 rows the model correctly guessed zero (no corn). We can use this confusion matrix to compute the error rate as follows:

error <- (cm[1,1] + cm[2, 2])/nrow(test)
error
## [1] 0.9008864

This is basically pulling out the number of times we get things right divided by the total number of observations! So our model got things right ~90% of the time. Not bad!

Another cool thing we can pull from this model is called a variable importance plot. This plot tells us how each variable contributes to the prediction of the outcome, so variables with a higher variable importance contribute more to predicting the outcome of interest. CAUTION. This does not mean that they cause the outcome, just that as the predictor values vary, they help predict variation in the outcome.3

imp <- as.data.frame(varImpPlot(rf_ap)) # this automatically creates a figure

imp$varnames <- rownames(imp) # row names to column
imp
##                       MeanDecreaseGini              varnames
## B10_MEANTEMP_WARM            292.30805     B10_MEANTEMP_WARM
## B4_TEMP_SEASONALITY          324.41407   B4_TEMP_SEASONALITY
## IRR                           73.57496                   IRR
## SLOPE                         99.42673                 SLOPE
## ELEVATION                    104.82330             ELEVATION
## B2_MEAN_DIURNAL_RANGE        136.54736 B2_MEAN_DIURNAL_RANGE
## B8_MEANTEMP_WET              143.27590       B8_MEANTEMP_WET
## B18_PPT_WARMQ                130.68100         B18_PPT_WARMQ
## B12_TOTAL_PPT                166.10387         B12_TOTAL_PPT
## B15_PPT_SEASONALITY           97.68377   B15_PPT_SEASONALITY
## T_CEC_SOIL                    97.82149            T_CEC_SOIL
## T_OC                          36.13708                  T_OC
## S_PH_H2O                     154.36127              S_PH_H2O

Here, the varImpPlot() computes MeanDeceaseGini which is the mean decrease in the Gini coefficient. Scale of this value is irrelevant; only relative values matter. It’s basically telling us about how average model performance changes every time that variable is chosen to split a node. I really like this overview of how the Gini coefficient works.

We can use ggplot() to create a better visualization of importance:

ggplot(imp, aes(x=reorder(varnames, MeanDecreaseGini), y=MeanDecreaseGini)) + 
  geom_point() +
  geom_segment(aes(x=varnames,xend=varnames,y=0,yend=MeanDecreaseGini)) +
  ylab("Mean decrease Gini") +
  xlab("") +
  coord_flip() +
  theme_minimal()

So here, we see that temperature seasonality and temperatures during the warmest quarter strongly predict where corn is grown. This makes sense! But remember, this doesn’t necessarily mean these things have a causal relationship with corn production, just that they help predict where corn is grown.


Random forests for regression

The language here might be a bit confusing, but random forests for regression just means we’re using random forests to predict a continuous variable (like corn yields). Hey, we have data on corn yields! Let’s load it and RF it! Luckily, most of the code to implement this is the same, we just have to think differently about how we assess error. Instead of predicting whether or not we got the classification correct, we’re not looking essentially at residuals, so how far off from the actual yield value were our predictions. We can use our old friend mean squared error to look at this.

corn_yield <- readRDS("./data/corn.RDS")

# for RF to work, we need to remove the weird identifier variables in the data.frame that don't go in the RF regression
corn_yield <- corn_yield %>% select(-c(GEOID, STATE_NAME, COUNTY_NAME))

set.seed(1) 
random_rn <- sample(nrow(corn_yield), ceiling(nrow(corn_yield)*.25)) 
train <- corn_yield[-random_rn,] 
test <- corn_yield[random_rn,] # 

rf_ap <- randomForest(CORN_YIELD ~ ., data = train)

We can use the same predict() function, but now to predict CORN_YIELD rather than the binary indicator we used above. Now, our confusion matrix approach to assessing error doesn’t really make sense any more. Instead, we need to look at how far off our predicted yields are from actual yields.

preds <- predict(rf_ap, test, type = "response")
test$PREDICTIONS <- preds
head(test %>% select(PREDICTIONS, CORN_YIELD))
##      PREDICTIONS CORN_YIELD
## 1107    208.4460      211.9
## 9093    131.0925      146.4
## 5505    130.2069      142.0
## 2628    151.5031      147.0
## 5768    121.6851      141.2
## 1758    105.3472       42.3

We can do this using mean squared error! Remember, this basically gives us a sense of how big our residuals are, or how far off our model is from the real values:

MSE <- mean((test$CORN_YIELD - test$PREDICTIONS)^2)
print(MSE)
## [1] 494.1617

Our goal here is to minimize the mean squared error, so as we run different models, we try to reduce how far off our modeled values are from actual values.

Now, let’s use variable importance plots to look at the variables most predictive of yields (DRUM ROLL):

ggplot(imp, aes(x=reorder(varnames, IncNodePurity), y=IncNodePurity)) + 
  geom_point() +
  geom_segment(aes(x=varnames,xend=varnames,y=0,yend=IncNodePurity)) +
  ylab("Increase in node purity") +
  xlab("") +
  coord_flip() +
  theme_minimal()

Note here that we’re assessing variable importance with a different metric here, the increase in node purity. This is like the Gini-based metric above, and is calculated using the reduction in the sum of squared errors when a new variable is chosen in a split. What’s cool here is that we see that YEAR strongly predicts yield changes, which means there’s likely something changing through time we haven’t included in our model that affects yield (technology? markets? prices?).


Additional Resources

  • Here’s a good overview of tree-based models versus linear models.
  • More on pruning trees here.
  • Good overview of prediction accuracy and interpretability here.
  • And some slides on random forest best practice.
  • I thought this was a helpful overview of how we measure variable importance in RF models.
  • And a great online free textbook reviewing lots of tree-based algorithms here.

  1. Love love love this overview with cool visualizations of how random forests works.

  2. Technically you want to split into three chunks, one for training, one for validation of training models, and a final independent data you only use for testing your final model.

  3. You have to be really careful interpreting these plots. Variable importance becomes really complex with highly collinear variables.