DigiNsure Inc. is an innovative insurance company focused on enhancing the efficiency of processing claims and customer service interactions. Their newest initiative is digitizing all historical insurance claim documents, which includes improving the labeling of some IDs scanned from paper documents and identifying them as primary or secondary IDs.
To help them in their effort, you'll be using multi-modal learning to train an Optical Character Recognition (OCR) model. To improve the classification, the model will use images of the scanned documents as input and their insurance type (home, life, auto, health, or other). Integrating different data modalities (such as image and text) enables the model to perform better in complex scenarios, helping to capture more nuanced information. The labels that the model will be trained to identify are of two types: a primary and a secondary ID, for each image-insurance type pair.
# Import the necessary libraries
import matplotlib.pyplot as plt
import numpy as np
from project_utils import ProjectDataset
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# Load the data
dataset = pickle.load(open('ocr_insurance_dataset.pkl', 'rb'))
# Define a function to visualize codes with their corresponding types and labels
def show_dataset_images(dataset, num_images=5):
fig, axes = plt.subplots(1, min(num_images, len(dataset)), figsize=(20, 4))
for ax, idx in zip(axes, np.random.choice(len(dataset), min(num_images, len(dataset)), False)):
img, lbl = dataset[idx]
ax.imshow((img[0].numpy() * 255).astype(np.uint8).reshape(64,64), cmap='gray'), ax.axis('off')
ax.set_title(f"Type: {list(dataset.type_mapping.keys())[img[1].tolist().index(1)]}\nLabel: {list(dataset.label_mapping.keys())[list(dataset.label_mapping.values()).index(lbl)]}")
plt.show()
# Inspect 5 codes images from the dataset
show_dataset_images(dataset)1. Define the OCRModel class
import torch
import torch.nn as nn
class OCRModel(nn.Module):
def __init__(self):
super().__init__()
self.image_layer = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ELU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ELU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*16*16, 256),
nn.ELU(),
nn.Dropout(0.3),
nn.Linear(256, 128)
)
self.type_layer = nn.Sequential(
nn.Linear(5, 16),
nn.ELU(),
nn.Linear(16, 8)
)
self.classifier = nn.Sequential(
nn.Linear(128 + 8, 64),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(64, 2)
)
def forward(self, x_image, x_type):
x_image = self.image_layer(x_image)
x_type = self.type_layer(x_type)
x = torch.cat((x_image, x_type), dim=1)
return self.classifier(x)2. Define optimizer and loss functions
model = OCRModel()
optimizer = torch.optim.Adam(ocr.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()3. Train the model
import matplotlib.pyplot as plt
batch_losses = [] # Stocke chaque loss de batch
batch_numbers = [] # Compteur de batches pour l'axe x
for epoch in range(2):
for batch_idx, (img, lbl) in enumerate(dataset):
optimizer.zero_grad()
# Data prep
x_image = img[0].unsqueeze(0) # [1, 1, 64, 64]
x_type = img[1].unsqueeze(0) # [1, 5]
lbl_tensor = torch.tensor(lbl).unsqueeze(0) # [1]
# Forward pass
outputs = model(x_image, x_type)
loss = criterion(outputs, lbl_tensor)
# Backward pass
loss.backward()
optimizer.step()
batch_losses.append(loss.item())
batch_numbers.append(epoch * len(dataset) + batch_idx)
print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}")
# Visualisation détaillée
plt.figure(figsize=(10, 5))
plt.plot(batch_numbers, batch_losses, label='Batch Loss', linewidth=1, alpha=0.6)
plt.xlabel('Batch Number')
plt.ylabel('Loss')
plt.title('Detailed Training Loss Curve')
plt.legend()
plt.grid(True)
plt.show()4. Evaluation
ocr.eval()
correct = 0
total = 0
with torch.no_grad(): # Désactive le calcul du gradient pour économiser de la mémoire
for img, lbl in dataset:
# Ajout des dimensions batch (comme pendant l'entraînement)
x_image = img[0].unsqueeze(0) # [1, 1, 64, 64]
x_type = img[1].unsqueeze(0) # [1, 5]
# Prédiction
outputs = model(x_image, x_type)
_, predicted = torch.max(outputs.data, 1) # Récupère la classe prédite (0 ou 1)
# Mise à jour des compteurs
total += 1 # Increment total by 1 for each sample
correct += (predicted == lbl).sum().item()
accuracy = 100 * correct / total
print(f"Précision du modèle : {accuracy:.2f}%")