Training Machine Learning model using Tree-based model#

teaching: 20 exercises: 0 questions:

  • “How to train a Machine Learning model using Tree-based model” objectives:

  • “Learn to use different Tree-based algorithm for Machine Learning training” keypoints:

  • “Decision Tree, Random Forest”

Training the model using Decision Trees#

  • Decision trees can take inputs where some variables are numerical and some are categorical

  • Very handy when dealing with demographic data (gender, age, education level)

  • Works for continuous or categorical outputs

  • Very flexible, can capture complex relationships between inputs and outputs

  • Easy to interpret

  • Training could be slow

  • Prone to overfitting image

Spliting algorithm#

  • Gini Impurity: (Categorical)

  • Chi-Square index (Categorical)

  • Cross-Entropy & Information gain (Categorical)

  • Reduction Variance (Continuous)

Implementation#

Here we will use iris data

library(caret)
data(iris)
set.seed(123)
indT <- createDataPartition(y=iris$Species,p=0.6,list=FALSE)
training <- iris[indT,]
testing  <- iris[-indT,]

Next we will train using method="rpart" with gini splitting algorithm:

ModFit_rpart <- train(Species~.,data=training,method="rpart",
                      parms = list(split = "gini"))
# gini can be replaced by chisquare, entropy, information

#fancier plot
install.packages("rattle")
library(rattle)
fancyRpartPlot(ModFit_rpart$finalModel)

image Apply decision tree model to predict output of testing data

predict_rpart <- predict(ModFit_rpart,testing)
confusionMatrix(predict_rpart, testing$Species)

testing$PredRight <- predict_rpart==testing$Species
ggplot(testing,aes(x=Petal.Width,y=Petal.Length))+
  geom_point(aes(col=PredRight))

image

Train model using Random Forest#

image

  • A single decision tree is prone to overfitting. A solution is to grow a number of decision trees, i.e. a forest.

  • We use portions of the data to grow a tree for each portion (bootstrapping)

  • All trees are used to predict the output, and the final prediction is the majority vote (in classification) or the average (in regression)

  • Very versatile, good generalization (especially for classification), handles missing data well

  • Quite slow and hard to interpret (prone to be used as a “black box”)

Implementation of Random Forest#

  • Run the following on console:

install.packages("randomForest")
  • Run the following:

ModFit_rf <- train(Species~.,data=training,method="rf",prox=TRUE)

predict_rf <- predict(ModFit_rf,testing)
confusionMatrix(predict_rf, testing$Species)

testing$PredRight <- predict_rf==testing$Species
ggplot(testing,aes(x=Petal.Width,y=Petal.Length))+
  geom_point(aes(col=PredRight))

image

We can see that Random Forest result has better prediction than the Decision Tree.