Training Machine Learning model using Tree-based model#
teaching: 20 exercises: 0 questions:
“How to train a Machine Learning model using Tree-based model” objectives:
“Learn to use different Tree-based algorithm for Machine Learning training” keypoints:
“Decision Tree, Random Forest”
Training the model using Decision Trees#
Decision trees can take inputs where some variables are numerical and some are categorical
Very handy when dealing with demographic data (gender, age, education level)
Works for continuous or categorical outputs
Very flexible, can capture complex relationships between inputs and outputs
Easy to interpret
Training could be slow
Prone to overfitting
Spliting algorithm#
Gini Impurity: (Categorical)
Chi-Square index (Categorical)
Cross-Entropy & Information gain (Categorical)
Reduction Variance (Continuous)
Implementation#
Here we will use iris
data
library(caret)
data(iris)
set.seed(123)
indT <- createDataPartition(y=iris$Species,p=0.6,list=FALSE)
training <- iris[indT,]
testing <- iris[-indT,]
Next we will train using method="rpart"
with gini
splitting algorithm:
ModFit_rpart <- train(Species~.,data=training,method="rpart",
parms = list(split = "gini"))
# gini can be replaced by chisquare, entropy, information
#fancier plot
install.packages("rattle")
library(rattle)
fancyRpartPlot(ModFit_rpart$finalModel)
Apply decision tree model to predict output of testing data
predict_rpart <- predict(ModFit_rpart,testing)
confusionMatrix(predict_rpart, testing$Species)
testing$PredRight <- predict_rpart==testing$Species
ggplot(testing,aes(x=Petal.Width,y=Petal.Length))+
geom_point(aes(col=PredRight))
Train model using Random Forest#
A single decision tree is prone to overfitting. A solution is to grow a number of decision trees, i.e. a forest.
We use portions of the data to grow a tree for each portion (bootstrapping)
All trees are used to predict the output, and the final prediction is the majority vote (in classification) or the average (in regression)
Very versatile, good generalization (especially for classification), handles missing data well
Quite slow and hard to interpret (prone to be used as a “black box”)
Implementation of Random Forest#
Run the following on console:
install.packages("randomForest")
Run the following:
ModFit_rf <- train(Species~.,data=training,method="rf",prox=TRUE)
predict_rf <- predict(ModFit_rf,testing)
confusionMatrix(predict_rf, testing$Species)
testing$PredRight <- predict_rf==testing$Species
ggplot(testing,aes(x=Petal.Width,y=Petal.Length))+
geom_point(aes(col=PredRight))
We can see that Random Forest result has better prediction than the Decision Tree.