Skip to content
New Workbook
Sign up
Project: Developing Multi-Input Models For OCR

DigiNsure Inc. is an innovative insurance company focused on enhancing the efficiency of processing claims and customer service interactions. Their newest initiative is digitizing all historical insurance claim documents, which includes improving the labeling of some IDs scanned from paper documents and identifying them as primary or secondary IDs.

To help them in their effort, you'll be using multi-modal learning to train an Optical Character Recognition (OCR) model. To improve the classification, the model will use images of the scanned documents as input and their insurance type (home, life, auto, health, or other). Integrating different data modalities (such as image and text) enables the model to perform better in complex scenarios, helping to capture more nuanced information. The labels that the model will be trained to identify are of two types: a primary and a secondary ID, for each image-insurance type pair.

! pip install torchvision
Hidden output
# Import the necessary libraries
import matplotlib.pyplot as plt
import numpy as np
from project_utils import ProjectDataset
import pickle 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Load the data
dataset = pickle.load(open('ocr_insurance_dataset.pkl', 'rb'))

# Define a function to visualize codes with their corresponding types and labels 
def show_dataset_images(dataset, num_images=5):
    fig, axes = plt.subplots(1, min(num_images, len(dataset)), figsize=(20, 4))
    for ax, idx in zip(axes, np.random.choice(len(dataset), min(num_images, len(dataset)), False)):
        img, lbl = dataset[idx]
        ax.imshow((img[0].numpy() * 255).astype(np.uint8).reshape(64,64), cmap='gray'), ax.axis('off')
        ax.set_title(f"Type: {list(dataset.type_mapping.keys())[img[1].tolist().index(1)]}\nLabel: {list(dataset.label_mapping.keys())[list(dataset.label_mapping.values()).index(lbl)]}")
    plt.show()

# Inspect 5 codes images from the dataset
show_dataset_images(dataset)
img, lbl = dataset[0]
print("first comp:\n", img[0].shape)
print("second comp:\n",img[1])
print(lbl)
# Start coding here: create the class of the model 
class OCRModel(nn.Module):
    def __init__(self, num_classes):
        super(OCRModel, self).__init__()
        # Define sub-networks as sequential models
        self.image_layer = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ELU(),
            nn.Flatten(),
            nn.Linear(16*32*32, 128)
        )
        self.type_layer = nn.Sequential(
            nn.Linear(5, 8),
            nn.ELU(), 
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 + 8, num_classes), 
        )
    def forward(self, x_image, x_type):
	   	# Pass the x_image and x_type through appropriate layers
        x_image = self.image_layer(x_image)
        x_type = self.type_layer(x_type)
        # Concatenate x_image and x_type
        x = torch.cat((x_image, x_type), dim=1)
        return self.classifier(x)
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
#dataset definition
class DigiDataset(Dataset):
    def __init__(self, transform, samples):
		# Assign transform and samples to class attributes
        self.transform = transform
        self.samples = samples
                    
    def __len__(self):
		# Return number of samples
        return len(self.samples)

    def __getitem__(self, idx):
      	# Unpack the sample at index idx
        img, label = self.samples[idx]
        #img = Image.open(img_path).convert('L')
        # Transform the image
        img_extr = (img[0].numpy() * 255).astype(np.uint8)
        img_transformed = self.transform(img_extr)
        return img[0], img[1], label
    
dataset_train = DigiDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((64, 64)),
    ]),
    samples=dataset,
)

dataloader_train = DataLoader(
    dataset_train, shuffle=True, batch_size=3,
)
#training loop
# Define the model
model = OCRModel(num_classes=2)
# Define the loss function
criterion = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

for epoch in range(10):
    running_loss = 0.0
    for img, typ, labels in dataloader_train: 
        optimizer.zero_grad()
        outputs = model(img, typ)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(dataloader_train)
    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")