Trees & Forests

and the art of pruning


Introduction

In this tutorial, we will cover tree based methods and give you insight about the advantages of some over others. We will also show how to evaluate the performance of these models. Overall we will,

  • discuss two evaluation measures, Gini index and cross entropy
  • cover decision and regression trees,
  • plot the resulting trees,
  • cover bagging, random forest and boosting,
  • give insights about how well they performed.

Getting Started

To follow up, you will need

  • tidyverse
  • tree
  • ISLR
  • randomForest
  • bc.xlsx
  • credit.xlsx
bc <- readxl::read_excel("bc.xlsx")
head(bc)
## # A tibble: 6 x 10
##     Age   BMI Glucose Insulin  HOMA Leptin Adiponectin Resistin MCP.1
##   <dbl> <dbl>   <dbl>   <dbl> <dbl>  <dbl>       <dbl>    <dbl> <dbl>
## 1    48  23.5      70    2.71 0.467   8.81        9.70     8.00  417.
## 2    83  20.7      92    3.12 0.707   8.84        5.43     4.06  469.
## 3    82  23.1      91    4.50 1.01   17.9        22.4      9.28  555.
## 4    68  21.4      77    3.23 0.613   9.88        7.17    12.8   928.
## 5    86  21.1      92    3.55 0.805   6.70        4.82    10.6   774.
## 6    49  22.9      92    3.23 0.732   6.83       13.7     10.3   530.
## # … with 1 more variable: Classification <chr>

Now, call tidyverse and other packages:

library('tidyverse')
library('tree')
library('ISLR')
theme_set(theme_minimal())

Evaluating Classifiers

When discussing classification in previous weeks, we used an intuitive measure for performance of classifiers, Accuracy and Confusion Matrices. This is a straightforward metrics and usually comes as the first thing to the mind when we ask about performance, how many did you misclassify?

When training models, however, accuracy may not be very handy and too insensitive to improvements. Instead, we want a more smooth function that can be used for calculating derivatives and more connected with the heart (theory) of classification tasks.

The most common measure is Cross Entropy loss which is actually negative log likelihood: “How much my model’s probability distribution fails to fit the true distribution?”:

\[CEloss = -\sum _{i}p_{i}\log q_{i}\ =\ -y\log {\hat {y}}-(1-y)\log(1-{\hat {y}})\]

When training Decision Tree models for classification tasks, we use a similar measure that specifically focuses on misclassification : What is the chance of misclassification if I use this model:

\[Gini = \sum_{i=1}^M p_i(1-p_i)\]

The above measure is between 0 and 1 and evaluates the impurity of the classification, where lower Gini score is better.

library("MLmetrics")
bc$Classification <- as.factor(bc$Classification)
logreg1 <- glm(Classification ~ Glucose , data= bc, family="binomial")
logreg2 <- glm(Classification ~ . , data= bc, family="binomial")

# summary(logreg2)
preds1 <- predict(logreg1, type = "response") > 0.5
preds2 <- predict(logreg2, type = "response") > 0.5

acc1    <- Accuracy(preds1, bc$Classification == 2)
acc2    <- Accuracy(preds2, bc$Classification == 2)
CEloss1 <- LogLoss(preds1, bc$Classification == 2)
CEloss2 <- LogLoss(preds2, bc$Classification == 2)
gini1   <- Gini(preds1, bc$Classification == 2)
gini2   <- Gini(preds2, bc$Classification == 2)

c(acc1,acc2)
## [1] 0.7241379 0.7931034
c(CEloss1,CEloss2)
## [1] 9.528055 7.146030
c(gini1,gini2)
## [1] 0.03064904 0.25661058

Decision Trees

Decision tree is a linear model that aims to find a bunch of logical rules that can cluster data points into pure subgroups that share the same output value. Here, the purity is measured using Gini index that evalutes impurity. Each iteration, the algorithm tries to find a breakpoint that can split the observations into two groups which are pure as possible inside.

library("tree")

iris.tree <- tree(Species ~ Petal.Length + Sepal.Length, data=iris)
plot(iris.tree)
text(iris.tree)

library("rpart")
library("rattle")
iris.tree <- rpart(Species ~ Petal.Length + Petal.Width, data=iris,)
fancyRpartPlot(iris.tree,caption = "")

We can read the above tree as below:

  • \(P(Y=Setosa | Petal.Length < 2.25) = 1\)
  • \(P(Y=Versicolor | Petal.Length \geq 2.25 \ \& Petal.Width,1.8) = 0.91\)
  • \(P(Y=Viginica | Petal.Length \geq 2.25 \ \& Petal.Width\geq 1.8) = 0.98\)

How does decision boundaries look like?

Decision Trees are linear models but less flexible than the linear regressions. More clearly, the decision region is split into rectangles (no slope).

Based on the above estimated model, the below plot has Petal.Width on x axis and Petal Length on y axis. Keep in mind that it is not a scatter plot investigating the relation between the two, the predicted class is mapped to colour here.

ggplot(iris, aes(x=Petal.Width, y=Petal.Length, colour=Species)) + 
  geom_point() + 
  geom_hline(aes(yintercept=2.5), linetype="dashed", colour="black") + 
  geom_text(aes(y=2.6,x=0, label="2.5"), colour="black") +
  geom_linerange(aes(ymin=2.5, ymax = 7.5, x=1.8), linetype="dashed",colour="black") + 
  geom_vline(aes(xintercept=1.8), linetype="dotted",alpha=.5) + 
  geom_text(aes(y=1,x=1.9, label="1.8"), colour="black")

The region was split into three based on the decision boundaries estimated by the model.

Cross Validation

The decision trees are greedy algorithms, for a good fit it requires too many data points. However, it also has the tendency of overfitting. So, to avoid such a disaster, similar to regularization, we add a penalty parameter, \(\alpha |T|\), to the Gini loss function. \(\alpha\) is the strength of the penalty and \(|T|\) is how many times the tree splits. As we overfit, the penalty grows linearly and therefore the procedure is forced to stop early without more split.

We will use two packages, tree and rpart. The latter by default applies the penalty. For the tree package, we will use cv.tree function.

set.seed(156)
bc.tree <- tree(Classification ~ ., data=bc)
cv <- cv.tree(bc.tree, FUN = prune.misclass)
cv
## $size
## [1] 12 11 10  8  4  2  1
## 
## $dev
## [1] 35 35 30 31 36 39 57
## 
## $k
## [1] -Inf  0.0  1.0  2.0  2.5  4.5 20.0
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

The above code uses prune.misclass function that prunes the tree at each iteration to improve loss. The output summarizes the k-fold (by default 10-fold) cross validation. The size in the output reports the number of nodes in the terminal tree and dev is the error rate (the minimum is the best). k is not number of validation set; it is the penalty parameter for which we used \(\alpha\) to denote.

ggplot() + 
  geom_point(aes(x=cv$size, y=cv$dev)) + 
  geom_line(aes(x=cv$size, y=cv$dev))

According to the above results, the best tree has 10 nodes (minimum dev is 30). We can use this wisdom to fit a better model:

bc.tree.pruned <- prune.misclass(bc.tree,best=10)
plot(bc.tree.pruned)
text(bc.tree.pruned)

preds <- predict(bc.tree.pruned, type = "class")
ConfusionMatrix(preds, bc$Classification)
##       y_pred
## y_true  1  2
##      1 48  4
##      2  5 59

Or, you can use rpart package with its title track rpart function follows the same procedure with some slight changes:

bc.tree <- rpart(Classification ~ ., data=bc)
fancyRpartPlot(bc.tree, main = "")

preds   <- predict(bc.tree, type = "class")
ConfusionMatrix(preds, bc$Classification)
##       y_pred
## y_true  1  2
##      1 42 10
##      2  4 60

Regression Trees

Regression tree is based on the same principle. Here, we are not trying to reduce Gini index but our good old friend, MSE. Again, the goal is to find a subgroup, whose mean squared error is the least. Here, the prediction of each element in the subgroup is the same, the mean of the group.

We will use Boston dataset of MASS package that contains information about 506 suburbs in Boston. The output value is median house price and the predictors are crime rate, distance to employment centres and so on. This time, we will properly split the data into train and test:

set.seed(156)
dat <- MASS::Boston
train_ind <- sample(1:nrow(dat), floor(0.5*nrow(dat)))

train <- dat[ train_ind,]
test  <- dat[-train_ind,]

boston.tree <- tree(medv ~ .,data = train)
summary(boston.tree)
## 
## Regression tree:
## tree(formula = medv ~ ., data = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm"    "dis"  
## Number of terminal nodes:  7 
## Residual mean deviance:  14.8 = 3642 / 246 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -11.9000  -2.3220  -0.3152   0.0000   2.6430  13.1000
plot(boston.tree)
text(boston.tree)

# Test Performance
preds <- predict(boston.tree,test)
RMSE(preds,test$medv)
## [1] 5.210991

The above is the RMSE of the model without cross validation. Now, we will prune the tree using cross validation:

cv <- cv.tree(boston.tree)
ggplot() + geom_line(aes(cv$size,cv$dev)) + geom_point(aes(cv$size,cv$dev))

cv$size[which.min(cv$dev)]
## [1] 7

The best model seems to have 7 nodes:

boston.tree.pruned <- prune.tree(boston.tree,best=7)
plot(boston.tree.pruned)
text(boston.tree.pruned)

preds <- predict(boston.tree.pruned,test)
RMSE(preds,test$medv)
## [1] 5.210991

Or we can use a rpart as below:

boston.tree <- rpart(medv ~ .,data = train)
fancyRpartPlot(boston.tree, caption ="")

# Test Performance
preds <- predict(boston.tree,test)
RMSE(preds,test$medv)
## [1] 5.202136

Bagging and Random Forest

As discussed earlier, decision trees have all the flaws that we don’t want in ML, they overfit, they are greedy and so on. We overcame this problem by cross validation.

There is an interesting properties of tree’s, (i) they are green, (ii) they love to be pruned, and (iii) they can populate into a forest. An interesting statistical property of them is if we distribute the data into bags (with bootstrapping), fit different trees and average the predictions it is equivalent to without bagging; the only difference is they don’t overfit and require less data.

The above method is called bagging. However, statistics fruit 7 seasons and in one season it produced random forest. This method’s principle is the same, we bag the observations but each time we fit a different tree, one that doesn’t use all predictors but a subset of them. By doing so, the machine knows what happens if that variable didn’t existed and more guarded against overfitting.

Bagging and random forest are essentially the same. For bagging, we fit a tree with all predictors to each bag; for random forest we randomly select a subset of predictors each time. We will do this by controling mtry parameter:

library(randomForest)
set.seed(156)
dat <- MASS::Boston
train_ind <- sample(1:nrow(dat), floor(0.5*nrow(dat)))
train <- dat[ train_ind,]
test  <- dat[-train_ind,]

###################### Bagging ##########################
# Fit Bagging model
boston.bag <- randomForest(medv ~ . , data = train,mtry = ncol(train)-1) # all columns but the output
plot(boston.bag)

The above plot shows the performance when we use ntree number of trees to fit the model. When we use all the data as one bag, the mean squared error increases. The best is to have 50 observations in our bag each time. Good news is, the randomForest function choose the best for you.

# Test performance]
preds.bag <- predict(boston.bag,test)
RMSE(preds.bag,test$medv) 
## [1] 3.791728

That’s quite an improvement compared to regression trees.

Now, let’s grow a forest:

###################### Random Forest ####################
# Fit Bagging model
boston.forest <- randomForest(medv ~ . , data = train,mtry = 6, importance=T) 
plot(boston.forest)

# Test accuracy
preds.forest <- predict(boston.forest,test)
RMSE(preds.forest,test$medv) 
## [1] 3.49494

One byproduct of having a forest rather than a tree is, we have many kinds of trees that has different number of nodes using different predictors. We can use this information to understand which predictor was always the more important, the predictor that was selected as the top node most of the times.

IncNodePurity below is the increase in impurity when the variable is excluded and is averaged over all trees in the forest. The best predictor is therefore the one with highest number. Similarly, %IncMSE is the percentage increase in mean squared error:

boston.forest$importance
##            %IncMSE IncNodePurity
## crim     6.3022590    1204.54752
## zn       0.5837553     169.51097
## indus    5.7631689    1408.12343
## chas     0.3740248      58.44584
## nox      4.0315232     640.43463
## rm      40.4924967    7029.55046
## age      2.6781163     518.40156
## dis      6.7859026    1278.83297
## rad      1.2015012     129.46638
## tax      2.8279482     491.26133
## ptratio  5.8094462    1340.50386
## black    0.8883838     311.44211
## lstat   86.0454893    9483.73956
varImpPlot(boston.forest)

The above plot visualizes the table. lstat (lower status of the population in percentage) and rm (average number of rooms per dwelling) are by far the most essential variables to be included in the models.

Random Forest for Classification

credit <- read.csv("credit.csv")[,-1]
credit$Score <- as.factor(credit$Score) # the output variable must be factor

train_ind <- sample(1:nrow(credit), floor(0.8*nrow(credit)))
train <- credit[ train_ind, ]
test  <- credit[-train_ind, ]

library(randomForest)
set.seed(156)

credit.forest <- randomForest(Score ~ . , data = train) 

preds.forest <- predict(credit.forest,test, type = "class")
Accuracy(preds.forest,test$Score) 
## [1] 0.775
credit.forest$importance
##                                                          MeanDecreaseGini
## Status.of.existing.checking.account                             36.185705
## Duration.in.months                                              33.524626
## Credit.history                                                  17.484038
## Purpose                                                         21.052925
## Credit.amount                                                   46.325388
## Savings.account.bonds                                           15.869092
## Present.employment.since                                        16.489057
## Installment.rate.in.percentage.of.disposable.income             14.835646
## Personal.status.and.sex                                         11.097888
## Other.debtors...guarantors                                       6.473223
## Present.residence.since                                         14.014699
## Property                                                        15.718111
## Age.in.years                                                    34.438281
## Other.installment.plans                                          9.872230
## Housing                                                          9.264186
## Number.of.existing.credits.at.this.bank                          7.994438
## Job                                                             10.926060
## Number.of.people.being.liable.to.provide.maintenance.for         5.140215
## Telephone                                                        6.791196
## Foreign.worker                                                   1.475774
varImpPlot(credit.forest)