Introduction to Image Classification with PyTorch
We'll be using computer vision to answer the question that never gets old on the internet: is it a sloth or a pain au chocolat? This is a binary image classification task.
In the age of deep learning, data scientists and machine learning engineers seldom create and train neural networks from scratch. A big chunk of what goes into performing a machine learning task, however, is collecting, preparing, and loading data to feed into a model. We'll perform a little bit of fine-tuning on the model, but this will not be the focus of the training.
In this session, we'll be adapting code from PyTorch.org's tutorials on loading custom datasets to load a dataset we have collected into PyTorch.
We'll then use this tutorial on transfer learning to perform an image processing task using a mostly-pretrained model, which we'll fine tune.
Specifically, we'll be labeling images with one of two labels: sloth
, or pain_au_chocolat
.
Package Imports
Like all great Python projects, ours too, will start with some package imports! We'll use:
NumPy
for manipulating numerical arraysMatplotlib.pyplot
for plottingtime
, which provides time-related functionsos
, a way of providing functionality that interacts with the operating systemcopy
, for copying objects- various packages from torch, including:
torch
torch.nn
, which contains the basic building blocks for neural networkstorch.optim
, a package containing various optimization algorithms for PyTorchlr_scheduler
fromtorch.optim
, for adjusting the learning rate based on the number of epochstorch.backends.cudnn
ascudnn
, a means for PyTorch to talk to the GPU (although GPUs may not be supported in your workspace)
torchvision
, which provides additional functionalities to manipulate and process images, includingdatasets
, which contains built in datasetsmodels
, containing models for various tasks, including image processingtransforms
, which we'll use to transform images in preparation for image processing
# Package imports go here
____
Initalizations
For fast runtime, let's begin our project by setting cudnn.benchmark
to True. You can read more about this here.
# Enable cudnn benchmark
____
Reading and transforming the data
While the PyTorch.org tutorial provides extensive information on loading, transforming, rescaling, cropping, and converting images to tensors using torch
and torch.utils
, we'll be using the torchvision
package, which provides some frequently used data loaders and transforms out-of-the-box.
One of the things it assumes is that data is organized in a certain way. Navigate to the data folder to see how the data is structured. Within the directory called "data/sloths_versus_pain_au_chocolat", there are two folders called "train"
and "val"
. Our dataset contains two labels:
sloth
, andpain_au_chocolat
so our folders are named and organized accordingly. Note that the images contained in thesloth
, andpain_au_chocolat
folders don't need to be named in any way, as long as the folders themselves are labelled correctly.
To adapt this tutorial to use different data, all you need to do is change the names of the sloth
, and pain_au_chocolat
folders, and upload different images into them.
When running code in notebooks, sometimes a file called .ipynb_checkpoints
can show up in our training and validation folders. We'll remove these with the lines below.
# Banish pesky .ipynb files
____
We'll begin loading and transforming our data by defining the specific transforms we'd like to use from torchvision
.
The specific transforms we'll use on our training set are:
RandomResizedCrop()
, used to crop a random portion of an image and resize it to a given size, passed as the first argument to the functionRandomHorizontalFlip()
, used to horizontally flip an image randomly with a given probability (default is 0.5)ToTensor()
, used to convert an image ornumpy.ndarray
to a tensorNormalize()
, used to normalize a tensor image with given means and standard deviations, passed as lists as the first and second arguments, respectively (taking tensors as input). If the images are similar to ImageNet images, we can use the mean and standard deviation of the ImageNet dataset. These are:mean
= [0.485, 0.456, 0.406]std
= [0.229, 0.224, 0.225].
The specific transforms we'll use on our validation set are:
Resize()
used to resize an input to a given size, passed as the first argumentCenterCrop()
to crop a given image at the center, based on dimensions provided in the first argumentToTensor()
Normalize()
# Create data transforms
data_transforms = {
____
}
Next, we'll:
- create a data directory path containing our dataset
- pass our directory to
datasets.ImageFolder()
to create a data loader calledimage_datasets
, where the images are arranged in the same way our folders are currently structured - use
image_datasets
to obtain our training and validationdataset_sizes
andclass_names
- pass
image_datasets
totorch.utils.data.DataLoader()
, which enables us to sample from our dataset, usingbatch_size
= 4, which uses 4 images per batchshuffle
= True, which will shuffle the data at every epoch
# Provide data directory
data_dir = ____
# Create image folders for our training and validation data
image_datasets = ____
# Obtain dataset sizes from image_datasets
dataset_sizes = ____
# Obtain class_names from image_datasets
class_names = ____
# Use image_datasets to sample from the dataset
dataloaders = ____