Skip to content
1. Import libraries
import os
import sys
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# implement pip as a subprocess:
subprocess.check_call([sys.executable, '-m', 'pip', 'install',
'opencv-python==4.5.3.56'])
from pathlib import Path
from collections import Counter
from sklearn.model_selection import train_test_split
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
print("Tensorflow version: ", tf.__version__)
print("Numpy version: ", np.__version__)
seed = 5647
np.random.seed(seed)
tf.random.set_seed(seed)
2. Train dataset
# Path to the data directory
data_dir = Path("./data/train/")
# Get list of all the images
images = list(data_dir.glob("*.png"))
print("Number of images found: ", len(images))
# Let's take a look at some samples first.
# Always look at your data!
sample_images = images[:4]
_,ax = plt.subplots(2,2, figsize=(5,3))
for i in range(4):
img = cv2.imread(str(sample_images[i]))
print("Shape of image: ", img.shape)
ax[i//2, i%2].imshow(img)
ax[i//2, i%2].axis('off')
plt.show()
# Path to the data directory
data_dir = Path("./data/train/")
# Get list of all the images
images = list(data_dir.glob("*.png"))
# Store all the characters in a set
characters = set()
# A list to store the length of each captcha
captcha_length = []
# Store image-label info
dataset = []
# Iterate over the dataset and store the
# information needed
for img_path in images:
# 1. Get the label associated with each image
label = img_path.name.split(".png")[0]
# 2. Store the length of this cpatcha
captcha_length.append(len(label))
# 3. Store the image-label pair info
dataset.append((str(img_path), label))
# 4. Store the characters present
for ch in label:
characters.add(ch)
# Sort the characters
characters = sorted(characters)
# Convert the dataset info into a dataframe
dataset = pd.DataFrame(dataset, columns=["img_path", "label"], index=None)
# Shuffle the dataset
dataset = dataset.sample(frac=1.).reset_index(drop=True)
print("Number of unqiue charcaters in the whole dataset: ", len(characters))
print("Maximum length of any captcha: ", max(Counter(captcha_length).keys()))
print("Characters present: ", characters)
print("Total number of samples in the dataset: ", len(dataset))
dataset.head()
3. Split train data in train and valid set
# Split the dataset into training and validation sets
training_data, validation_data = train_test_split(dataset, test_size=0.10, random_state=seed)
training_data = training_data.reset_index(drop=True)
validation_data = validation_data.reset_index(drop=True)
print("Number of training samples: ", len(training_data))
print("Number of validation samples: ", len(validation_data))
# Map text to numeric labels
char_to_labels = {char:idx for idx, char in enumerate(characters)}
# Map numeric labels to text
labels_to_char = {val:key for key, val in char_to_labels.items()}
# Sanity check for corrupted images
def is_valid_captcha(captcha):
for ch in captcha:
if not ch in characters:
return False
return True
# Store arrays in memory as it's not a muvh big dataset
def generate_arrays(df, resize=True, img_height=50, img_width=200):
"""Generates image array and labels array from a dataframe.
Args:
df: dataframe from which we want to read the data
resize (bool) : whether to resize images or not
img_weidth (int): width of the resized images
img_height (int): height of the resized images
Returns:
images (ndarray): grayscale images
labels (ndarray): corresponding encoded labels
"""
num_items = len(df)
images = np.zeros((num_items, img_height, img_width), dtype=np.float32)
labels = [0]*num_items
for i in range(num_items):
img = cv2.imread(df["img_path"][i])
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
if resize:
img = cv2.resize(img, (img_width, img_height))
img = (img/255.).astype(np.float32)
label = df["label"][i]
# Add only if it is a valid captcha
if is_valid_captcha(label):
images[i, :, :] = img
labels[i] = label
return images, np.array(labels)
# Build training data
training_data, training_labels = generate_arrays(df=training_data)
print("Number of training images: ", training_data.shape)
print("Number of training labels: ", training_labels.shape)
# Build validation data
validation_data, validation_labels = generate_arrays(df=validation_data)
print("Number of validation images: ", validation_data.shape)
print("Number of validation labels: ", validation_labels.shape)
class DataGenerator(keras.utils.Sequence):
"""Generates batches from a given dataset.
Args:
data: training or validation data
labels: corresponding labels
char_map: dictionary mapping char to labels
batch_size: size of a single batch
img_width: width of the resized
img_height: height of the resized
downsample_factor: by what factor did the CNN downsample the images
max_length: maximum length of any captcha
shuffle: whether to shuffle data or not after each epoch
Returns:
batch_inputs: a dictionary containing batch inputs
batch_labels: a batch of corresponding labels
"""
def __init__(self,
data,
labels,
char_map,
batch_size=16,
img_width=200,
img_height=50,
downsample_factor=4,
max_length=5,
shuffle=True
):
self.data = data
self.labels = labels
self.char_map = char_map
self.batch_size = batch_size
self.img_width = img_width
self.img_height = img_height
self.downsample_factor = downsample_factor
self.max_length = max_length
self.shuffle = shuffle
self.indices = np.arange(len(data))
self.on_epoch_end()
def __len__(self):
return int(np.ceil(len(self.data) / self.batch_size))
def __getitem__(self, idx):
# 1. Get the next batch indices
curr_batch_idx = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
# 2. This isn't necessary but it can help us save some memory
# as not all batches the last batch may not have elements
# equal to the batch_size
batch_len = len(curr_batch_idx)
# 3. Instantiate batch arrays
batch_images = np.ones((batch_len, self.img_width, self.img_height, 1),
dtype=np.float32)
batch_labels = np.ones((batch_len, self.max_length), dtype=np.float32)
input_length = np.ones((batch_len, 1), dtype=np.int64) * \
(self.img_width // self.downsample_factor - 2)
label_length = np.zeros((batch_len, 1), dtype=np.int64)
for j, idx in enumerate(curr_batch_idx):
# 1. Get the image and transpose it
img = self.data[idx].T
# 2. Add extra dimenison
img = np.expand_dims(img, axis=-1)
# 3. Get the correpsonding label
text = self.labels[idx]
# 4. Include the pair only if the captcha is valid
if is_valid_captcha(text):
label = [self.char_map[ch] for ch in text]
batch_images[j] = img
batch_labels[j] = label
label_length[j] = len(text)
batch_inputs = {
'input_data': batch_images,
'input_label': batch_labels,
'input_length': input_length,
'label_length': label_length,
}
return batch_inputs, np.zeros(batch_len).astype(np.float32)
def on_epoch_end(self):
if self.shuffle:
np.random.shuffle(self.indices)
# Batch size for training and validation
batch_size = 16
# Desired image dimensions
img_width=200
img_height=50
# Factor by which the image is going to be downsampled
# by the convolutional blocks
downsample_factor=4
# Maximum length of any captcha in the data
max_length=5
# Get a generator object for the training data
train_data_generator = DataGenerator(data=training_data,
labels=training_labels,
char_map=char_to_labels,
batch_size=batch_size,
img_width=img_width,
img_height=img_height,
downsample_factor=downsample_factor,
max_length=max_length,
shuffle=True
)
# Get a generator object for the validation data
valid_data_generator = DataGenerator(data=validation_data,
labels=validation_labels,
char_map=char_to_labels,
batch_size=batch_size,
img_width=img_width,
img_height=img_height,
downsample_factor=downsample_factor,
max_length=max_length,
shuffle=False
)
4. Create model
class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost
def call(self, y_true, y_pred, input_length, label_length):
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# On test time, just return the computed loss
return loss
def build_model():
# Inputs to the model
input_img = layers.Input(shape=(img_width, img_height, 1),
name='input_data',
dtype='float32')
labels = layers.Input(name='input_label', shape=[max_length], dtype='float32')
input_length = layers.Input(name='input_length', shape=[1], dtype='int64')
label_length = layers.Input(name='label_length', shape=[1], dtype='int64')
# First conv block
x = layers.Conv2D(32,
(3,3),
activation='relu',
kernel_initializer='he_normal',
padding='same',
name='Conv1')(input_img)
x = layers.MaxPooling2D((2,2), name='pool1')(x)
# Second conv block
x = layers.Conv2D(64,
(3,3),
activation='relu',
kernel_initializer='he_normal',
padding='same',
name='Conv2')(x)
x = layers.MaxPooling2D((2,2), name='pool2')(x)
# We have used two max pool with pool size and strides of 2.
# Hence, downsampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing it to RNNs
new_shape = ((img_width // 4), (img_height // 4)*64)
x = layers.Reshape(target_shape=new_shape, name='reshape')(x)
x = layers.Dense(64, activation='relu', name='dense1')(x)
x = layers.Dropout(0.2)(x)
# RNNs
x = layers.Bidirectional(layers.LSTM(128,return_sequences=True,dropout=0.2))(x)
x = layers.Bidirectional(layers.LSTM(64,return_sequences=True,dropout=0.25))(x)
# Predictions
x = layers.Dense(len(characters)+1,
activation='softmax',
name='dense2',
kernel_initializer='he_normal')(x)
# Calculate CTC
output = CTCLayer(name='ctc_loss')(labels, x, input_length, label_length)
# Define the model
model = keras.models.Model(inputs=[input_img,
labels,
input_length,
label_length],
outputs=output,
name='ocr_model_v1')
# Optimizer
sgd = keras.optimizers.SGD(learning_rate=0.002,
decay=1e-6,
momentum=0.9,
nesterov=True,
clipnorm=5)
# Compile the model and return
model.compile(optimizer=sgd)
return model
model = build_model()
model.summary()
5. Train model
# Add early stopping
es = keras.callbacks.EarlyStopping(monitor='val_loss',
patience=5,
restore_best_weights=True)
# Train the model
history = model.fit(train_data_generator,
validation_data=valid_data_generator,
epochs=50,
callbacks=[es])
plt.plot(history.history['loss'],label='Train Loss')
plt.plot(history.history['val_loss'],label='Valid Loss')
plt.show()