Trees & Forests
and the art of pruning
Introduction
In this tutorial, we will cover tree based methods and give you insight about the advantages of some over others. We will also show how to evaluate the performance of these models. Overall we will,
- discuss two evaluation measures, Gini index and cross entropy
- cover decision and regression trees,
- plot the resulting trees,
- cover bagging, random forest and boosting,
- give insights about how well they performed.
Getting Started
To follow up, you will need
tidyverse
tree
ISLR
randomForest
bc.xlsx
credit.xlsx
## # A tibble: 6 x 10
## Age BMI Glucose Insulin HOMA Leptin Adiponectin Resistin MCP.1
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 48 23.5 70 2.71 0.467 8.81 9.70 8.00 417.
## 2 83 20.7 92 3.12 0.707 8.84 5.43 4.06 469.
## 3 82 23.1 91 4.50 1.01 17.9 22.4 9.28 555.
## 4 68 21.4 77 3.23 0.613 9.88 7.17 12.8 928.
## 5 86 21.1 92 3.55 0.805 6.70 4.82 10.6 774.
## 6 49 22.9 92 3.23 0.732 6.83 13.7 10.3 530.
## # … with 1 more variable: Classification <chr>
Now, call tidyverse
and other packages:
Evaluating Classifiers
When discussing classification in previous weeks, we used an intuitive measure for performance of classifiers, Accuracy and Confusion Matrices. This is a straightforward metrics and usually comes as the first thing to the mind when we ask about performance, how many did you misclassify?
When training models, however, accuracy may not be very handy and too insensitive to improvements. Instead, we want a more smooth function that can be used for calculating derivatives and more connected with the heart (theory) of classification tasks.
The most common measure is Cross Entropy loss which is actually negative log likelihood: “How much my model’s probability distribution fails to fit the true distribution?”:
\[CEloss = -\sum _{i}p_{i}\log q_{i}\ =\ -y\log {\hat {y}}-(1-y)\log(1-{\hat {y}})\]
When training Decision Tree models for classification tasks, we use a similar measure that specifically focuses on misclassification : What is the chance of misclassification if I use this model:
\[Gini = \sum_{i=1}^M p_i(1-p_i)\]
The above measure is between 0 and 1 and evaluates the impurity of the classification, where lower Gini score is better.
library("MLmetrics")
bc$Classification <- as.factor(bc$Classification)
logreg1 <- glm(Classification ~ Glucose , data= bc, family="binomial")
logreg2 <- glm(Classification ~ . , data= bc, family="binomial")
# summary(logreg2)
preds1 <- predict(logreg1, type = "response") > 0.5
preds2 <- predict(logreg2, type = "response") > 0.5
acc1 <- Accuracy(preds1, bc$Classification == 2)
acc2 <- Accuracy(preds2, bc$Classification == 2)
CEloss1 <- LogLoss(preds1, bc$Classification == 2)
CEloss2 <- LogLoss(preds2, bc$Classification == 2)
gini1 <- Gini(preds1, bc$Classification == 2)
gini2 <- Gini(preds2, bc$Classification == 2)
c(acc1,acc2)
## [1] 0.7241379 0.7931034
## [1] 9.528055 7.146030
## [1] 0.03064904 0.25661058
Decision Trees
Decision tree is a linear model that aims to find a bunch of logical rules that can cluster data points into pure subgroups that share the same output value. Here, the purity is measured using Gini index that evalutes impurity. Each iteration, the algorithm tries to find a breakpoint that can split the observations into two groups which are pure as possible inside.
library("tree")
iris.tree <- tree(Species ~ Petal.Length + Sepal.Length, data=iris)
plot(iris.tree)
text(iris.tree)
library("rpart")
library("rattle")
iris.tree <- rpart(Species ~ Petal.Length + Petal.Width, data=iris,)
fancyRpartPlot(iris.tree,caption = "")
We can read the above tree as below:
- \(P(Y=Setosa | Petal.Length < 2.25) = 1\)
- \(P(Y=Versicolor | Petal.Length \geq 2.25 \ \& Petal.Width,1.8) = 0.91\)
- \(P(Y=Viginica | Petal.Length \geq 2.25 \ \& Petal.Width\geq 1.8) = 0.98\)
How does decision boundaries look like?
Decision Trees are linear models but less flexible than the linear regressions. More clearly, the decision region is split into rectangles (no slope).
Based on the above estimated model, the below plot has Petal.Width on x axis and Petal Length on y axis. Keep in mind that it is not a scatter plot investigating the relation between the two, the predicted class is mapped to colour here.
ggplot(iris, aes(x=Petal.Width, y=Petal.Length, colour=Species)) +
geom_point() +
geom_hline(aes(yintercept=2.5), linetype="dashed", colour="black") +
geom_text(aes(y=2.6,x=0, label="2.5"), colour="black") +
geom_linerange(aes(ymin=2.5, ymax = 7.5, x=1.8), linetype="dashed",colour="black") +
geom_vline(aes(xintercept=1.8), linetype="dotted",alpha=.5) +
geom_text(aes(y=1,x=1.9, label="1.8"), colour="black")
The region was split into three based on the decision boundaries estimated by the model.
Cross Validation
The decision trees are greedy algorithms, for a good fit it requires too many data points. However, it also has the tendency of overfitting. So, to avoid such a disaster, similar to regularization, we add a penalty parameter, \(\alpha |T|\), to the Gini loss function. \(\alpha\) is the strength of the penalty and \(|T|\) is how many times the tree splits. As we overfit, the penalty grows linearly and therefore the procedure is forced to stop early without more split.
We will use two packages, tree
and rpart
. The latter by default applies the penalty. For the tree
package, we will use cv.tree
function.
set.seed(156)
bc.tree <- tree(Classification ~ ., data=bc)
cv <- cv.tree(bc.tree, FUN = prune.misclass)
cv
## $size
## [1] 12 11 10 8 4 2 1
##
## $dev
## [1] 35 35 30 31 36 39 57
##
## $k
## [1] -Inf 0.0 1.0 2.0 2.5 4.5 20.0
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
The above code uses prune.misclass
function that prunes the tree at each iteration to improve loss. The output summarizes the k-fold (by default 10-fold) cross validation. The size
in the output reports the number of nodes in the terminal tree and dev
is the error rate (the minimum is the best). k
is not number of validation set; it is the penalty parameter for which we used \(\alpha\) to denote.
According to the above results, the best tree has 10 nodes (minimum dev
is 30). We can use this wisdom to fit a better model:
## y_pred
## y_true 1 2
## 1 48 4
## 2 5 59
Or, you can use rpart
package with its title track rpart
function follows the same procedure with some slight changes:
## y_pred
## y_true 1 2
## 1 42 10
## 2 4 60
Regression Trees
Regression tree is based on the same principle. Here, we are not trying to reduce Gini index but our good old friend, MSE. Again, the goal is to find a subgroup, whose mean squared error is the least. Here, the prediction of each element in the subgroup is the same, the mean of the group.
We will use Boston
dataset of MASS
package that contains information about 506 suburbs in Boston. The output value is median house price and the predictors are crime rate, distance to employment centres and so on. This time, we will properly split the data into train and test:
set.seed(156)
dat <- MASS::Boston
train_ind <- sample(1:nrow(dat), floor(0.5*nrow(dat)))
train <- dat[ train_ind,]
test <- dat[-train_ind,]
boston.tree <- tree(medv ~ .,data = train)
summary(boston.tree)
##
## Regression tree:
## tree(formula = medv ~ ., data = train)
## Variables actually used in tree construction:
## [1] "lstat" "rm" "dis"
## Number of terminal nodes: 7
## Residual mean deviance: 14.8 = 3642 / 246
## Distribution of residuals:
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## -11.9000 -2.3220 -0.3152 0.0000 2.6430 13.1000
## [1] 5.210991
The above is the RMSE of the model without cross validation. Now, we will prune the tree using cross validation:
cv <- cv.tree(boston.tree)
ggplot() + geom_line(aes(cv$size,cv$dev)) + geom_point(aes(cv$size,cv$dev))
## [1] 7
The best model seems to have 7 nodes:
boston.tree.pruned <- prune.tree(boston.tree,best=7)
plot(boston.tree.pruned)
text(boston.tree.pruned)
## [1] 5.210991
Or we can use a rpart
as below:
## [1] 5.202136
Bagging and Random Forest
As discussed earlier, decision trees have all the flaws that we don’t want in ML, they overfit, they are greedy and so on. We overcame this problem by cross validation.
There is an interesting properties of tree’s, (i) they are green, (ii) they love to be pruned, and (iii) they can populate into a forest. An interesting statistical property of them is if we distribute the data into bags (with bootstrapping), fit different trees and average the predictions it is equivalent to without bagging; the only difference is they don’t overfit and require less data.
The above method is called bagging. However, statistics fruit 7 seasons and in one season it produced random forest. This method’s principle is the same, we bag the observations but each time we fit a different tree, one that doesn’t use all predictors but a subset of them. By doing so, the machine knows what happens if that variable didn’t existed and more guarded against overfitting.
Bagging and random forest are essentially the same. For bagging, we fit a tree with all predictors to each bag; for random forest we randomly select a subset of predictors each time. We will do this by controling mtry
parameter:
library(randomForest)
set.seed(156)
dat <- MASS::Boston
train_ind <- sample(1:nrow(dat), floor(0.5*nrow(dat)))
train <- dat[ train_ind,]
test <- dat[-train_ind,]
###################### Bagging ##########################
# Fit Bagging model
boston.bag <- randomForest(medv ~ . , data = train,mtry = ncol(train)-1) # all columns but the output
plot(boston.bag)
The above plot shows the performance when we use ntree
number of trees to fit the model. When we use all the data as one bag, the mean squared error increases. The best is to have 50 observations in our bag each time. Good news is, the randomForest
function choose the best for you.
## [1] 3.791728
That’s quite an improvement compared to regression trees.
Now, let’s grow a forest:
###################### Random Forest ####################
# Fit Bagging model
boston.forest <- randomForest(medv ~ . , data = train,mtry = 6, importance=T)
plot(boston.forest)
## [1] 3.49494
One byproduct of having a forest rather than a tree is, we have many kinds of trees that has different number of nodes using different predictors. We can use this information to understand which predictor was always the more important, the predictor that was selected as the top node most of the times.
IncNodePurity
below is the increase in impurity when the variable is excluded and is averaged over all trees in the forest. The best predictor is therefore the one with highest number. Similarly, %IncMSE
is the percentage increase in mean squared error:
## %IncMSE IncNodePurity
## crim 6.3022590 1204.54752
## zn 0.5837553 169.51097
## indus 5.7631689 1408.12343
## chas 0.3740248 58.44584
## nox 4.0315232 640.43463
## rm 40.4924967 7029.55046
## age 2.6781163 518.40156
## dis 6.7859026 1278.83297
## rad 1.2015012 129.46638
## tax 2.8279482 491.26133
## ptratio 5.8094462 1340.50386
## black 0.8883838 311.44211
## lstat 86.0454893 9483.73956
The above plot visualizes the table. lstat
(lower status of the population in percentage) and rm
(average number of rooms per dwelling) are by far the most essential variables to be included in the models.
Random Forest for Classification
credit <- read.csv("credit.csv")[,-1]
credit$Score <- as.factor(credit$Score) # the output variable must be factor
train_ind <- sample(1:nrow(credit), floor(0.8*nrow(credit)))
train <- credit[ train_ind, ]
test <- credit[-train_ind, ]
library(randomForest)
set.seed(156)
credit.forest <- randomForest(Score ~ . , data = train)
preds.forest <- predict(credit.forest,test, type = "class")
Accuracy(preds.forest,test$Score)
## [1] 0.775
## MeanDecreaseGini
## Status.of.existing.checking.account 36.185705
## Duration.in.months 33.524626
## Credit.history 17.484038
## Purpose 21.052925
## Credit.amount 46.325388
## Savings.account.bonds 15.869092
## Present.employment.since 16.489057
## Installment.rate.in.percentage.of.disposable.income 14.835646
## Personal.status.and.sex 11.097888
## Other.debtors...guarantors 6.473223
## Present.residence.since 14.014699
## Property 15.718111
## Age.in.years 34.438281
## Other.installment.plans 9.872230
## Housing 9.264186
## Number.of.existing.credits.at.this.bank 7.994438
## Job 10.926060
## Number.of.people.being.liable.to.provide.maintenance.for 5.140215
## Telephone 6.791196
## Foreign.worker 1.475774