Skip to content
Project: Building an E-Commerce Clothing Classifier Model with Keras
Fashion Forward is a new AI-based e-commerce clothing retailer. They want to use image classification to automatically categorize new product listings, making it easier for customers to find what they're looking for. It will also assist in inventory management by quickly sorting items.
As a data scientist tasked with implementing a garment classifier, your primary objective is to develop a machine learning model capable of accurately categorizing images of clothing items into distinct garment types such as shirts, trousers, shoes, etc.
# Run the cells below firstfrom tensorflow.keras import datasets, layers, models, Sequential
from keras.layers import Dense, Conv2D, Flatten
from tensorflow.keras.utils import to_categorical(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data()Hidden output
import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i]/255.0, cmap='gray')
plt.show()# Preprocess the data
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# Define the CNN classifier model
model = Sequential()
model.add(Conv2D(32, (3, 3), strides=1, activation='relu', input_shape=(28, 28, 1)))
model.add(Conv2D(16, (3, 3), strides=1, activation='relu'))
model.add(Flatten())
model.add(Dense(10, activation='softmax')) # 10 units for 10 classes in Fashion MNIST
# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
model.fit(train_images, train_labels, epochs=1, batch_size=32)
# Evaluate the model on test images
test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=2)
print("Test Accuracy:", test_accuracy)