Detecting Tuberculosis in X-Rays
📖 Background
Tuberculosis (TB) is one of the most common and deadly respiratory diseases in the world, causing about 1.25 million people in 2023. Doctors often use chest X-rays to help detect TB. However, looking at many X-rays by hand can be slow and difficult.
In this challenge, you will build a simple machine learning model that can help classify chest X-ray images into two groups:
- Healthy lungs
- Lungs affected by TB
This is not about building a “perfect” model. The focus should be on how you describe your process, decisions, and learnings.
🩻 The Data
You are given a small dataset from the Sakha-TB dataset:
- Training data: 150 healthy + 150 TB images (300 total)
- Test data: 50 healthy + 50 TB images (100 total)
These images are in the data.zip file at the root of the notebook. They will then be in the data/chestxray folder, which is further divided into test and train, both containing healthy and tb folders with the images inside.
💪 Challenge
You will train a model to classify chest X-rays. Your report should cover these questions:
-
Preprocessing
What steps did you take to make the images easier for a model to understand?
Some ideas to think about:- Did you resize the images to the same size?
- Did you convert them to grayscale or normalize the pixel values?
- Did you try any simple image transformations (e.g., small rotations)?
-
Modeling
Try at least two models and compare them.- One can be a simple model you build yourself (like a small CNN).
- Another can be a pre-trained model (like ResNet or MobileNet).
Explain what you tried and what differences you observed.
-
Evaluation
Evaluate your models on the test set. Report the following metrics in plain words:- Sensitivity (Recall for TB): How many TB cases your model correctly finds.
- Specificity: How many healthy cases your model correctly identifies.
- Positive Predictive Value (PPV): When your model says “TB”, how often it’s right.
- Negative Predictive Value (NPV): When your model says “Healthy”, how often it’s right.
👉 Tip: You don’t need to get the “best” numbers. Focus on explaining what the metrics mean and what you learned.
-
[Optional] ROC Curve
If you want, you can also draw a Receiver Operating Characteristic (ROC) curve to visualize how your model performance changes with different thresholds.
✅ Checklist before publishing
- Rename your workspace to make it descriptive of your work. N.B. you should leave the notebook name as notebook.ipynb.
- Remove redundant cells like the introduction to data science notebooks, so the workbook is focused on your story.
- Check that all the cells run without error.
⌛️ Time is ticking. Good luck!
import os
import zipfile
# Unzip the data folder
if not os.path.exists('data/chestxrays'):
with zipfile.ZipFile('data.zip', 'r') as zip_ref:
zip_ref.extractall()import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input# ============================
# 1. Prétraitement des données
# ============================
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
img_size = (224, 224)
batch_size = 16
train_dir = "data/chestxrays/train"
test_dir = "data/chestxrays/test"
# Vérification de l'existence des dossiers
if not os.path.isdir(train_dir):
raise FileNotFoundError(f"Le dossier d'entraînement n'existe pas : {train_dir}")
if not os.path.isdir(test_dir):
raise FileNotFoundError(f"Le dossier de test n'existe pas : {test_dir}")
# Data augmentation pour l'entraînement
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1,
horizontal_flip=True
)
test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
train_gen = train_datagen.flow_from_directory(
train_dir,
target_size=img_size,
batch_size=batch_size,
class_mode="binary"
)
test_gen = test_datagen.flow_from_directory(
test_dir,
target_size=img_size,
batch_size=batch_size,
class_mode="binary",
shuffle=False
)# ============================
# 2. Modèle CNN simple
# ============================
def build_simple_cnn(input_shape=(224,224,3)):
model = models.Sequential([
layers.Conv2D(32, (3,3), activation='relu', input_shape=input_shape),
layers.MaxPooling2D(2,2),
layers.Conv2D(64, (3,3), activation='relu'),
layers.MaxPooling2D(2,2),
layers.Conv2D(128, (3,3), activation='relu'),
layers.MaxPooling2D(2,2),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
return model
cnn_model = build_simple_cnn()
history_cnn = cnn_model.fit(
train_gen,
epochs=10,
validation_data=test_gen
)# ============================
# 3. Modèle pré-entraîné MobileNetV2
# ============================
base_model = MobileNetV2(input_shape=(224,224,3), include_top=False, weights="imagenet")
base_model.trainable = False # On fige les couches de base
mobilenet_model = models.Sequential([
base_model,
layers.GlobalAveragePooling2D(),
layers.Dense(128, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid')
])
mobilenet_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
history_mobilenet = mobilenet_model.fit(
train_gen,
epochs=10,
validation_data=test_gen
)
# ============================
# 4. Évaluation des modèles
# ============================
def evaluate_model(model, test_gen, name="Model"):
y_true = test_gen.classes
y_pred_prob = model.predict(test_gen).ravel()
y_pred = (y_pred_prob > 0.5).astype(int)
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp+fn)>0 else 0
specificity = tn / (tn + fp) if (tn+fp)>0 else 0
ppv = tp / (tp + fp) if (tp+fp)>0 else 0
npv = tn / (tn + fn) if (tn+fn)>0 else 0
print(f"\n📊 Résultats pour {name}:")
print(f"Sensibilité (Recall TB): {sensitivity:.2f}")
print(f"Spécificité: {specificity:.2f}")
print(f"Valeur Prédictive Positive (PPV): {ppv:.2f}")
print(f"Valeur Prédictive Négative (NPV): {npv:.2f}")
# ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.2f})')
return sensitivity, specificity, ppv, npv
plt.figure(figsize=(6,6))
evaluate_model(cnn_model, test_gen, "CNN simple")
evaluate_model(mobilenet_model, test_gen, "MobileNetV2")
plt.plot([0,1], [0,1], 'k--')
plt.xlabel("1 - Spécificité (FPR)")
plt.ylabel("Sensibilité (TPR)")
plt.title("Courbes ROC")
plt.legend()
plt.show()