Skip to content
Image Classification in PyTorch
  • AI Chat
  • Code
  • Report
  • Spinner

    Introduction to Image Classification with PyTorch

    We'll be using computer vision to answer the question that never gets old on the internet: is it a sloth or a pain au chocolat? This is a binary image classification task.

    In the age of deep learning, data scientists and machine learning engineers seldom create and train neural networks from scratch. A big chunk of what goes into performing a machine learning task, however, is collecting, preparing, and loading data to feed into a model. We'll perform a little bit of fine-tuning on the model, but this will not be the focus of the training.

    In this session, we'll be adapting code from PyTorch.org's tutorials on loading custom datasets to load a dataset we have collected into PyTorch.

    We'll then use this tutorial on transfer learning to perform an image processing task using a mostly-pretrained model, which we'll fine tune.

    Specifically, we'll be labeling images with one of two labels: sloth, or pain_au_chocolat.

    Package Imports

    Like all great Python projects, ours too, will start with some package imports! We'll use:

    • NumPy for manipulating numerical arrays
    • Matplotlib.pyplot for plotting
    • time, which provides time-related functions
    • os, a way of providing functionality that interacts with the operating system
    • copy, for copying objects
    • various packages from torch, including:
      • torch
      • torch.nn, which contains the basic building blocks for neural networks
      • torch.optim, a package containing various optimization algorithms for PyTorch
      • lr_scheduler from torch.optim, for adjusting the learning rate based on the number of epochs
      • torch.backends.cudnn as cudnn, a means for PyTorch to talk to the GPU (although GPUs may not be supported in your workspace)
    • torchvision, which provides additional functionalities to manipulate and process images, including
      • datasets, which contains built in datasets
      • models, containing models for various tasks, including image processing
      • transforms, which we'll use to transform images in preparation for image processing
    # Package imports go here
    import numpy as np
    import matplotlib.pyplot as plt
    
    import time
    import os
    import copy
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim import lr_scheduler
    from torch.backends import cudnn
    
    import torchvision
    from torchvision import datasets, models, transforms

    Initalizations

    For fast runtime, let's begin our project by setting cudnn.benchmark to True. You can read more about this here.

    # Enable cudnn benchmark
    cudnn.benchmark = True

    Reading and transforming the data

    While the PyTorch.org tutorial provides extensive information on loading, transforming, rescaling, cropping, and converting images to tensors using torch and torch.utils, we'll be using the torchvision package, which provides some frequently used data loaders and transforms out-of-the-box.

    One of the things it assumes is that data is organized in a certain way. Navigate to the data folder to see how the data is structured. Within the directory called "data/sloths_versus_pain_au_chocolat", there are two folders called "train" and "val". Our dataset contains two labels:

    • sloth, and
    • pain_au_chocolat so our folders are named and organized accordingly. Note that the images contained in the sloth, and pain_au_chocolat folders don't need to be named in any way, as long as the folders themselves are labelled correctly.

    To adapt this tutorial to use different data, all you need to do is change the names of the sloth, and pain_au_chocolat folders, and upload different images into them.

    When running code in notebooks, sometimes a file called .ipynb_checkpoints can show up in our training and validation folders. We'll remove these with the lines below.

    # Banish pesky .ipynb files
    !rm -R data/sloths_versus_pain_au_chocolat/train/.ipynb_checkpoints
    
    !rm -R data/sloths_versus_pain_au_chocolat/val/.ipynb_checkpoints

    We'll begin loading and transforming our data by defining the specific transforms we'd like to use from torchvision.

    The specific transforms we'll use on our training set are:

    • RandomResizedCrop(), used to crop a random portion of an image and resize it to a given size, passed as the first argument to the function
    • RandomHorizontalFlip(), used to horizontally flip an image randomly with a given probability (default is 0.5)
    • ToTensor(), used to convert an image or numpy.ndarray to a tensor
    • Normalize(), used to normalize a tensor image with given means and standard deviations, passed as lists as the first and second arguments, respectively (taking tensors as input). If the images are similar to ImageNet images, we can use the mean and standard deviation of the ImageNet dataset. These are:
      • mean = [0.485, 0.456, 0.406]
      • std = [0.229, 0.224, 0.225].

    The specific transforms we'll use on our validation set are:

    • Resize() used to resize an input to a given size, passed as the first argument
    • CenterCrop() to crop a given image at the center, based on dimensions provided in the first argument
    • ToTensor()
    • Normalize()
    # Create data transforms
    data_transforms = {
        "train": transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(), #try to introduce noise or variability 
         transforms.ToTensor(), #we want them to be in tensor format
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) #the model we are using have been trained using these
        ]),
        "val":transforms.Compose([
            transforms.Resize(256) ,
            transforms.CenterCrop(224) ,
            transforms.ToTensor() ,
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }

    Next, we'll:

    • create a data directory path containing our dataset
    • pass our directory to datasets.ImageFolder() to create a data loader called image_datasets, where the images are arranged in the same way our folders are currently structured
    • use image_datasets to obtain our training and validation dataset_sizes and class_names
    • pass image_datasets to torch.utils.data.DataLoader(), which enables us to sample from our dataset, using
      • batch_size = 4, which uses 4 images per batch
      • shuffle = True, which will shuffle the data at every epoch
    # Provide data directory
    data_dir = 'data/sloths_versus_pain_au_chocolat'
    
    # Create image folders for our training and validation data 
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                             data_transforms[x]
                                             )
                     for x in ['train', 'val']}
    
    # Obtain dataset sizes from image_datasets
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    
    # Obtain class_names from image_datasets
    class_names = image_datasets['train'].classes
    print(class_names)
    
    # Use image_datasets to sample from the dataset
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                                 shuffle=True)
                  for x in ['train', 'val']}