Automated TB Detection from Chest X-Rays: Building and Comparing CNN Classifiers
Author: Najib Yusuf Ubandiya
Date: October 1, 2025
Executive Summary
This notebook presents a comparative analysis of two machine learning approaches for automated tuberculosis (TB) detection from chest X-ray images. Using a balanced dataset of 302 training images and 100 test images from the Sakha-TB dataset, we developed and evaluated:
- Custom CNN: A lightweight convolutional neural network built from scratch
- MobileNetV2: A transfer learning approach using pre-trained ImageNet weights
Results
| Model | Accuracy | Sensitivity | Specificity | AUC |
|---|---|---|---|---|
| Custom CNN | 51.0% | 66.0% | 36.0% | 0.6452 |
| MobileNetV2 | 82.0% | 64.0% | 100.0% | 0.9952 |
Main Findings
- Transfer learning outperformed custom CNN, achieving 82% accuracy compared to 51%
- MobileNetV2 demonstrated perfect specificity (100%), meaning zero false positives - no healthy patients were incorrectly flagged
- Perfect positive predictive value (100%) ensures that when TB is predicted, it's always correct
- Moderate sensitivity (64%) indicates 36% of TB cases are missed, suggesting the model works best as a screening tool requiring follow-up confirmation
- Near-perfect AUC (0.9952) confirms excellent discriminative ability
Clinical Implications
MobileNetV2 is well-suited for initial TB screening where avoiding false alarms is prioritized. However, its moderate sensitivity necessitates combining it with additional diagnostic methods to catch missed cases. The model's conservative nature (high specificity, perfect PPV) builds trust in positive diagnoses while minimizing unnecessary patient anxiety.
Technical Approach
- Preprocessing: Normalized pixel values and applied moderate augmentation (rotation, shift, zoom, flip)
- Training Strategy: Used early stopping and learning rate reduction to prevent overfitting
- Evaluation: Computed medical-specific metrics (Sensitivity, Specificity, PPV, NPV) alongside standard ML metrics
This analysis demonstrates that transfer learning is essential when working with limited medical imaging data, and that careful evaluation using domain-appropriate metrics is critical for healthcare applications.
SECTION 1: INTRODUCTION, SETUP & DATA EXPLORATION
1.1: Import Necessary Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
from PIL import Image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import (
confusion_matrix, classification_report,
roc_curve, auc, roc_auc_score
)
import warnings
warnings.filterwarnings('ignore')
np.random.seed(42)
tf.random.set_seed(42)
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")1.2: Extract and Set Up Data Paths
import zipfile
# Extract the data
zip_path = 'data.zip'
extract_path = '.' # Extract to current directory
# Check if already extracted
if not os.path.exists('data/chestxrays'):
print("Extracting data.zip...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_path)
print("Data extracted successfully!")
else:
print("Data already extracted.")
# Define paths (note: chestxrays with 's')
base_dir = Path('data/chestxrays')
train_dir = base_dir / 'train'
test_dir = base_dir / 'test'
train_healthy = train_dir / 'healthy'
train_tb = train_dir / 'tb'
test_healthy = test_dir / 'healthy'
test_tb = test_dir / 'tb'
# Verify structure and count images (excluding .DS_Store files)
def count_images(path):
return len([f for f in path.glob('*') if f.suffix.lower() in ['.jpg', '.jpeg', '.png']])
print(f"\nDataset Summary:")
print(f"{'='*50}")
print(f"Train - Healthy images: {count_images(train_healthy)}")
print(f"Train - TB images: {count_images(train_tb)}")
print(f"Test - Healthy images: {count_images(test_healthy)}")
print(f"Test - TB images: {count_images(test_tb)}")
print(f"{'='*50}")
print(f"Total training images: {count_images(train_healthy) + count_images(train_tb)}")
print(f"Total test images: {count_images(test_healthy) + count_images(test_tb)}")1.2.1 Data Loading and Setup
Data Structure
The dataset has been successfully extracted and organized into the following structure:
data/chestxrays/train/- Training imagesdata/chestxrays/test/- Test images
Each folder contains two subdirectories: healthy and tb, representing the two classes we aim to classify.
Dataset Distribution
| Split | Healthy | TB | Total |
|---|---|---|---|
| Train | 151 | 151 | 302 |
| Test | 50 | 50 | 100 |
Observations
- Balanced Dataset: Both classes have equal representation in training and test sets, which eliminates class imbalance concerns
- Small Dataset: With only 302 training images, we'll need to be mindful of overfitting and may benefit from data augmentation
- Clean Split: The train/test split is approximately 75/25, which is reasonable for model evaluation
The balanced nature of this dataset means that accuracy will be a meaningful metric, although we'll also compute medical-specific metrics (Sensitivity, Specificity, PPV, NPV) for a more comprehensive evaluation.
1.3: Data Exploration
1.3.1. Exploratory Data Analysis
Before building our models, let's explore the X-ray images to understand:
- Image dimensions and formats
- Visual characteristics of healthy vs TB-affected lungs
- Any preprocessing needs
1.3.2: Visualize Sample Images
def display_sample_images(healthy_path, tb_path, n_samples=5):
"""Display sample images from both classes side by side"""
fig, axes = plt.subplots(2, n_samples, figsize=(15, 6))
healthy_images = list(healthy_path.glob('*.jpg'))[:n_samples]
tb_images = list(tb_path.glob('*.jpg'))[:n_samples]
for idx, img_path in enumerate(healthy_images):
img = Image.open(img_path)
axes[0, idx].imshow(img, cmap='gray')
axes[0, idx].axis('off')
axes[0, idx].set_title(f'Healthy\n{img.size[0]}x{img.size[1]}', fontsize=10)
for idx, img_path in enumerate(tb_images):
img = Image.open(img_path)
axes[1, idx].imshow(img, cmap='gray')
axes[1, idx].axis('off')
axes[1, idx].set_title(f'TB\n{img.size[0]}x{img.size[1]}', fontsize=10)
axes[0, 0].set_ylabel('Healthy', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('TB', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()
print("Sample Chest X-rays from Training Set:")
display_sample_images(train_healthy, train_tb, n_samples=5)1.3.4: Analyze Image Properties
# Function to analyze image properties
def analyze_image_properties(path, label):
"""Analyze dimensions and color modes of images"""
dimensions = []
modes = []
for img_path in path.glob('*.jpg'):
img = Image.open(img_path)
dimensions.append(img.size)
modes.append(img.mode)
# Get unique dimensions and modes
unique_dims = set(dimensions)
unique_modes = set(modes)
print(f"\n{label} Images:")
print(f" Total: {len(dimensions)}")
print(f" Unique dimensions: {unique_dims}")
print(f" Color modes: {unique_modes}")
print(f" Most common size: {max(set(dimensions), key=dimensions.count)}")
return dimensions
# Analyze both training sets
print("="*60)
print("IMAGE PROPERTY ANALYSIS")
print("="*60)
healthy_dims = analyze_image_properties(train_healthy, "Healthy")
tb_dims = analyze_image_properties(train_tb, "TB")
# Visualize dimension distribution
all_widths = [d[0] for d in healthy_dims + tb_dims]
all_heights = [d[1] for d in healthy_dims + tb_dims]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].hist(all_widths, bins=20, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Width (pixels)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Image Width Distribution')
axes[0].grid(alpha=0.3)
axes[1].hist(all_heights, bins=20, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Height (pixels)')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Image Height Distribution')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()
print(f"\nDimension Statistics:")
print(f" Width range: {min(all_widths)} - {max(all_widths)} pixels")
print(f" Height range: {min(all_heights)} - {max(all_heights)} pixels")1.3.5 Findings of the Exploratory Data Analysis
Image Properties
After analyzing all training images, we discovered:
- Uniform Dimensions: All images are already resized to 224×224 pixels
- Color Mode: All images are in grayscale mode ('L'), which is appropriate for medical X-ray images
- Consistency: No variation in size across the dataset, which simplifies preprocessing
Visual Observations
From the sample X-rays displayed above:
- Healthy lungs typically show clear, well-defined lung fields with normal anatomical structures
- TB-affected lungs may show opacities, infiltrates, or abnormal patterns indicating infection
Preprocessing Implications
Since images are already:
- ✅ Uniformly sized (224×224)
- ✅ In grayscale format
- ✅ Consistent across all samples
Our preprocessing will focus on:
- Normalization: Scaling pixel values to [0, 1] range for neural network training
- Data Augmentation: Applying random transformations (rotation, zoom, flip) to increase training data diversity and prevent overfitting
- Train-Validation Split: Creating a validation set from training data to monitor model performance during training
SECTION 2. DATA PREPROCESSING AND AUGMENTATION
Preprocessing Strategy
Given our small dataset (302 training images), we'll implement the following preprocessing pipeline:
- Normalization: Scale pixel values from [0, 255] to [0, 1] for stable neural network training
- Data Augmentation: Apply random transformations to training data to artificially expand our dataset and improve model generalization
- Validation Split: Reserve 20% of training data for validation to monitor overfitting
Augmentation Techniques
We'll apply moderate augmentations appropriate for medical images:
- Rotation: ±15 degrees (lungs can appear at slight angles)
- Width/Height Shift: ±10% (account for positioning variations)
- Zoom: ±10% (handle different magnifications)
- Horizontal Flip: Mirror images (anatomically valid)
We'll avoid aggressive augmentations that might alter clinically relevant features.
2.1: Configure Image Data Generators