Decision trees can suffer from high variance which makes their results fragile to the specific training data used.
Building multiple models from samples of your training data, called bagging, can reduce this variance, but the trees are highly correlated.
Random Forest is an extension of bagging that in addition to building trees based on multiple samples of your training data, it also constrains the features that can be used to build the trees, forcing trees to be different. This, in turn, can give a lift in performance.
In this tutorial, you will discover how to implement the Random Forest algorithm from scratch in Python.
After completing this tutorial, you will know:
- The difference between bagged decision trees and the random forest algorithm.
- How to construct bagged decision trees with more variance.
- How to apply the random forest algorithm to a predictive modeling problem.
Kick-start your project with my new book Machine Learning Algorithms From Scratch, including step-by-step tutorials and the Python source code files for all examples.
Let’s get started.
- Update Jan/2017: Changed the calculation of fold_size in cross_validation_split() to always be an integer. Fixes issues with Python 3.
- Update Feb/2017: Fixed a bug in build_tree.
- Update Aug/2017: Fixed a bug in Gini calculation, added the missing weighting of group Gini scores by group size (thanks Michael!).
- Update Aug/2018: Tested and updated to work with Python 3.6.
Description
This section provides a brief introduction to the Random Forest algorithm and the Sonar dataset used in this tutorial.
Random Forest Algorithm
Decision trees involve the greedy selection of the best split point from the dataset at each step.
This algorithm makes decision trees susceptible to high variance if they are not pruned. This high variance can be harnessed and reduced by creating multiple trees with different samples of the training dataset (different views of the problem) and combining their predictions. This approach is called bootstrap aggregation or bagging for short.
A limitation of bagging is that the same greedy algorithm is used to create each tree, meaning that it is likely that the same or very similar split points will be chosen in each tree making the different trees very similar (trees will be correlated). This, in turn, makes their predictions similar, mitigating the variance originally sought.
We can force the decision trees to be different by limiting the features (rows) that the greedy algorithm can evaluate at each split point when creating the tree. This is called the Random Forest algorithm.
Like bagging, multiple samples of the training dataset are taken and a different tree trained on each. The difference is that at each point a split is made in the data and added to the tree, only a fixed subset of attributes can be considered.
For classification problems, the type of problems we will look at in this tutorial, the number of attributes to be considered for the split is limited to the square root of the number of input features.
The result of this one small change are trees that are more different from each other (uncorrelated) resulting predictions that are more diverse and a combined prediction that often has better performance that single tree or bagging alone.
Sonar Dataset
The dataset we will use in this tutorial is the Sonar dataset.
This is a dataset that describes sonar chirp returns bouncing off different surfaces. The 60 input variables are the strength of the returns at different angles. It is a binary classification problem that requires a model to differentiate rocks from metal cylinders. There are 208 observations.
It is a well-understood dataset. All of the variables are continuous and generally in the range of 0 to 1. The output variable is a string “M” for mine and “R” for rock, which will need to be converted to integers 1 and 0.
By predicting the class with the most observations in the dataset (M or mines) the Zero Rule Algorithm can achieve an accuracy of 53%.
You can learn more about this dataset at the UCI Machine Learning repository.
Download the dataset for free and place it in your working directory with the filename sonar.all-data.csv.
Tutorial
This tutorial is broken down into 2 steps.
- Calculating Splits.
- Sonar Dataset Case Study.
These steps provide the foundation that you need to implement and apply the Random Forest algorithm to your own predictive modeling problems.
1. Calculating Splits
In a decision tree, split points are chosen by finding the attribute and the value of that attribute that results in the lowest cost.
For classification problems, this cost function is often the Gini index, that calculates the purity of the groups of data created by the split point. A Gini index of 0 is perfect purity where class values are perfectly separated into two groups, in the case of a two-class classification problem.
Finding the best split point in a decision tree involves evaluating the cost of each value in the training dataset for each input variable.
For bagging and random forest, this procedure is executed upon a sample of the training dataset, made with replacement. Sampling with replacement means that the same row may be chosen and added to the sample more than once.
We can update this procedure for Random Forest. Instead of enumerating all values for input attributes in search if the split with the lowest cost, we can create a sample of the input attributes to consider.
This sample of input attributes can be chosen randomly and without replacement, meaning that each input attribute needs only be considered once when looking for the split point with the lowest cost.
Below is a function name get_split() that implements this procedure. It takes a dataset and a fixed number of input features from to evaluate as input arguments, where the dataset may be a sample of the actual training dataset.
The helper function test_split() is used to split the dataset by a candidate split point and gini_index() is used to evaluate the cost of a given split by the groups of rows created.
We can see that a list of features is created by randomly selecting feature indices and adding them to a list (called features), this list of features is then enumerated and specific values in the training dataset evaluated as split points.
Now that we know how a decision tree algorithm can be modified for use with the Random Forest algorithm, we can piece this together with an implementation of bagging and apply it to a real-world dataset.
2. Sonar Dataset Case Study
In this section, we will apply the Random Forest algorithm to the Sonar dataset.
The example assumes that a CSV copy of the dataset is in the current working directory with the file name sonar.all-data.csv.
The dataset is first loaded, the string values converted to numeric and the output column is converted from strings to the integer values of 0 and 1. This is achieved with helper functions load_csv(), str_column_to_float() and str_column_to_int() to load and prepare the dataset.
We will use k-fold cross validation to estimate the performance of the learned model on unseen data. This means that we will construct and evaluate k models and estimate the performance as the mean model error. Classification accuracy will be used to evaluate each model. These behaviors are provided in the cross_validation_split(), accuracy_metric() and evaluate_algorithm() helper functions.
We will also use an implementation of the Classification and Regression Trees (CART) algorithm adapted for bagging including the helper functions test_split() to split a dataset into groups, gini_index() to evaluate a split point, our modified get_split() function discussed in the previous step, to_terminal(), split() and build_tree() used to create a single decision tree, predict() to make a prediction with a decision tree, subsample() to make a subsample of the training dataset and bagging_predict() to make a prediction with a list of decision trees.
A new function name random_forest() is developed that first creates a list of decision trees from subsamples of the training dataset and then uses them to make predictions.
As we stated above, the key difference between Random Forest and bagged decision trees is the one small change to the way that trees are created, here in the get_split() function.
The complete example is listed below.
A k value of 5 was used for cross-validation, giving each fold 208/5 = 41.6 or just over 40 records to be evaluated upon each iteration.
Deep trees were constructed with a max depth of 10 and a minimum number of training rows at each node of 1. Samples of the training dataset were created with the same size as the original dataset, which is a default expectation for the Random Forest algorithm.
The number of features considered at each split point was set to sqrt(num_features) or sqrt(60)=7.74 rounded to 7 features.
A suite of 3 different numbers of trees were evaluated for comparison, showing the increasing skill as more trees are added.
Running the example prints the scores for each fold and mean score for each configuration.
Extensions
This section lists extensions to this tutorial that you may be interested in exploring.
- Algorithm Tuning. The configuration used in the tutorial was found with a little trial and error but was not optimized. Experiment with larger numbers of trees, different numbers of features and even different tree configurations to improve performance.
- More Problems. Apply the technique to other classification problems and even adapt it for regression with a new cost function and a new method for combining the predictions from trees.
Did you try any of these extensions?
Share your experiences in the comments below.
Review
In this tutorial, you discovered how to implement the Random Forest algorithm from scratch.
Specifically, you learned:
- The difference between Random Forest and Bagged Decision Trees.
- How to update the creation of decision trees to accommodate the Random Forest procedure.
- How to apply the Random Forest algorithm to a real world predictive modeling problem.
Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.
No comments:
Post a Comment