Skip to main content

Introduction to Plotting with Matplotlib in Python

This tutorial demonstrates how to use Matplotlib, a powerful data visualization library in Python, to create line, bar, and scatter plots with stock market data.
Updated Mar 2023  · 25 min read

Matplotlib is a powerful and very popular data visualization library in Python. In this tutorial, we will discuss how to create line plots, bar plots, and scatter plots in Matplotlib using stock market data in 2022. These are the foundational plots that will allow you to start understanding, visualizing, and telling stories about data. Data visualization is an essential skill for all data analysts and Matplotlib is one of the most popular libraries for creating visualizations. 

This tutorial expects some basic prior knowledge in NumPy arrays and pandas dataframes. When we use those libraries, we will quickly explain what we are doing. The main focus of this tutorial is Matplotlib, which works on top of these data structures to create visualizations. 

Matplotlib is very flexible and customizable for creating plots. It does require a lot of code to make more basic plots with little customizations. When working in a setting where exploratory data analysis is the main goal, requiring many quickly drawn plots without as much emphasis on aesthetics, the library seaborn is a great option as it builds on top of Matplotlib to create visualizations more quickly. Please see our Python Seaborn Tutorial For Beginners instead if exploratory data analysis or quick and easy graph creation is your main priority.

Matplotlib Examples

By the end of this tutorial, you will be able to make great-looking visualizations in Matplotlib. We will focus on creating line plots, bar plots, and scatter plots. We will also focus on how to make customization decisions, such as the use of color, how to label plots, and how to organize them in a clear way to tell a compelling story.

image14.png

image13.png

image7.png

The Dataset

Matplotlib is designed to work with NumPy arrays and pandas dataframes. The library makes it easy to make graphs from tabular data. For this tutorial, we will use the Dow Jones Industrial Average (DJIA) index’s historical prices from 2022-01-01 to 2022-12-31 (found here). You can set the date range on the page and then click the “download a spreadsheet” button.

We will load in the csv file, named HistoricalPrices.csv using the pandas library and view the first rows using the .head() method.

import pandas as pd

djia_data = pd.read_csv('HistoricalPrices.csv')
djia_data.head()

image19.png

We see the data include 4 columns, a Date, Open, High, Low, and Close. The latter 4 are related to the price of the index during the trading day. Below is a brief explanation of each variable.

  • Date: The day that the stock price information represents.
  • Open: The price of the DJIA at 9:30 AM ET when the stock market opens.
  • High: The highest price the DJIA reached during the day.
  • Low: The lowest price the DJIA reached during the day.
  • Close: The price of the DJIA when the market stopped trading at 4:00 PM ET.

As a quick clean up step, we will also need to use the rename() method in pandas as the dataset we downloaded has an extra space in the column names. 

djia_data = djia_data.rename(columns = {' Open': 'Open', ' High': 'High', ' Low': 'Low', ' Close': 'Close'})

We will also ensure that the Date variable is a datetime variable and sort in ascending order by the date.  

djia_data['Date'] = pd.to_datetime(djia_data['Date'])
djia_data = djia_data.sort_values(by = 'Date')

Loading Matplotlib

Next, we will load the pyplot submodule of Matplotlib so that we can draw our plots. The pyplot module contains all of the relevant methods we will need to create plots and style them. We will use the conventional alias plt. We will also load in pandas, numpy, and datetime for future parts of this tutorial.

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from datetime import datetime

Drawing Line Plots

The first plot we will create will be a line plot. Line plots are a very important plot type as they do a great job of displaying time series data. It is often important to visualize how KPIs change over time to understand patterns in data that can be actioned on.

Line Plots with a Single Line

  • Show how to draw a simple line plot with a single line.
    • Make sure to emphasize the use of plt.show() so the plot actually displays.
  • Provide brief commentary on the plot, including interpretation.

We can create a line plot in matplotlib using the plt.plot() method where the first argument is the x variable and the second argument is the y variable in our line plot. Whenever we create a plot, we need to make sure to call plt.show() to ensure we see the graph we have created.  We will visualize the close price over time of the DJIA. 

plt.plot(djia_data['Date'], djia_data['Close'])
plt.show()

image1.png

We can see that over the course of the year, the index price started at its highest value followed by some fluctuations up and down throughout the year. We see the price was lowest around October followed by a strong end of the year increase in price.

Line Plots with Multiple Lines

We can visualize multiple lines on the same plot by adding another plt.plot() call before the plt.show() function.

plt.plot(djia_data['Date'], djia_data['Open'])
plt.plot(djia_data['Date'], djia_data['Close'])


plt.show()

image16.png

Over the course of the year, we see that the open and close prices of the DJIA were relatively close to each other for each given day with no clear pattern of one always being above or below the other.  

Adding a Legend

If we want to distinguish which line represents which column, we can add a legend. This will create a color coded label in the corner of the graph. We can do this using plt.legend() and adding label parameters to each plt.plot() call. 

plt.plot(djia_data['Date'], djia_data['Open'], label = 'Open')
plt.plot(djia_data['Date'], djia_data['Close'], label = 'Close')
plt.legend()
plt.show()

image2.png

We now see a legend with the specified labels appear in the default location in the top right (location can be specified using the loc argument in plt.legend()). 

Drawing Bar Plots

Bar plots are very useful for comparing numerical values across categories. They are particularly helpful for finding the largest and smallest categories. 

For this section we will aggregate the data into monthly averages using pandas so that we can compare monthly performance during 2022 for the DJIA. We will also use the first 6 months to make the data easier to visualize.

# Import the calendar package 
from calendar import month_name

# Order by months by chronological order
djia_data['Month'] = pd.Categorical(djia_data['Date'].dt.month_name(), month_name[1:])

# Group metrics by monthly averages
djia_monthly_mean = djia_data \
    .groupby('Month') \
    .mean() \
    .reset_index()

djia_monthly_mean.head(6)

image5.png

Vertical Bar Plots

We will start by creating a bar chart with vertical bars. This can be done using the plt.bar() method with the first argument being the x-axis variable (Month) and the height parameter being the y-axis (Close).  We then want to make sure to call plt.show() to show our plot.

plt.bar(djia_monthly_mean['Month'], height = djia_monthly_mean['Close'])
plt.show()

image9.png

We see that most of the close prices of the DJIA were close to each other with the lowest average close value being in June and the highest average close value being in January.

Reordering Bars in Bar Plots

If we want to show these bars in order of highest to lowest Monthly average close price, we can sort the bars using the sort_values() method in pandas and then using the same plt.bar() method. 

djia_monthly_mean_srtd = djia_monthly_mean.sort_values(by = 'Close', ascending = False)

plt.bar(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'])
plt.show()

image10.png

As you can see, it is significantly easier to see which months had the highest average DJIA close price and which months had the lower averages. It is also easier to compare across months and rank the months.

Horizontal Bar Plots

  • Show how to swap the axes, so the bars are horizontal.
  • Provide brief commentary on the plot, including interpretation.

It is sometimes easier to interpret bar charts and read the labels when we make the bar plot with horizontal bars. We can do this using the plt.hbar() method.  

plt.barh(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'])
plt.show()

image17.png
As you can see, the labels of each category (month) are easier to read than when the bars were vertical. We can still easily compare across groups. This horizontal bar chart is especially useful when there are a lot of categories.

Drawing Scatter Plots

Scatterplots are very useful for identifying relationships between 2 numeric variables. This can give you a sense of what to expect in a variable when the other variable changes and can also be very informative in your decision to use different modeling techniques such as linear or non-linear regression.  

Scatter Plots

Similar to the other plots, a scatter plot can be created using pyplot.scatter() where the first argument is the x-axis variable and the second argument is the y-axis variable. In this example, we will look at the relationship between the open and close price of the DJIA.

plt.scatter(djia_data['Open'], djia_data['Close'])
plt.show()

image4.png

On the x-axis we have the open price of the DJIA and on the y-axis we have the close price. As we would expect, as the open price increases, we see a strong relationship in the close price increasing as well.

Scatter Plots with a Trend Line

Next, we will add a trend line to the graph to show the linear relationship between the open and close variables more explicitly. To do this, we will use the numpy polyfit() method and poly1d(). The first method will give us a least squares polynomial fit where the first argument is the x variable, the second variable is the y variable, and the third variable is the degrees of the fit (1 for linear). The second method will give us a one-dimensional polynomial class that we can use to create a trend line using plt.plot().

z = np.polyfit(djia_data['Open'], djia_data['Close'], 1)
p = np.poly1d(z)


plt.scatter(djia_data['Open'], djia_data['Close'])
plt.plot(djia_data['Open'], p(djia_data['Open']))
plt.show()

image20.png

As we can see, the line in the background of the graph follows the trend of the scatterplot closely as the relationship between open and close price is strongly linear. We see that as the open price increases, the close price generally increases at a similar and linear rate.

Setting the Plot Title and Axis Labels

Plot titles and axis labels make it significantly easier to understand a visualization and allow the viewer to quickly understand what they are looking at more clearly. We can do this by adding more layers using plt.xtitle(), plt.ylabel() and plt.xlabel() which we will demonstrate with the scatterplot we made in the previous section. 

plt.scatter(djia_data['Open'], djia_data['Close'])
plt.show()

image8.png

Changing Colors

Color can be a powerful tool in data visualizations for emphasizing certain points or telling a consistent story with consistent colors for a certain idea. In Matplotlib, we can change colors using named colors (e.g. "red", "blue", etc.), hex code ("#f4db9a", "#383c4a", etc.), and red-green-blue tuples (e.g. (125, 100, 37), (30, 54, 121), etc.). 

Lines

For a line plot, we can change a color using the color attribute in plt.plot(). Below, we change the color of our open price line to “black” and our close price line to “red.”

plt.plot(djia_data['Date'], djia_data['Open'], color = 'black')
plt.plot(djia_data['Date'], djia_data['Close'], color = 'red')
plt.show()

image3.png

Bars

For bars, we can pass a list into the color attribute to specify the color of each line. Let’s say we want to highlight the average price in January for a point we are trying to make about how strong the average close price was. We can do this by giving that bar a unique color to draw attention to it. 

plt.bar(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'], color = ['blue', 'gray', 'gray', 'gray', 'gray', 'gray'])
plt.show()

image11.png

Points

Finally, for scatter plots, we can change the color using the color attribute of plt.scatter(). We will color all points in January as blue and all other points as gray to show a similar story as in the above visualization. 

plt.scatter(djia_data[djia_data['Month'] == 'January']['Open'], djia_data[djia_data['Month'] == 'January']['Close'], color = 'blue')

plt.scatter(djia_data[djia_data['Month'] != 'January']['Open'], djia_data[djia_data['Month'] != 'January']['Close'], color = 'gray')

plt.show()

image18.png
Using Colormaps

Colormaps are built-in Matplotlib colors that scale based on the magnitude of the value (documentation here). The colormaps generally aesthetically look good together and help tell a story in the increasing values. 

We see in the below example, we use a colormap by passing the close price (y-variable) to the c attribute, and the plasma colormap through cmap. We see that as the values increase, the associated color gets brighter and more yellow while the lower end of the values is purple and darker.

plt.scatter(djia_data['Open'], djia_data['Close'], c=djia_data['Close'], cmap = plt.cm.plasma)

plt.show()

 image6.png

Setting Axis Limits

Sometimes, it is helpful to look at a specific range of values in a plot. For example, if the DJIA is currently trading around $30,000, we may only care about behavior around that price.  We can pass a tuple into the plt.xlim() and plt.ylim() to set x and y limits respectively. The first value in the tuple is the lower limit, and the second value in the tuple is the upper limit.

image15.png

Saving Plots

Finally, we can save plots that we create in matplotlib using the plt.savefig() method. We can save the file in many different file formats including ‘png,’ ‘pdf,’ and ‘svg’. The first argument is the filename. The format is inferred from the file extension (or you can override this with the format argument).

plt.scatter(djia_data['Open'], djia_data['Close'])
plt.savefig('DJIA 2022 Scatterplot Open vs. Close.png')

Take it to the Next Level

We have covered the basics of Matplotlib in this tutorial and you can now make basic line graphs, bar graphs, and scatter plots. Matplotlib is an advanced library with a lot of great features for creating aesthetically pleasing visualizations. If you would like to take your Matplotlib skills to the next level, take our Introduction to Data Visualization with Matplotlib course. You can also download our Matplotlib Cheat Sheet: Plotting in Python for reference as you start creating your own visualizations. 

Learn more about Python

Introduction to Data Visualization with Matplotlib

Beginner
4 hr
135K
Learn how to create, customize, and share data visualizations using Matplotlib.
See DetailsRight Arrow
Start Course
See MoreRight Arrow
Related

Precision-Recall Curve in Python Tutorial

Learn how to implement and interpret precision-recall curves in Python and discover how to choose the right threshold to meet your objective.
Vidhi Chugh's photo

Vidhi Chugh

14 min

An Introduction to Hierarchical Clustering in Python

Understand the ins and outs of hierarchical clustering and its implementation in Python
Zoumana Keita 's photo

Zoumana Keita

17 min

Association Rule Mining in Python Tutorial

Uncovering Hidden Patterns in Python with Association Rule Mining
Moez Ali's photo

Moez Ali

14 min

An Introduction to Python Subprocess: Basics and Examples

Explore our step-by-step guide to running external commands using Python's subprocess module, complete with examples.
Moez Ali's photo

Moez Ali

15 min

Setting Up VSCode For Python: A Complete Guide

Experience a simple, fun, and productive way of Python development by learning about VSCode and its extensionsn and features.
Abid Ali Awan's photo

Abid Ali Awan

16 min

GeoPandas Tutorial: An Introduction to Geospatial Analysis

Get started with GeoPandas, one of the most popular Python libraries for geospatial analysis.
Javier Canales Luna's photo

Javier Canales Luna

15 min

See MoreSee More