Skip to content

DigiNsure Inc. is an innovative insurance company focused on enhancing the efficiency of processing claims and customer service interactions. Their newest initiative is digitizing all historical insurance claim documents, which includes improving the labeling of some IDs scanned from paper documents and identifying them as primary or secondary IDs.

To help them in their effort, you'll be using multi-modal learning to train an Optical Character Recognition (OCR) model. To improve the classification, the model will use images of the scanned documents as input and their insurance type (home, life, auto, health, or other). Integrating different data modalities (such as image and text) enables the model to perform better in complex scenarios, helping to capture more nuanced information. The labels that the model will be trained to identify are of two types: a primary and a secondary ID, for each image-insurance type pair.

# Import the necessary libraries
import matplotlib.pyplot as plt
import numpy as np
from project_utils import ProjectDataset
import pickle 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Load the data
dataset = pickle.load(open('ocr_insurance_dataset.pkl', 'rb'))

# Define a function to visualize codes with their corresponding types and labels 
def show_dataset_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, min(num_images, len(dataset)), figsize=(20, 4))
    for ax, idx in zip(axes, np.random.choice(len(dataset), min(num_images, len(dataset)), False)):
        img, lbl = dataset[idx]
        ax.imshow((img[0].numpy() * 255).astype(np.uint8).reshape(64,64), cmap='gray'), ax.axis('off')
        ax.set_title(f"Type: {list(dataset.type_mapping.keys())[img[1].tolist().index(1)]}\nLabel: {list(dataset.label_mapping.keys())[list(dataset.label_mapping.values()).index(lbl)]}")
    plt.show()

# Inspect 5 codes images from the dataset
show_dataset_images(dataset)
# Start coding here

class OCRModel(nn.Module):
    def __init__(self):
        super(OCRModel, self).__init__()
        
        # Capa para procesar la imagen
        self.image_layer = nn.Sequential(
            # Capa convolucional: 1 canal entrada (escala de grises) → 16 canales
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # 64x64 → 32x32
            
            # Segunda capa convolucional
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # 32x32 → 16x16
            
            # Aplanar para capa lineal
            nn.Flatten(),  # 32 * 16 * 16 = 8192
            nn.Linear(32 * 16 * 16, 128),
            nn.ReLU()
        )
        
        # Capa para procesar el tipo de seguro (5 tipos codificados one-hot)
        self.type_layer = nn.Sequential(
            nn.Linear(5, 16),
            nn.ReLU()
        )
        
        # Clasificador final (combina imagen + tipo → 2 clases)
        self.classifier = nn.Sequential(
            nn.Linear(128 + 16, 64),  # 128 de imagen + 16 de tipo
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)  # 2 clases: Primary o Secondary ID
        )
    
    def forward(self, image, insurance_type):
        # Procesar imagen
        x_image = self.image_layer(image)
        
        # Procesar tipo de seguro
        x_type = self.type_layer(insurance_type)
        
        # Concatenar ambas representaciones
        x_combined = torch.cat([x_image, x_type], dim=1)
        
        # Clasificación final
        output = self.classifier(x_combined)
        
        return output
# Crea DataLoader para entrenamiento
# Ajusta batch_size según tu memoria disponible
dataloader_train = DataLoader(
    dataset, 
    batch_size=32, 
    shuffle=True
)

# Verificar la estructura de los datos
for images, labels in dataloader_train:
    print(f"Batch de imágenes: {images[0].shape}")  # Imagen
    print(f"Tipo de seguro: {images[1].shape}")     # One-hot encoding
    print(f"Etiquetas: {labels.shape}")
    break
# Instanciar el modelo
model = OCRModel()

# Definir la función de pérdida (Cross Entropy para clasificación)
criterion = nn.CrossEntropyLoss()

# Definir el optimizador (Adam es una buena opción)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Verificar si hay GPU disponible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Entrenando en: {device}")
# Número de épocas
num_epochs = 10

# Listas para guardar métricas
train_losses = []
train_accuracies = []

print("Iniciando entrenamiento...")
print("=" * 50)

for epoch in range(num_epochs):
    model.train()  # Modo entrenamiento
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, labels) in enumerate(dataloader_train):
        # Separar imagen y tipo de seguro
        images = inputs[0].to(device)           # Imágenes
        insurance_types = inputs[1].to(device)  # Tipos de seguro
        labels = labels.to(device)
        
        # Limpiar gradientes
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, insurance_types)
        
        # Calcular pérdida
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Actualizar pesos
        optimizer.step()
        
        # Estadísticas
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    # Calcular métricas de la época
    epoch_loss = running_loss / len(dataloader_train)
    epoch_accuracy = 100 * correct / total
    
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_accuracy)
    
    print(f"Época [{epoch+1}/{num_epochs}] - "
          f"Pérdida: {epoch_loss:.4f} - "
          f"Precisión: {epoch_accuracy:.2f}%")

print("=" * 50)
print("¡Entrenamiento completado!")
# Graficar pérdida y precisión
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Gráfico de pérdida
ax1.plot(range(1, num_epochs + 1), train_losses, 'b-', marker='o')
ax1.set_xlabel('Época')
ax1.set_ylabel('Pérdida')
ax1.set_title('Pérdida durante el Entrenamiento')
ax1.grid(True)

# Gráfico de precisión
ax2.plot(range(1, num_epochs + 1), train_accuracies, 'g-', marker='o')
ax2.set_xlabel('Época')
ax2.set_ylabel('Precisión (%)')
ax2.set_title('Precisión durante el Entrenamiento')
ax2.grid(True)

plt.tight_layout()
plt.show()
def evaluate_model(model, dataloader):
    model.eval()  # Modo evaluación
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            images = inputs[0].to(device)
            insurance_types = inputs[1].to(device)
            labels = labels.to(device)
            
            outputs = model(images, insurance_types)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Precisión final del modelo: {accuracy:.2f}%")
    return accuracy

# Evaluar
final_accuracy = evaluate_model(model, dataloader_train)
# Guardar el modelo entrenado
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': train_losses[-1],
    'accuracy': train_accuracies[-1]
}, 'ocr_model_digiNsure.pth')

print("Modelo guardado exitosamente!")

# ---
# 🎓 Conceptos Clave del Proyecto

# Multi-Modal Learning
# - Combina imagen (visual) + tipo de seguro (categórico)
# - Cada modalidad se procesa por separado y luego se fusiona
# - Mejora el rendimiento al capturar información complementaria

# Arquitectura
# Imagen (64x64) → image_layer → 128 features
#                                         ↓ concatenar
# Tipo (5 clases) → type_layer → 16 features
#                                         ↓
#                                 144 features → classifier → 2 clases