Pneumonia is one of the leading respiratory illnesses worldwide, and its timely and accurate diagnosis is essential for effective treatment. Manually reviewing chest X-rays is a critical step in this process, and AI can provide valuable support by helping to expedite the assessment. In your role as a consultant data scientist, you will test the ability of a deep learning model to distinguish pneumonia cases from normal images of lungs in chest X-rays.
By fine-tuning a pre-trained convolutional neural network, specifically the ResNet-18 model, your task is to classify X-ray images into two categories: normal lungs and those affected by pneumonia. You can leverage its already trained weights and get an accurate classifier trained faster and with fewer resources.
The Data
You have a dataset of chest X-rays that have been preprocessed for use with a ResNet-18 model. You can see a sample of 5 images from each category above. Upon unzipping the chestxrays.zip file (code provided below), you will find your dataset inside the data/chestxrays folder divided into test and train folders.
There are 150 training images and 50 testing images for each category, NORMAL and PNEUMONIA (300 and 100 in total). For your convenience, this data has already been loaded into a train_loader and a test_loader using the DataLoader class from the PyTorch library.
# # Make sure to run this cell to use torchmetrics.
!pip install torch torchvision torchmetrics# Import required libraries
# -------------------------
# Data loading
import random
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# Train model
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim
# Evaluate model
from torchmetrics import Accuracy, F1Score
# Set random seeds for reproducibility
torch.manual_seed(101010)
np.random.seed(101010)
random.seed(101010)import os
import zipfile
# Unzip the data folder
if not os.path.exists('data/chestxrays'):
with zipfile.ZipFile('data/chestxrays.zip', 'r') as zip_ref:
zip_ref.extractall('data')# Define the transformations to apply to the images for use with ResNet-18
transform_mean = [0.485, 0.456, 0.406]
transform_std =[0.229, 0.224, 0.225]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=transform_mean, std=transform_std)])
# Apply the image transforms
train_dataset = ImageFolder('data/chestxrays/train', transform=transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset) // 2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))!pip install torchmetrics
# Import required libraries
# Data loading
import random
import numpy as np
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# Train model
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim
# Evaluate model
from torchmetrics import Accuracy, F1Score# Set random seeds for reproducibility
torch.manual_seed(101010)
np.random.seed(101010)
random.seed(101010)# Unzip the data folder
if not os.path.exists('data/chestxrays'):
with zipfile.ZipFile('data/chestxrays.zip', 'r') as zip_ref:
zip_ref.extractall('data')
# Define the transformations to apply to the images for use with ResNet-18.
# The images need to be normalized to the same domain as the original training data of ResNet-18 network.
# We normalize the X-rays using transforms.Normalize function that takes as input the means and
# standard deviations of the three color channels, (R,G,B), from the original ResNet-18 training dataset.
transform_mean = [0.485, 0.456, 0.406]
transform_std =[0.229, 0.224, 0.225]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=transform_mean, std=transform_std)])
# Apply the image transforms
train_dataset = ImageFolder('data/chestxrays/train', transform=transform)
test_dataset = ImageFolder('data/chestxrays/test', transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset) // 2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))
Instantiate the model
# Load the pre-trained ResNet-18 model
resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)Modify the model
# Freeze the parameters of the model
for param in resnet18.parameters():
param.requires_grad = False
# Modify the final layer for binary classification
resnet18.fc = nn.Linear(resnet18.fc.in_features, 1)Define the training loop
# Model training/fine-tuning loop
def train(model, train_loader, criterion, optimizer, num_epochs):
# Train the model for the specified number of epochs
for epoch in range(num_epochs):
# Set the model to train mode
model.train()
# Initialize the running loss and accuracy
running_loss = 0.0
running_accuracy = 0
# Iterate over the batches of the train loader
for inputs, labels in train_loader:
# Zero the optimizer gradients
optimizer.zero_grad()
# Ensure labels have the same dimensions as outputs
labels = labels.float().unsqueeze(1)
# Forward pass
outputs = model(inputs)
preds = torch.sigmoid(outputs) > 0.5 # Binary classification
loss = criterion(outputs, labels)
# Backward pass and optimizer step
loss.backward()
optimizer.step()
# Update the running loss and accuracy
running_loss += loss.item() * inputs.size(0)
running_accuracy += torch.sum(preds == labels.data)
# Calculate the train loss and accuracy for the current epoch
train_loss = running_loss / len(train_dataset)
train_acc = running_accuracy.double() / len(train_dataset)
# Print the epoch results
print('Epoch [{}/{}], train loss: {:.4f}, train acc: {:.4f}'
.format(epoch+1, num_epochs, train_loss, train_acc))Fine-tune the model