Visualization in Python
One of the best ways to improve your data visualization skills is to try and replicate great visualizations you see out there. In this live code-along, we will take a look at how to recreate some amazing visualizations using Python so that you can take your data visualization skills to the next level.
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import seaborn as sns
⚾ Strikeouts in Baseball
The first visualization we will try and replicate is a sports piece published by the New York Times in 2012. It is a beautiful visualization illustrating how strikeouts were on the rise. The visualization showcases the strikeouts per game by team as well as the aggregated strikeouts per game for the whole league. Read the original article to get more context, and analyze the visualization carefully before attempting to replicate it.
The data for this visualization comes from an excellent database compiled by Sean Lahman that contains complete batting and pitching statistics from 1871 to 2020, plus fielding statistics, standings, team stats, managerial records, post-season data, and more.
teams = pd.read_csv('teams.csv')[['yearID', 'franchID', 'name', 'G', 'SO']]
# G - Games; SO - Strikeouts
teams.head()
# Precompute data
# Team level SOG - strikeouts per game
team_sog = (
teams
.query('yearID >= 1900')
.assign(SOG = lambda d: d.SO / d.G)
)
red_sox_sog = (
team_sog.query('name == "Boston Red Sox"')
)
league_sog = (
team_sog
.groupby('yearID', as_index=False)
.agg(SOG = ('SOG', 'mean'))
)
league_sog.head()
# plot data
# Plot setup
plt.rcParams['figure.figsize'] = (20,8)
plt.style.use('fivethirtyeight')
# Add a scatter plot layer for all SOGs.
plt.scatter(
team_sog['yearID'], # x-axis
team_sog['SOG'], # y-axis
color = 'gray',
alpha = 0.2 # transparency
)
# Add a line plot layer for specific team (Boston Red Sox)
plt.plot(
red_sox_sog['yearID'],
red_sox_sog['SOG'],
color = 'orange',
marker = 'o',
)
# Add a line plot layer for the entire league (Blue line)
plt.plot(
league_sog['yearID'],
league_sog['SOG'],
color = 'steelblue',
marker = 'o',
)
# Change axis limits
plt.ylim(-0.1, 13)
plt.axhline(xmin=0, color='black')
# Add text annotation layer
plt.text(
1914,
1,
s = 'US Enters World War I'
)
# Add title, subtitle
plt.text(1888, 15, s="Strikeout are on the rise", fontsize=24, fontweight='bold')
plt.text(1888, 14, s="There were more strikeouts...", fontsize=16)
🦠 COVID Cases by State
The second visualization we will try and replicate is also from the New York Times and was published on March 21st 2020 to visualize the spread of COVID by state. Read the original article to get a better understanding.
You will need two datasets to replicate this plot. The first dataset is provided by the New York Times and provides a time series of COVID cases by date. The second dataset provides a useful mapping of states to x-y coordinates on the grid. Use it wisely to place the different panels appropriately.
# COVID Cases by State
covid_cases = pd.read_csv("https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv")
# covid_cases.sort_values(by = 'date', ascending=False).head()
covid_cases.head()
# Grid Coordinates for States
# Source: https://github.com/hrbrmstr/statebins/blob/master/R/aaa.R
state_coords = pd.read_csv('state_coords.csv')
state_coords.head()
# Plot setup
from matplotlib.patches import Rectangle
plt.style.use('seaborn')
plt.rcParams['figure.figsize'] = (20,20)
fig = plt.figure()
# Define a grid object
gs = fig.add_gridspec(nrows=13, ncols=13)
# Plot lines for each state
for state in state_coords.to_dict(orient="records"):
ax = fig.add_subplot(gs[state['y'], state['x']])
ax.axes.xaxis.set_visible(False)
ax.axes.yaxis.set_visible(False)
state_name = state["state"]
d = (
covid_cases
.query('state == @state_name')
)
ax.plot(d["date"], d["cases"], linewidth=1)
ax.set_ylim(-1, covid_cases.cases.max())
ax.text(x=0, y=covid_cases.cases.max()*0.8, s=state['abbrev'], fontweight='bold', fontsize='large')
ax.add_patch(Rectangle((40,40), 360, covid_cases.cases.max(), color='yellow', alpha=0.4))
plt.suptitle("Number of new cases each day", fontsize=24, fontweight=2)
plt.tight_layout()
plt.show()