Course
Matplotlib is a powerful and very popular data visualization library in Python. In this tutorial, we will discuss how to create line plots, bar plots, and scatter plots in Matplotlib using stock market data in 2022. These are the foundational plots that will allow you to start understanding, visualizing, and telling stories about data. Data visualization is an essential skill for all data analysts and Matplotlib is one of the most popular libraries for creating visualizations.
This tutorial expects some basic prior knowledge in NumPy arrays and pandas dataframes. When we use those libraries, we will quickly explain what we are doing. The main focus of this tutorial is Matplotlib, which works on top of these data structures to create visualizations.
Matplotlib is very flexible and customizable for creating plots. It does require a lot of code to make more basic plots with little customizations. When working in a setting where exploratory data analysis is the main goal, requiring many quickly drawn plots without as much emphasis on aesthetics, the library seaborn is a great option as it builds on top of Matplotlib to create visualizations more quickly. Please see our Python Seaborn Tutorial For Beginners instead if exploratory data analysis or quick and easy graph creation is your main priority.
Matplotlib Examples
By the end of this tutorial, you will be able to make great-looking visualizations in Matplotlib. We will focus on creating line plots, bar plots, and scatter plots. We will also focus on how to make customization decisions, such as the use of color, how to label plots, and how to organize them in a clear way to tell a compelling story.
The Dataset
Matplotlib is designed to work with NumPy arrays and pandas dataframes. The library makes it easy to make graphs from tabular data. For this tutorial, we will use the Dow Jones Industrial Average (DJIA) index’s historical prices from 2022-01-01 to 2022-12-31 (found here). You can set the date range on the page and then click the “download a spreadsheet” button.
We will load in the csv file, named HistoricalPrices.csv
using the pandas
library and view the first rows using the .head()
method.
import pandas as pd
djia_data = pd.read_csv('HistoricalPrices.csv')
djia_data.head()
We see the data include 4 columns, a Date, Open, High, Low, and Close. The latter 4 are related to the price of the index during the trading day. Below is a brief explanation of each variable.
- Date: The day that the stock price information represents.
- Open: The price of the DJIA at 9:30 AM ET when the stock market opens.
- High: The highest price the DJIA reached during the day.
- Low: The lowest price the DJIA reached during the day.
- Close: The price of the DJIA when the market stopped trading at 4:00 PM ET.
As a quick clean up step, we will also need to use the rename()
method in pandas
as the dataset we downloaded has an extra space in the column names.
djia_data = djia_data.rename(columns = {' Open': 'Open', ' High': 'High', ' Low': 'Low', ' Close': 'Close'})
We will also ensure that the Date variable is a datetime variable and sort in ascending order by the date.
djia_data['Date'] = pd.to_datetime(djia_data['Date'])
djia_data = djia_data.sort_values(by = 'Date')
Loading Matplotlib
Next, we will load the pyplot
submodule of Matplotlib so that we can draw our plots. The pyplot
module contains all of the relevant methods we will need to create plots and style them. We will use the conventional alias plt
. We will also load in pandas, numpy, and datetime for future parts of this tutorial.
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from datetime import datetime
Drawing Line Plots
The first plot we will create will be a line plot. Line plots are a very important plot type as they do a great job of displaying time series data. It is often important to visualize how KPIs change over time to understand patterns in data that can be actioned on.
Line Plots with a Single Line
- Show how to draw a simple line plot with a single line.
- Make sure to emphasize the use of plt.show() so the plot actually displays.
- Provide brief commentary on the plot, including interpretation.
We can create a line plot in matplotlib using the plt.plot()
method where the first argument is the x variable and the second argument is the y variable in our line plot. Whenever we create a plot, we need to make sure to call plt.show()
to ensure we see the graph we have created. We will visualize the close price over time of the DJIA.
plt.plot(djia_data['Date'], djia_data['Close'])
plt.show()
We can see that over the course of the year, the index price started at its highest value followed by some fluctuations up and down throughout the year. We see the price was lowest around October followed by a strong end of the year increase in price.
Line Plots with Multiple Lines
We can visualize multiple lines on the same plot by adding another plt.plot()
call before the plt.show()
function.
plt.plot(djia_data['Date'], djia_data['Open'])
plt.plot(djia_data['Date'], djia_data['Close'])
plt.show()
Over the course of the year, we see that the open and close prices of the DJIA were relatively close to each other for each given day with no clear pattern of one always being above or below the other.
Adding a Legend
If we want to distinguish which line represents which column, we can add a legend. This will create a color coded label in the corner of the graph. We can do this using plt.legend()
and adding label parameters to each plt.plot()
call.
plt.plot(djia_data['Date'], djia_data['Open'], label = 'Open')
plt.plot(djia_data['Date'], djia_data['Close'], label = 'Close')
plt.legend()
plt.show()
We now see a legend with the specified labels appear in the default location in the top right (location can be specified using the loc
argument in plt.legend()
).
Drawing Bar Plots
Bar plots are very useful for comparing numerical values across categories. They are particularly helpful for finding the largest and smallest categories.
For this section we will aggregate the data into monthly averages using pandas so that we can compare monthly performance during 2022 for the DJIA. We will also use the first 6 months to make the data easier to visualize.
# Import the calendar package
from calendar import month_name
# Order by months by chronological order
djia_data['Month'] = pd.Categorical(djia_data['Date'].dt.month_name(), month_name[1:])
# Group metrics by monthly averages
djia_monthly_mean = djia_data \
.groupby('Month') \
.mean() \
.reset_index()
djia_monthly_mean.head(6)
Vertical Bar Plots
We will start by creating a bar chart with vertical bars. This can be done using the plt.bar()
method with the first argument being the x-axis variable (Month) and the height
parameter being the y-axis (Close). We then want to make sure to call plt.show()
to show our plot.
plt.bar(djia_monthly_mean['Month'], height = djia_monthly_mean['Close'])
plt.show()
We see that most of the close prices of the DJIA were close to each other with the lowest average close value being in June and the highest average close value being in January.
Reordering Bars in Bar Plots
If we want to show these bars in order of highest to lowest Monthly average close price, we can sort the bars using the sort_values()
method in pandas and then using the same plt.bar() method.
djia_monthly_mean_srtd = djia_monthly_mean.sort_values(by = 'Close', ascending = False)
plt.bar(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'])
plt.show()
As you can see, it is significantly easier to see which months had the highest average DJIA close price and which months had the lower averages. It is also easier to compare across months and rank the months.
Horizontal Bar Plots
- Show how to swap the axes, so the bars are horizontal.
- Provide brief commentary on the plot, including interpretation.
It is sometimes easier to interpret bar charts and read the labels when we make the bar plot with horizontal bars. We can do this using the plt.hbar()
method.
plt.barh(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'])
plt.show()
As you can see, the labels of each category (month) are easier to read than when the bars were vertical. We can still easily compare across groups. This horizontal bar chart is especially useful when there are a lot of categories.
Drawing Scatter Plots
Scatterplots are very useful for identifying relationships between 2 numeric variables. This can give you a sense of what to expect in a variable when the other variable changes and can also be very informative in your decision to use different modeling techniques such as linear or non-linear regression.
Scatter Plots
Similar to the other plots, a scatter plot can be created using pyplot.scatter()
where the first argument is the x-axis variable and the second argument is the y-axis variable. In this example, we will look at the relationship between the open and close price of the DJIA.
plt.scatter(djia_data['Open'], djia_data['Close'])
plt.show()
On the x-axis we have the open price of the DJIA and on the y-axis we have the close price. As we would expect, as the open price increases, we see a strong relationship in the close price increasing as well.
Scatter Plots with a Trend Line
Next, we will add a trend line to the graph to show the linear relationship between the open and close variables more explicitly. To do this, we will use the numpy polyfit()
method and poly1d()
. The first method will give us a least squares polynomial fit where the first argument is the x variable, the second variable is the y variable, and the third variable is the degrees of the fit (1 for linear). The second method will give us a one-dimensional polynomial class that we can use to create a trend line using plt.plot().
z = np.polyfit(djia_data['Open'], djia_data['Close'], 1)
p = np.poly1d(z)
plt.scatter(djia_data['Open'], djia_data['Close'])
plt.plot(djia_data['Open'], p(djia_data['Open']))
plt.show()
As we can see, the line in the background of the graph follows the trend of the scatterplot closely as the relationship between open and close price is strongly linear. We see that as the open price increases, the close price generally increases at a similar and linear rate.
Setting the Plot Title and Axis Labels
Plot titles and axis labels make it significantly easier to understand a visualization and allow the viewer to quickly understand what they are looking at more clearly. We can do this by adding more layers using plt.xtitle()
, plt.ylabel()
and plt.xlabel()
which we will demonstrate with the scatterplot we made in the previous section.
plt.scatter(djia_data['Open'], djia_data['Close'])
plt.show()
Changing Colors
Color can be a powerful tool in data visualizations for emphasizing certain points or telling a consistent story with consistent colors for a certain idea. In Matplotlib, we can change colors using named colors (e.g. "red", "blue", etc.), hex code ("#f4db9a", "#383c4a", etc.), and red-green-blue tuples (e.g. (125, 100, 37), (30, 54, 121), etc.).
Lines
For a line plot, we can change a color using the color attribute in plt.plot(). Below, we change the color of our open price line to “black” and our close price line to “red.”
plt.plot(djia_data['Date'], djia_data['Open'], color = 'black')
plt.plot(djia_data['Date'], djia_data['Close'], color = 'red')
plt.show()
Bars
For bars, we can pass a list into the color attribute to specify the color of each line. Let’s say we want to highlight the average price in January for a point we are trying to make about how strong the average close price was. We can do this by giving that bar a unique color to draw attention to it.
plt.bar(djia_monthly_mean_srtd['Month'], height = djia_monthly_mean_srtd['Close'], color = ['blue', 'gray', 'gray', 'gray', 'gray', 'gray'])
plt.show()
Points
Finally, for scatter plots, we can change the color using the color attribute of plt.scatter()
. We will color all points in January as blue and all other points as gray to show a similar story as in the above visualization.
plt.scatter(djia_data[djia_data['Month'] == 'January']['Open'], djia_data[djia_data['Month'] == 'January']['Close'], color = 'blue')
plt.scatter(djia_data[djia_data['Month'] != 'January']['Open'], djia_data[djia_data['Month'] != 'January']['Close'], color = 'gray')
plt.show()
Using Colormaps
Colormaps are built-in Matplotlib colors that scale based on the magnitude of the value (documentation here). The colormaps generally aesthetically look good together and help tell a story in the increasing values.
We see in the below example, we use a colormap by passing the close price (y-variable) to the c
attribute, and the plasma colormap through cmap
. We see that as the values increase, the associated color gets brighter and more yellow while the lower end of the values is purple and darker.
plt.scatter(djia_data['Open'], djia_data['Close'], c=djia_data['Close'], cmap = plt.cm.plasma)
plt.show()
Setting Axis Limits
Sometimes, it is helpful to look at a specific range of values in a plot. For example, if the DJIA is currently trading around $30,000, we may only care about behavior around that price. We can pass a tuple into the plt.xlim()
and plt.ylim()
to set x and y limits respectively. The first value in the tuple is the lower limit, and the second value in the tuple is the upper limit.
Saving Plots
Finally, we can save plots that we create in matplotlib using the plt.savefig()
method. We can save the file in many different file formats including ‘png,’ ‘pdf,’ and ‘svg’. The first argument is the filename. The format is inferred from the file extension (or you can override this with the format
argument).
plt.scatter(djia_data['Open'], djia_data['Close'])
plt.savefig('DJIA 2022 Scatterplot Open vs. Close.png')
Take it to the Next Level
We have covered the basics of Matplotlib in this tutorial and you can now make basic line graphs, bar graphs, and scatter plots. Matplotlib is an advanced library with a lot of great features for creating aesthetically pleasing visualizations. If you would like to take your Matplotlib skills to the next level, take our Introduction to Data Visualization with Matplotlib course. You can also download our Matplotlib Cheat Sheet: Plotting in Python for reference as you start creating your own visualizations.
Matplotlib FAQs
What is Matplotlib in Python?
Matplotlib is a popular data visualization library in Python. It's often used for creating static, interactive, and animated visualizations in Python. Matplotlib allows you to generate plots, histograms, bar charts, scatter plots, etc., with just a few lines of code.
Why should I use Matplotlib for data visualization?
There are several reasons. First, Matplotlib is flexible. It supports a broad array of graphs and plots, and it integrates well with many other Python libraries, like NumPy and pandas. Second, it's a mature and widely-used library, so it has a strong community and lots of resources and tutorials available. Lastly, because it's in Python, you can automate and customize your plots as part of your data pipelines.
How do I install Matplotlib?
You can install Matplotlib with pip, Python's package installer. Open your terminal and type: pip install matplotlib
. If you're using a Jupyter notebook, you can run this command in a code cell by prepending an exclamation mark: !pip install matplotlib
.
How do I create a basic plot in Matplotlib?
Here's a simple example. First, you'll need to import the Matplotlib library. The most commonly used module is pyplot
, and it's typically imported under the alias plt
:
import matplotlib.pyplot as plt
Then you can create a basic line plot like this:
plt.plot([1, 2, 3, 4]) plt.ylabel('Some Numbers') plt.show()
In this example, plt.plot([1, 2, 3, 4])
is used to plot the specified list of numbers. The plt.ylabel('Some Numbers')
line sets the label for the y-axis, and plt.show()
displays the plot.
Data Science writer | Senior Technical Marketing Analyst at Wayfair | MSE in Data Science at University of Pennsylvania
Learn more about Python
Course
Python for MATLAB Users
Course
Introduction to NumPy
blog
Intermediate Python for Data Science: Matplotlib
cheat-sheet
Matplotlib Cheat Sheet: Plotting in Python
tutorial
Line Plots in MatplotLib with Python
tutorial
Matplotlib time series line plot
tutorial
Python Seaborn Line Plot Tutorial: Create Data Visualizations
tutorial