R code

library("tidyverse")
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0     ✔ purrr   1.0.1
## ✔ tibble  3.2.1     ✔ dplyr   1.1.2
## ✔ tidyr   1.2.1     ✔ stringr 1.4.1
## ✔ readr   2.1.2     ✔ forcats 0.5.2
## Warning: package 'tibble' was built under R version 4.2.3
## Warning: package 'purrr' was built under R version 4.2.3
## Warning: package 'dplyr' was built under R version 4.2.3
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
theme_set(theme_bw())
library("rpart")
library("rpart.plot")
## Warning: package 'rpart.plot' was built under R version 4.2.3
library("randomForest")
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## 
## The following object is masked from 'package:dplyr':
## 
##     combine
## 
## The following object is masked from 'package:ggplot2':
## 
##     margin
library("DT")
## Warning: package 'DT' was built under R version 4.2.3

In these slides, we’ll introduce classification and regression trees (CART).

For simplicity in the analyses, I will use a subset of the diamonds data set where we randomly select 100 observations and eliminate (for simplicity) the categorical variables.

set.seed(20230425) # This matches what was used in a previous set of slides
n <- 100
d <- diamonds %>%
  dplyr::select(-cut, -color, -clarity) %>%
  rename(lprice = price) %>%
  mutate(lprice = log(lprice))

train <- d %>% sample_n(n)
test <- d %>% sample_n(n)

Load up previous error file

error <- read_csv("../28-penalty/error.csv")
## Rows: 10 Columns: 4
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (2): group, method
## dbl (2): in_sample, out_of_sample
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Classification and Regression Trees (CART)

Tree-based methods utilize a regression approach with dummy variables created from continuous variables being above or below a cutoff. Then in each split, variables will be split again. This methodology can be utilized for many regression problems including a binary response and continuous response. Here we will investigate the continuous response.

Fit

m <- rpart(lprice ~ ., data = train)

Tree

Trees (in the real world) are constructed of a trunk, branches, and leaves. CART trees utilize the same structure where the trunk is composed of all observations, the branches split those observations, and (eventually) result in leaves.

These leaves are composed of a collection of observations. The mean of those observations is the estimated and predicted mean for any observations that fall into that leaf.

The plot below provides the dummy variables that split the observations. The top number in each box is the mean of those observations. The percentage in the each box is the percentage of all observations that end up in that branch or leaf.

rpart.plot(m, uniform = TRUE)

This regression tree utilizes only the explanatory variables carat, x, and z. These variables are related to the size of the diamond: carat is weight, x is depth, and z is height. Thus, this tree is ordered from left (smallest) to right (largest) diamonds and the log price increases from left to right.

Regression model

Recall that a regression model is specified by a set of explanatory variables \(X_1,\ldots,X_p\). Here those explanatory variables are the product of dummy variables that lead to a leaf. Based on the tree above, the explanatory variables are

  • \(X_1 = \mathrm{I}(x < 5.7)\mathrm{I}(z < 3.1)\mathrm{I}(x < 4.4)\)
  • \(X_2 = \mathrm{I}(x < 5.7)\mathrm{I}(z < 3.1)\mathrm{I}(x \ge 4.4)\)
  • \(X_3 = \mathrm{I}(x < 5.7)\mathrm{I}(z \ge 3.1)\)
  • \(X_4 = \mathrm{I}(x \ge 5.7)\mathrm{I}(carat < 1.2)\mathrm{I}(x < 6.1)\)
  • \(X_5 = \mathrm{I}(x \ge 5.7)\mathrm{I}(carat < 1.2)\mathrm{I}(x \ge 6.1)\)
  • \(X_6 = \mathrm{I}(x \ge 5.7)\mathrm{I}(carat \ge 1.2)\mathrm{I}(x < 7.3)\)
  • \(X_7 = \mathrm{I}(x \ge 5.7)\mathrm{I}(carat \ge 1.2)\mathrm{I}(x \ge 7.3)\)

Recall that for these data lprice is our response.

Some of these dummy variables can be simplified, e.g. 

  • \(X_1 = \mathrm{I}(z < 3.1)\mathrm{I}(x < 4.4)\)

The resulting regression model has \[ E[Y_i] = \beta_1X_1 + \beta_2X_2 + \beta_3X_3 + \beta_4X_4 + \beta_5X_5 + \beta_6X_6 \] where \(\beta_j\) is the mean for the observations with \(X_j = 1\).

We can verify that this is what is going on by computing group means

but in order to do so we need more accurate cut points

m
## n= 100 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 100 100.5598000 7.769913  
##    2) x< 5.705 51  11.2735400 6.904977  
##      4) z< 3.07 36   2.6600300 6.655123  
##        8) x< 4.37 9   0.4573568 6.317919 *
##        9) x>=4.37 27   0.8381956 6.767524 *
##      5) z>=3.07 15   0.9724444 7.504626 *
##    3) x>=5.705 49  11.4211500 8.670154  
##      6) carat< 1.18 27   1.9915980 8.328375  
##       12) x< 6.105 9   0.1660813 8.030495 *
##       13) x>=6.105 18   0.6276279 8.477315 *
##      7) carat>=1.18 22   2.4048540 9.089610  
##       14) x< 7.345 12   0.2206264 8.867408 *
##       15) x>=7.345 10   0.8807686 9.356251 *
train %>%
  mutate(
    X1 = (x < 5.705)*(z < 3.07)*(x < 4.37),
    X2 = (x < 5.705)*(z < 3.07)*(x >= 4.37),
    X3 = (x < 5.705)*(z >= 3.07),
    X4 = (x >= 5.705)*(carat < 1.18)*(x < 6.105),
    X5 = (x >= 5.705)*(carat < 1.18)*(x >= 6.105),
    X6 = (x >= 5.705)*(carat >= 1.18)*(x < 7.345),
    X7 = (x >= 5.705)*(carat >= 1.18)*(x >= 7.345)
  ) %>%
  summarize(
    mean1 = sum(lprice*X1)/sum(X1),
    mean2 = sum(lprice*X2)/sum(X2),
    mean3 = sum(lprice*X3)/sum(X3),
    mean4 = sum(lprice*X4)/sum(X4),
    mean5 = sum(lprice*X5)/sum(X5),
    mean6 = sum(lprice*X6)/sum(X6),
    mean7 = sum(lprice*X7)/sum(X7)
  ) %>%
  round(1) %>%
  pivot_longer(everything())
## # A tibble: 7 × 2
##   name  value
##   <chr> <dbl>
## 1 mean1   6.3
## 2 mean2   6.8
## 3 mean3   7.5
## 4 mean4   8  
## 5 mean5   8.5
## 6 mean6   8.9
## 7 mean7   9.4

Iterative construction

We can take a look at the iterative construction of the model using

Predictions

p <- bind_rows(
  test  %>% mutate(p = predict(m, newdata = test),  type = "test"),
  train %>% mutate(p = predict(m, newdata = train), type = "train")
)

ggplot(p, aes(x = p, y = lprice, shape = type, color = type)) + 
  geom_abline(intercept = 0, slope = 1, color = "gray") + 
  geom_point(position = position_dodge(width = 0.1)) 

p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Tree",
    method        = "default",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Tuning parameters

args(rpart.control)
## function (minsplit = 20L, minbucket = round(minsplit/3), cp = 0.01, 
##     maxcompete = 4L, maxsurrogate = 5L, usesurrogate = 2L, xval = 10L, 
##     surrogatestyle = 0L, maxdepth = 30L, ...) 
## NULL

Each argument can tune the CART model to underfit or overfit. For example,

  • minsplit
    • low values lead to overfitting
    • high values lead to underfitting
  • minbucket
    • low values lead to overfitting
    • high values lead to underfitting
  • cp
    • low values lead to overfitting
    • high values lead to underfitting

Underfit

m <- rpart(lprice ~ ., data = train,
           control = rpart.control(
             minsplit = 40,
             minbucket = 20,
             cp = 0.1
           ))
rpart.plot(m)

p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Tree",
    method        = "underfit",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Overfit

m <- rpart(lprice ~ ., data = train,
           control = rpart.control(
             minsplit = 10,
             minbucket = 5,
             cp = 0.001
           ))
rpart.plot(m)

p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Tree",
    method        = "overfit",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Random forests

The idea behind random forests is encapsulated in the name.

Forests means that we will create a collection of tree models, i.e. a forest. Since we have a collection of trees, we will use a model averaged prediction from those trees. Since the trees are all interchangeable, this average will be an unweighted average.

Random indicates that randomness will be employed in the process to construct a variety of trees. This randomness will be included in two steps:

  1. each time we construct a tree, we will sample from the data with replacement and
  2. at every split, we will randomly choose a subset of the explanatory variables to split on.

In the first step, we will fit the model with the resampled data. Since some observations will not be included in a particular resample, these data can be used to evaluate out-of-sample error. In random forests, this is called the out-of-bag error.

Fit

m <- randomForest(lprice ~ ., data = train)
m
## 
## Call:
##  randomForest(formula = lprice ~ ., data = train) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 2
## 
##           Mean of squared residuals: 0.04560889
##                     % Var explained: 95.46

Out-of-bag error versus number of trees.

plot(m)

Prediction

p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Random forest",
    method        = "default",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Arguments

There are a number of arguments to the random forest algorithm that can be tuned that will affect predictive performance.

?randomForest
  • sampsize: number of resampled observations to train on
  • mtry: number of explanatory variable to try at each step
  • nodesize: number of observations in each leaf
  • ntree: number of trees to create
  • replace: whether to sample with replacement or not

Overfit

We can overfit by using a larger portion of our data, making sampling without replacement, trying all explanatory variables at every split, and making the minimum number of observations in each split as small as possible.

This should lead to every tree being exactly the same, but perhaps there are additional tuning parameters that result in this not being the case.

m <- randomForest(lprice ~ ., data = train,
                  sampsize = nrow(train),
                  replace = FALSE,
                  mtry = 5,
                  nodesize = 1 
                  )
m
## 
## Call:
##  randomForest(formula = lprice ~ ., data = train, sampsize = nrow(train),      replace = FALSE, mtry = 5, nodesize = 1) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 5
## 
##           Mean of squared residuals: NaN
##                     % Var explained: NaN
p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Random forest",
    method        = "overfit",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Underfit

To underfit, we can reverse all the settings in the previous section. Here we will make the sample for each tree small, try only 1 explanatory variable for each split, require a relatively large number of obserations in each leaf, and use a small number of trees.

m <- randomForest(lprice ~ ., data = train,
                  sampsize = 0.3*nrow(train), # training set is same size as data
                  mtry = 1, # number of variables to try at each split
                  nodesize = 10, # number of observations in each leaf
                  ntree = 20
                  )
m
## 
## Call:
##  randomForest(formula = lprice ~ ., data = train, sampsize = 0.3 *      nrow(train), mtry = 1, nodesize = 10, ntree = 20) 
##                Type of random forest: regression
##                      Number of trees: 20
## No. of variables tried at each split: 1
## 
##           Mean of squared residuals: 0.1147601
##                     % Var explained: 88.59
p_train <- predict(m, newdata = train)
p_test  <- predict(m, newdata = test)

error <- bind_rows(
  error,
  data.frame(
    group         = "Random forest",
    method        = "underfit",
    in_sample     = mean((p_train - train$lprice)^2),
    out_of_sample = mean((p_test  -  test$lprice)^2)
  )
)

Summary

Implementations

There are a number of different packages that implement tree based approaches to regression modeling. Many of those approaches and additional machine and statistical learning approaches exist on the CRAN Task View: Machine Learning and Statisical Learning.

Probabilistic predictions

Tree models are just a particular type of regression model and thus prediction intervals should be straight-forward, but apparently are not implemented in predict.rpart. Prediction intervals construction from random forests models is an active area of research. The piRF package aims to combine multiple approaches for regression models in a single package.

Comparison

Below is a table comparing performance for a number of methods for the log price of diamonds based on continuous explanatory variables. While the results here indicating similar performance amongst all these methods, we should be careful in drawing too many conclusions from this analysis. As a reminder, for computational reasons, we are only using 100 training and testing data points. In addition, the relationship between the explanatory variables and log price seem reasonably linear. For more complicated relationships and more data, the more flexible methods will likely perform better.

error %>%
  datatable(
    rownames = FALSE,
    caption = "In and out-of-sample error for various prediction methods",
    filter = "top"
  ) %>%
  formatRound(columns = c("in_sample","out_of_sample"), digits = 3)