Skip to content
0

Detecting Tuberculosis in X-Rays

📖 Background

Tuberculosis (TB) is one of the most common and deadly respiratory diseases in the world, causing about 1.25 million people in 2023. Doctors often use chest X-rays to help detect TB. However, looking at many X-rays by hand can be slow and difficult.

In this challenge, you will build a simple machine learning model that can help classify chest X-ray images into two groups:

  • Healthy lungs
  • Lungs affected by TB

This is not about building a “perfect” model. The focus should be on how you describe your process, decisions, and learnings.

🩻 The Data

 

You are given a small dataset from the Sakha-TB dataset:

  • Training data: 150 healthy + 150 TB images (300 total)
  • Test data: 50 healthy + 50 TB images (100 total)

These images are in the data.zip file at the root of the notebook. They will then be in the data/chestxray folder, which is further divided into test and train, both containing healthy and tb folders with the images inside.

💪 Challenge

You will train a model to classify chest X-rays. Your report should cover these questions:

  1. Preprocessing
    What steps did you take to make the images easier for a model to understand?
    Some ideas to think about:

    • Did you resize the images to the same size?
    • Did you convert them to grayscale or normalize the pixel values?
    • Did you try any simple image transformations (e.g., small rotations)?
  2. Modeling
    Try at least two models and compare them.

    • One can be a simple model you build yourself (like a small CNN).
    • Another can be a pre-trained model (like ResNet or MobileNet).
      Explain what you tried and what differences you observed.
  3. Evaluation
    Evaluate your models on the test set. Report the following metrics in plain words:

    • Sensitivity (Recall for TB): How many TB cases your model correctly finds.
    • Specificity: How many healthy cases your model correctly identifies.
    • Positive Predictive Value (PPV): When your model says “TB”, how often it’s right.
    • Negative Predictive Value (NPV): When your model says “Healthy”, how often it’s right.

    👉 Tip: You don’t need to get the “best” numbers. Focus on explaining what the metrics mean and what you learned.

  4. [Optional] ROC Curve
    If you want, you can also draw a Receiver Operating Characteristic (ROC) curve to visualize how your model performance changes with different thresholds.

✅ Checklist before publishing

  • Rename your workspace to make it descriptive of your work. N.B. you should leave the notebook name as notebook.ipynb.
  • Remove redundant cells like the introduction to data science notebooks, so the workbook is focused on your story.
  • Check that all the cells run without error.

⌛️ Time is ticking. Good luck!

import os
import zipfile

# Unzip the data folder
if not os.path.exists('data/chestxrays'):
    with zipfile.ZipFile('data.zip', 'r') as zip_ref:
        zip_ref.extractall()
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
# ============================
# 1. Prétraitement des données
# ============================

import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

img_size = (224, 224)
batch_size = 16

train_dir = "data/chestxrays/train"
test_dir = "data/chestxrays/test"

# Vérification de l'existence des dossiers
if not os.path.isdir(train_dir):
    raise FileNotFoundError(f"Le dossier d'entraînement n'existe pas : {train_dir}")
if not os.path.isdir(test_dir):
    raise FileNotFoundError(f"Le dossier de test n'existe pas : {test_dir}")

# Data augmentation pour l'entraînement
train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True
)

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

train_gen = train_datagen.flow_from_directory(
    train_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode="binary"
)

test_gen = test_datagen.flow_from_directory(
    test_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode="binary",
    shuffle=False
)
# ============================
# 2. Modèle CNN simple
# ============================

def build_simple_cnn(input_shape=(224,224,3)):
    model = models.Sequential([
        layers.Conv2D(32, (3,3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D(2,2),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(2,2),
        layers.Conv2D(128, (3,3), activation='relu'),
        layers.MaxPooling2D(2,2),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(128, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

cnn_model = build_simple_cnn()
history_cnn = cnn_model.fit(
    train_gen,
    epochs=10,
    validation_data=test_gen
)
# ============================
# 3. Modèle pré-entraîné MobileNetV2
# ============================

base_model = MobileNetV2(input_shape=(224,224,3), include_top=False, weights="imagenet")
base_model.trainable = False  # On fige les couches de base

mobilenet_model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1, activation='sigmoid')
])

mobilenet_model.compile(optimizer='adam',
                        loss='binary_crossentropy',
                        metrics=['accuracy'])

history_mobilenet = mobilenet_model.fit(
    train_gen,
    epochs=10,
    validation_data=test_gen
)
# ============================
# 4. Évaluation des modèles
# ============================

def evaluate_model(model, test_gen, name="Model"):
    y_true = test_gen.classes
    y_pred_prob = model.predict(test_gen).ravel()
    y_pred = (y_pred_prob > 0.5).astype(int)

    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    sensitivity = tp / (tp + fn) if (tp+fn)>0 else 0
    specificity = tn / (tn + fp) if (tn+fp)>0 else 0
    ppv = tp / (tp + fp) if (tp+fp)>0 else 0
    npv = tn / (tn + fn) if (tn+fn)>0 else 0

    print(f"\n📊 Résultats pour {name}:")
    print(f"Sensibilité (Recall TB): {sensitivity:.2f}")
    print(f"Spécificité: {specificity:.2f}")
    print(f"Valeur Prédictive Positive (PPV): {ppv:.2f}")
    print(f"Valeur Prédictive Négative (NPV): {npv:.2f}")

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_pred_prob)
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.2f})')
    return sensitivity, specificity, ppv, npv

plt.figure(figsize=(6,6))
evaluate_model(cnn_model, test_gen, "CNN simple")
evaluate_model(mobilenet_model, test_gen, "MobileNetV2")
plt.plot([0,1], [0,1], 'k--')
plt.xlabel("1 - Spécificité (FPR)")
plt.ylabel("Sensibilité (TPR)")
plt.title("Courbes ROC")
plt.legend()
plt.show()