HomeTutorialsPython

# Decision Tree Classification in Python Tutorial

In this tutorial, learn Decision Tree Classification, attribute selection measures, and how to build and optimize Decision Tree Classifier using Python Scikit-learn package.
Updated Feb 2023  · 12 min read

As a marketing manager, you want a set of customers who are most likely to purchase your product. This is how you can save your marketing budget by finding your audience. As a loan manager, you need to identify risky loan applications to achieve a lower loan default rate. This process of classifying customers into a group of potential and non-potential customers or safe or risky loan applications is known as a classification problem.

Classification is a two-step process; a learning step and a prediction step. In the learning step, the model is developed based on given training data. In the prediction step, the model is used to predict the response to given data. A Decision tree is one of the easiest and most popular classification algorithms used to understand and interpret data. It can be utilized for both classification and regression problems.

## The Decision Tree Algorithm

A decision tree is a flowchart-like tree structure where an internal node represents a feature(or attribute), the branch represents a decision rule, and each leaf node represents the outcome.

The topmost node in a decision tree is known as the root node. It learns to partition on the basis of the attribute value. It partitions the tree in a recursive manner called recursive partitioning. This flowchart-like structure helps you in decision-making. It's visualization like a flowchart diagram which easily mimics the human level thinking. That is why decision trees are easy to understand and interpret.

Image | Abid Ali Awan

A decision tree is a white box type of ML algorithm. It shares internal decision-making logic, which is not available in the black box type of algorithms such as with a neural network. Its training time is faster compared to the neural network algorithm.

The time complexity of decision trees is a function of the number of records and attributes in the given data. The decision tree is a distribution-free or non-parametric method which does not depend upon probability distribution assumptions. Decision trees can handle high-dimensional data with good accuracy.

## How Does the Decision Tree Algorithm Work?

The basic idea behind any decision tree algorithm is as follows:

1. Select the best attribute using Attribute Selection Measures (ASM) to split the records.
2. Make that attribute a decision node and breaks the dataset into smaller subsets.
3. Start tree building by repeating this process recursively for each child until one of the conditions will match:
• All the tuples belong to the same attribute value.
• There are no more remaining attributes.
• There are no more instances.

## Attribute Selection Measures

Attribute selection measure is a heuristic for selecting the splitting criterion that partitions data in the best possible manner. It is also known as splitting rules because it helps us to determine breakpoints for tuples on a given node. ASM provides a rank to each feature (or attribute) by explaining the given dataset. The best score attribute will be selected as a splitting attribute (Source). In the case of a continuous-valued attribute, split points for branches also need to define. The most popular selection measures are Information Gain, Gain Ratio, and Gini Index.

### Information Gain

Claude Shannon invented the concept of entropy, which measures the impurity of the input set. In physics and mathematics, entropy is referred to as the randomness or the impurity in a system. In information theory, it refers to the impurity in a group of examples. Information gain is the decrease in entropy. Information gain computes the difference between entropy before the split and average entropy after the split of the dataset based on given attribute values. ID3 (Iterative Dichotomiser) decision tree algorithm uses information gain.

Where Pi is the probability that an arbitrary tuple in D belongs to class Ci.

Where:

• Info(D) is the average amount of information needed to identify the class label of a tuple in D.
• |Dj|/|D| acts as the weight of the jth partition.
• InfoA(D) is the expected information required to classify a tuple from D based on the partitioning by A.

The attribute A with the highest information gain, Gain(A), is chosen as the splitting attribute at node N().

### Gain Ratio

Information gain is biased for the attribute with many outcomes. It means it prefers the attribute with a large number of distinct values. For instance, consider an attribute with a unique identifier, such as customer_ID, that has zero info(D) because of pure partition. This maximizes the information gain and creates useless partitioning.

C4.5, an improvement of ID3, uses an extension to information gain known as the gain ratio. Gain ratio handles the issue of bias by normalizing the information gain using Split Info. Java implementation of the C4.5 algorithm is known as J48, which is available in WEKA data mining tool.

Where:

• |Dj|/|D| acts as the weight of the jth partition.
• v is the number of discrete values in attribute A.

The gain ratio can be defined as

The attribute with the highest gain ratio is chosen as the splitting attribute (Source).

### Gini index

Another decision tree algorithm CART (Classification and Regression Tree) uses the Gini method to create split points.

Where pi is the probability that a tuple in D belongs to class Ci.

The Gini Index considers a binary split for each attribute. You can compute a weighted sum of the impurity of each partition. If a binary split on attribute A partitions data D into D1 and D2, the Gini index of D is:

In the case of a discrete-valued attribute, the subset that gives the minimum gini index for that chosen is selected as a splitting attribute. In the case of continuous-valued attributes, the strategy is to select each pair of adjacent values as a possible split point, and a point with a smaller gini index is chosen as the splitting point.

The attribute with the minimum Gini index is chosen as the splitting attribute.

### .css-1531qan{-webkit-text-decoration:none;text-decoration:none;color:inherit;}Machine Learning with Tree-Based Models in Python

BeginnerSkill Level
5 hr
86.6K learners
In this course, you'll learn how to use tree-based models and ensembles for regression and classification using scikit-learn.

### Machine Learning for Marketing in Python

BeginnerSkill Level
4 hr
11.9K learners
From customer lifetime value, predicting churn to segmentation - learn and implement Machine Learning use cases for Marketing in Python.

Run and edit the code from this tutorial online

Run Code

## Decision Tree Classifier Building in Scikit-learn

### Importing Required Libraries

Let's first load the required libraries.

``````# Load libraries
import pandas as pd
from sklearn.tree import DecisionTreeClassifier # Import Decision Tree Classifier
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn import metrics #Import scikit-learn metrics module for accuracy calculation
``````

``````col_names = ['pregnant', 'glucose', 'bp', 'skin', 'insulin', 'bmi', 'pedigree', 'age', 'label']
``````
``````pima.head()
``````
pregnant glucose bp skin insulin bmi pedigree age label
0 6 148 72 35 0 33.6 0.627 50 1
1 1 85 66 29 0 26.6 0.351 31 0
2 8 183 64 0 0 23.3 0.672 32 1
3 1 89 66 23 94 28.1 0.167 21 0
4 0 137 40 35 168 43.1 2.288 33 1

### Feature Selection

Here, you need to divide given columns into two types of variables dependent(or target variable) and independent variable(or feature variables).

``````#split dataset in features and target variable
feature_cols = ['pregnant', 'insulin', 'bmi', 'age','glucose','bp','pedigree']
X = pima[feature_cols] # Features
y = pima.label # Target variable
``````

### Splitting Data

To understand model performance, dividing the dataset into a training set and a test set is a good strategy.

Let's split the dataset by using the function `train_test_split()`. You need to pass three parameters features; target, and test_set size.

``````# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1) # 70% training and 30% test
``````

### Building Decision Tree Model

Let's create a decision tree model using Scikit-learn.

``````# Create Decision Tree classifer object
clf = DecisionTreeClassifier()

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

#Predict the response for test dataset
y_pred = clf.predict(X_test)
``````

### Evaluating the Model

Let's estimate how accurately the classifier or model can predict the type of cultivars.

Accuracy can be computed by comparing actual test set values and predicted values.

``````# Model Accuracy, how often is the classifier correct?
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
``````
``````Accuracy: 0.6753246753246753
``````

We got a classification rate of 67.53%, which is considered as good accuracy. You can improve this accuracy by tuning the parameters in the decision tree algorithm.

### Visualizing Decision Trees

You can use Scikit-learn's export_graphviz function for display the tree within a Jupyter notebook. For plotting the tree, you also need to install graphviz and pydotplus.

`pip install graphviz`

`pip install pydotplus`

The export_graphviz function converts the decision tree classifier into a dot file, and pydotplus converts this dot file to png or displayable form on Jupyter.

``````from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO
from IPython.display import Image
import pydotplus

dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,
filled=True, rounded=True,
special_characters=True,feature_names = feature_cols,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('diabetes.png')
Image(graph.create_png())
``````

In the decision tree chart, each internal node has a decision rule that splits the data. Gini, referred to as Gini ratio, measures the impurity of the node. You can say a node is pure when all of its records belong to the same class, such nodes known as the leaf node.

Here, the resultant tree is unpruned. This unpruned tree is unexplainable and not easy to understand. In the next section, let's optimize it by pruning.

## Optimizing Decision Tree Performance

• criterion : optional (default=”gini”) or Choose attribute selection measure. This parameter allows us to use the different-different attribute selection measure. Supported criteria are “gini” for the Gini index and “entropy” for the information gain.

• splitter : string, optional (default=”best”) or Split Strategy. This parameter allows us to choose the split strategy. Supported strategies are “best” to choose the best split and “random” to choose the best random split.

• max_depth : int or None, optional (default=None) or Maximum Depth of a Tree. The maximum depth of the tree. If None, then nodes are expanded until all the leaves contain less than min_samples_split samples. The higher value of maximum depth causes overfitting, and a lower value causes underfitting (Source).

In Scikit-learn, optimization of decision tree classifier performed by only pre-pruning. Maximum depth of the tree can be used as a control variable for pre-pruning. In the following the example, you can plot a decision tree on the same data with max_depth=3. Other than pre-pruning parameters, You can also try other attribute selection measure such as entropy.

``````# Create Decision Tree classifer object
clf = DecisionTreeClassifier(criterion="entropy", max_depth=3)

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

#Predict the response for test dataset
y_pred = clf.predict(X_test)

# Model Accuracy, how often is the classifier correct?
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
``````
``````Accuracy: 0.7705627705627706
``````

Well, the classification rate increased to 77.05%, which is better accuracy than the previous model.

### Visualizing Decision Trees

Let's make our decision tree a little easier to understand using the following code:

```from six import StringIO from IPython.display import Image from sklearn.tree import export_graphviz import pydotplus dot_data = StringIO() export_graphviz(clf, out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols,class_names=['0','1']) graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) graph.write_png('diabetes.png') Image(graph.create_png()) ```

Here, we've completed the following steps:

• Imported the required libraries.
• Created a `StringIO` object called `dot_data` to hold the text representation of the decision tree.
• Exported the decision tree to the `dot` format using the `export_graphviz` function and write the output to the `dot_data` buffer.
• Created a `pydotplus` graph object from the `dot` format representation of the decision tree stored in the `dot_data` buffer.
• Written the generated graph to a PNG file named "diabetes.png".
• Displayed the generated PNG image of the decision tree using the `Image` object from the `IPython.display` module.

As you can see, this pruned model is less complex, more explainable, and easier to understand than the previous decision tree model plot.

## Decision Tree Pros

• Decision trees are easy to interpret and visualize.
• It can easily capture Non-linear patterns.
• It requires fewer data preprocessing from the user, for example, there is no need to normalize columns.
• It can be used for feature engineering such as predicting missing values, suitable for variable selection.
• The decision tree has no assumptions about distribution because of the non-parametric nature of the algorithm. (Source)

## Decision Tree Cons

• Sensitive to noisy data. It can overfit noisy data.
• The small variation(or variance) in data can result in the different decision tree. This can be reduced by bagging and boosting algorithms.
• Decision trees are biased with imbalance dataset, so it is recommended that balance out the dataset before creating the decision tree.

## Conclusion

Congratulations, you have made it to the end of this tutorial!

In this tutorial, you covered a lot of details about decision trees; how they work, attribute selection measures such as Information Gain, Gain Ratio, and Gini Index, decision tree model building, visualization, and evaluation of a diabetes dataset using Python's Scikit-learn package. We also discussed its pros, cons, and how to optimize decision tree performance using parameter tuning.

Hopefully, you can now utilize the decision tree algorithm to analyze your own datasets.

If you want to learn more about Machine Learning in Python, take DataCamp's Machine Learning with Tree-Based Models in Python course.

Check out our Kaggle Tutorial: Your First Machine Learning Model.

## Get certified in your dream Data Scientist role

Topics

Python Courses

Course

### Introduction to Python

4 hr
5.5M
Master the basics of data analysis with Python in just four hours. This online course will introduce the Python interface and explore popular packages.
See Details
Start Course

Course

### Introduction to Data Science in Python

4 hr
455.8K
Dive into data science using Python and learn how to effectively analyze and visualize your data. No coding experience or skills needed.

Course

### Intermediate Python

4 hr
1.1M
Level up your data science skills by creating visualizations using Matplotlib and manipulating DataFrames with pandas.
See More
Related

tutorial

### A Comprehensive Tutorial on Optical Character Recognition (OCR) in Python With Pytesseract

Master the fundamentals of optical character recognition in OCR with PyTesseract and OpenCV.

Bex Tuychiev

11 min

tutorial

### An Introduction to Vector Databases For Machine Learning: A Hands-On Guide With Examples

Explore vector databases in ML with our guide. Learn to implement vector embeddings and practical applications.

Gary Alway

8 min

tutorial

### Encapsulation in Python Object-Oriented Programming: A Comprehensive Guide

Learn the fundamentals of implementing encapsulation in Python object-oriented programming.

Bex Tuychiev

11 min

tutorial

### Python KeyError Exceptions and How to Fix Them

Learn key techniques such as exception handling and error prevention to handle the KeyError exception in Python effectively.

Javier Canales Luna

6 min

code-along

### Full Stack Data Engineering with Python

In this session, you'll see a full data workflow using some LIGO gravitational wave data (no physics knowledge required). You'll see how to work with HDF5 files, clean and analyze time series data, and visualize the results.

Blenda Guedes

code-along

### Getting Started with Machine Learning Using ChatGPT

In this session Francesca Donadoni, a Curriculum Manager at DataCamp, shows you how to make use of ChatGPT to implement a simple machine learning workflow.