Skip to content

Fashion Forward is a new AI-based e-commerce clothing retailer. They want to use image classification to automatically categorize new product listings, making it easier for customers to find what they're looking for. It will also assist in inventory management by quickly sorting items.

As a data scientist tasked with implementing a garment classifier, your primary objective is to develop a machine learning model capable of accurately categorizing images of clothing items into distinct garment types such as shirts, trousers, shoes, etc.

# Run the cells below first
!pip install torchmetrics
!pip install torchvision
Hidden output
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy, Precision, Recall
# Load datasets
from torchvision import datasets
import torchvision.transforms as transforms

train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
  • The transforms module is used to convert images to tensors, making them suitable for PyTorch operations.
  • The datasets.FashionMNIST class is used to load both the training and test datasets, applying the tensor transformation.
  • DataLoader is employed to batch the datasets, allowing efficient processing during training and evaluation.
# Define the CNN architecture
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # Adjust the input size here
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Example usage
model = CNN()

# Initialize the model, loss function, and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Model Training and Evaluation:

  • A CNN architecture is defined with convolutional layers, ReLU activations, max-pooling, and fully connected layers.
  • The model is trained using the Adam optimizer and CrossEntropy loss for a single epoch.
  • Finally, predictions are generated on the test set, and the accuracy, precision, and recall for each garment class are calculated and displayed.
# Train the model for 1 epoch
for epoch in range(1):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
import torch
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall

# Evaluate the model and store predictions
model.eval()
predictions = []
targets = []
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        predictions.extend(predicted.cpu().numpy())
        targets.extend(labels.cpu().numpy())

# Calculate accuracy, precision, and recall
num_classes = 10
accuracy = MulticlassAccuracy(num_classes=num_classes)(torch.tensor(predictions), torch.tensor(targets)).item()
precision = MulticlassPrecision(num_classes=num_classes, average=None)(torch.tensor(predictions), torch.tensor(targets)).tolist()
recall = MulticlassRecall(num_classes=num_classes, average=None)(torch.tensor(predictions), torch.tensor(targets)).tolist()

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)

Evaluation of Metrics:

  • Accuracy: Measures the overall correctness of the model. A high accuracy indicates that most garment items were correctly classified.
  • Precision: Evaluates how many of the items identified as a certain garment type were actually of that type. High precision means fewer false positives.
  • Recall: Measures how well the model identified all items of a specific type. High recall means fewer false negatives.

The provided code calculates these metrics to ensure that the garment classifier performs well in real-world scenarios, helping customers find products accurately and assisting inventory management effectively.