Skip to content
0

Detecting Tuberculosis in X-Rays

📖 Background

Tuberculosis (TB) is one of the most common and deadly respiratory diseases in the world, causing about 1.25 million people in 2023. Doctors often use chest X-rays to help detect TB. However, looking at many X-rays by hand can be slow and difficult.

In this challenge, you will build a simple machine learning model that can help classify chest X-ray images into two groups:

  • Healthy lungs
  • Lungs affected by TB

This is not about building a “perfect” model. The focus should be on how you describe your process, decisions, and learnings.

🩻 The Data

 

You are given a small dataset from the Sakha-TB dataset:

  • Training data: 150 healthy + 150 TB images (300 total)
  • Test data: 50 healthy + 50 TB images (100 total)

These images are in the data.zip file at the root of the notebook. They will then be in the data/chestxray folder, which is further divided into test and train, both containing healthy and tb folders with the images inside.

💪 Challenge

You will train a model to classify chest X-rays. Your report should cover these questions:

  1. Preprocessing
    What steps did you take to make the images easier for a model to understand?
    Some ideas to think about:

    • Did you resize the images to the same size?
    • Did you convert them to grayscale or normalize the pixel values?
    • Did you try any simple image transformations (e.g., small rotations)?
  2. Modeling
    Try at least two models and compare them.

    • One can be a simple model you build yourself (like a small CNN).
    • Another can be a pre-trained model (like ResNet or MobileNet).
      Explain what you tried and what differences you observed.
  3. Evaluation
    Evaluate your models on the test set. Report the following metrics in plain words:

    • Sensitivity (Recall for TB): How many TB cases your model correctly finds.
    • Specificity: How many healthy cases your model correctly identifies.
    • Positive Predictive Value (PPV): When your model says “TB”, how often it’s right.
    • Negative Predictive Value (NPV): When your model says “Healthy”, how often it’s right.

    👉 Tip: You don’t need to get the “best” numbers. Focus on explaining what the metrics mean and what you learned.

  4. [Optional] ROC Curve
    If you want, you can also draw a Receiver Operating Characteristic (ROC) curve to visualize how your model performance changes with different thresholds.

import os
import zipfile

# Unzip the data folder
if not os.path.exists('data/chestxrays'):
    with zipfile.ZipFile('data.zip', 'r') as zip_ref:
        zip_ref.extractall()
import tensorflow as tf

# Define the paths to the training and test directories
train_dir = 'data/chestxrays/train'
test_dir = 'data/chestxrays/test'

# Check if the directories exist before proceeding
if not os.path.isdir(train_dir):
    raise FileNotFoundError(f"Training directory not found: {train_dir}")
if not os.path.isdir(test_dir):
    raise FileNotFoundError(f"Test directory not found: {test_dir}")

# Set up image dimensions and batch size
img_height = 224
img_width = 224
batch_size = 32

# Load the training dataset
train_data = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='binary' # for only two classes
)

# Load the test dataset
test_data = tf.keras.utils.image_dataset_from_directory(
    test_dir,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='binary'
)

# Print a summary of the loaded datasets
print("\nTraining Data Info:")
print(f"Number of batches: {tf.data.experimental.cardinality(train_data).numpy()}")
print(f"Class names: {train_data.class_names}")

print("\nTest Data Info:")
print(f"Number of batches: {tf.data.experimental.cardinality(test_data).numpy()}")
print(f"Class names: {test_data.class_names}")
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# data generators for preprocessing and augmentation
# augment and normalize the train data
train_datagen = ImageDataGenerator(
    rescale=1./255,             # Normalize the pixel values
    rotation_range=10,          # Rotate the images by 10 degrees
    horizontal_flip=True,       # Flip the images horizontally
    zoom_range=0.1,             # Zoom in randomly
    validation_split=0.2        # Use some portion for validation
)

# normalize for test data
test_datagen = ImageDataGenerator(rescale=1./255)

This preprocessing steps was chosen so as to standardize data and prevent overfitting, rescaling the pixel values to a smaller range leads to a more faster and stable weights updates, using the rotating, flipping and zooming transformations, the model learns to recognize patterns from the original training images

from tensorflow.keras import layers, models
import numpy as np

# Define image dimensions
img_height, img_width = 224, 224

# Initialize ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)

# a simple CNN model
# convolutional layers to extract features and dense layers for classification
simple_model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu'),  
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid') # for binary classification
])

# Compile the model
simple_model.compile(optimizer='adam',
                     loss='binary_crossentropy',
                     metrics=['accuracy'])

# create a train generator 
train_generator = train_datagen.flow_from_directory(
    'data/chestxrays/train',
    target_size=(img_height, img_width),
    batch_size=32,
    class_mode='binary'
)

# Train the model 
simple_model.fit(
    train_generator,
    epochs=10
)

I trained and designed a CNN model from scratch with 10 epochs, the output showed an accuracy starting from 0.47 and improved to about 0.71 by epoch 9, the loss decreased from 1.16 to about 0.58 This means that the model is learning gradually, the accuracy above 70% suggests my CNN is capturing useful patterns from the xray dataset and the accuracy curve showing fluctuations could indicate that the training data is limited.

# create a train generator 
test_generator = test_datagen.flow_from_directory(
    'data/chestxrays/test',
    target_size=(img_height, img_width),
    batch_size= batch_size,
    class_mode='binary',
    shuffle=False
    )
from sklearn.metrics import confusion_matrix
import numpy as np

# Get true labels and predictions from the test generator
true_labels = test_generator.classes
predictions_prob = simple_model.predict(test_generator)
predictions = (predictions_prob > 0.5).astype(int) # Convert probabilities to binary predictions

# Compute the confusion matrix
cm = confusion_matrix(true_labels, predictions)
tn, fp, fn, tp = cm.ravel()

# Calculate the metrics
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0

print(f"Sensitivity (Recall for TB): {sensitivity:.2f}")
print(f"Specificity: {specificity:.2f}")
print(f"Positive Predictive Value (PPV): {ppv:.2f}")
print(f"Negative Predictive Value (NPV): {npv:.2f}")

Recall measures how well the model identifies positive cases, the recall of 0.60 means my model actually identifies 60% of patients out of all who actually had TB Spcificity measures how well the model identifies negative cases, the specificity of 0.76 means that out of all the healthy patients, my model correctly identified 76% of them. PPV measures how often the model is right in the case of predicting TB, a ppv of 0.71 means my model is correct 71% of time NPV measures how often the model is right in predicting healthy patients, a npv of 0.66 means my model is correct 66% of time

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

predictions_prob = simple_model.predict(test_generator).ravel() # .ravel() flattens the array

# Calculate the ROC curve metrics
fpr, tpr, thresholds = roc_curve(true_labels, predictions_prob)
roc_auc = auc(fpr, tpr)

# Plot the ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model

# MobileNetV2 model, pretrained on imagenet
base_model = MobileNetV2(
    input_shape=(img_height, img_width, 3),
    include_top=False, # the final classification layers is not needed
    weights='imagenet'
)

# Freeze the base model layers so they aren't trained
base_model.trainable = False

# new layers for our classification task
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(1, activation='sigmoid')(x)

# Create new model
transfer_model = Model(inputs=base_model.input, outputs=x)

# Compile and train the model
transfer_model.compile(optimizer='adam',
                       loss='binary_crossentropy',
                       metrics=['accuracy'])

# Train the model
transfer_model.fit(
    train_generator,
    epochs=5
)

For the secnd model, i used a pre-trained mobilenetv2 and trained for 5 epochs, the accuracy started at about 0.53 which is higher than first model, then it improved to 0.67 with only 5 epochs. The loss decreased from 0.77 to 0.59. This means that the mobilenetv2 would exceeded CNN's performance if more epochs is used because with just 5 epochs used it had an higher accuracy due to its knowledge gained on training huge datasets from imagenet

from sklearn.metrics import confusion_matrix
import numpy as np

# Get true labels and predictions from the test generator
true_labels = test_generator.classes
predictions_prob = transfer_model.predict(test_generator)
predictions = (predictions_prob > 0.5).astype(int) # Convert probabilities to binary predictions

# Compute the confusion matrix
cm = confusion_matrix(true_labels, predictions)
tn, fp, fn, tp = cm.ravel()

# Calculate the metrics
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv = tp / (tp + fp) if (tp + fp) > 0 else 0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0

print(f"Sensitivity (Recall for TB): {sensitivity:.2f}")
print(f"Specificity: {specificity:.2f}")
print(f"Positive Predictive Value (PPV): {ppv:.2f}")
print(f"Negative Predictive Value (NPV): {npv:.2f}")