Skip to content
Project: Building an E-Commerce Clothing Classifier Model
  • AI Chat
  • Code
  • Report
  • Fashion Forward is a new AI-based e-commerce clothing retailer. They want to use image classification to automatically categorize new product listings, making it easier for customers to find what they're looking for. It will also assist in inventory management by quickly sorting items.

    As a data scientist tasked with implementing a garment classifier, your primary objective is to develop a machine learning model capable of accurately categorizing images of clothing items into distinct garment types such as shirts, trousers, shoes, etc.

    !pip install torchmetrics
    Hidden output
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from import Dataset, DataLoader
    from torchmetrics import Accuracy, Precision, Recall
    # Load datasets
    from torchvision import datasets
    import torchvision.transforms as transforms
    train_data = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
    test_data = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
    # Get the number of classes
    classes = train_data.classes
    num_classes = len(train_data.classes)
    # Define some relevant variables
    num_input_channels = 1
    num_output_channels = 16
    image_size = train_data[0][0].shape[1]
    # Define CNN
    class MultiClassImageClassifier(nn.Module):
        # Define the init method
        def __init__(self, num_classes):
            super(MultiClassImageClassifier, self).__init__()
            self.conv1 = nn.Conv2d(num_input_channels, num_output_channels, kernel_size=3, stride=1, padding=1)
            self.relu = nn.ReLU()
            self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
            self.flatten = nn.Flatten()
            # Create a fully connected layer
            self.fc = nn.Linear(num_output_channels * (image_size//2)**2, num_classes)
        def forward(self, x):
            # Pass inputs through each layer
            x = self.conv1(x)
            x = self.relu(x)
            x = self.maxpool(x)
            x = self.flatten(x)
            x = self.fc(x)
            return x
    # Define the training set DataLoader
    dataloader_train = DataLoader(
    # Define training function
    def train_model(optimizer, net, num_epochs):
        num_processed = 0
        criterion = nn.CrossEntropyLoss()
        for epoch in range(num_epochs):
            running_loss = 0
            num_processed = 0
            for features, labels in dataloader_train:
                outputs = net(features)
                loss = criterion(outputs, labels)
                running_loss += loss.item()
                num_processed += len(labels)
            print(f'epoch {epoch}, loss: {running_loss / num_processed}')
        train_loss = running_loss / len(dataloader_train)
    # Train for 1 epoch
    net = MultiClassImageClassifier(num_classes)
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    # Test the model on the test set
    # Define the test set DataLoader
    dataloader_test = DataLoader(
    # Define the metrics
    accuracy_metric = Accuracy(task='multiclass', num_classes=num_classes)
    precision_metric = Precision(task='multiclass', num_classes=num_classes, average=None)
    recall_metric = Recall(task='multiclass', num_classes=num_classes, average=None)
    # Run model on test set
    predicted = []
    for i, (features, labels) in enumerate(dataloader_test):
        output = net.forward(features.reshape(-1, 1, image_size, image_size))
        cat = torch.argmax(output, dim=-1)
        accuracy_metric(cat, labels)
        precision_metric(cat, labels)
        recall_metric(cat, labels)
    # Compute the metrics
    accuracy = accuracy_metric.compute().item()
    precision = precision_metric.compute().tolist()
    recall = recall_metric.compute().tolist()
    print('Accuracy:', accuracy)
    print('Precision (per class):', precision)
    print('Recall (per class):', recall)