Skip to content

Demystifying AI: Forecasting Demand with SHAP Insights


Description

Uncover the power of Explainable AI by forecasting weekly demand for 'Home Decor' products. Using SHAP, learn to interpret model predictions, identify critical features, and refine your forecasting model. Gain hands-on experience in creating transparent, impactful, and data-driven solutions.

Rizzles Bizzles is an Home Decor e-commerce store based in Bristol, UK. They have an existing data science team that has transformed the raw transactions dataset into the format you will see below. Their goal is to build an accurate weekly demand forecasting model so that the ops team could order the right quantity of products from their suppliers.

You have been asked to review the initial feature set and default model, using explainable AI, specifically, the shap library to explain the model's predictions and answer stakeholder questions.

The Data

They have provided you with a single dataset to use. A summary and preview are provided below.

It is a modified version of the original data, which is publicly available on Kaggle.

home_decor_train/test.csv

ColumnDescription
'ProductName'The product or item name.
'WeekEndingDate'The date (always a Sunday) marking the end of the week for which the data applies.
'WeeklyMinimumPrice'The minimum transacted price (in £) for that product during the given week.
'WeeklyMedianPrice'The median transacted price (in £) for that product during the given week.
'WeeklyMeanPrice'The mean (average) transacted price (in £) for that product during the given week.
'WeeklyMaxPrice'The maximum transacted price (in £) for that product during the given week.
'WeeklyRevenue'The total revenue (in £) for that product in the given week.
'WeeklyQuantityDemanded'The total quantity demanded (units sold) of the product in the given week.
'L1ProdCat'The broadest (Level 1) inferred product category to which the product belongs.
'L2ProdCat'The mid-level (Level 2) inferred product category to which the product belongs.
'L3ProdCat'The most granular (Level 3) inferred product category to which the product belongs.
'Momentum_WeeklyQuantity'Week-over-week percentage change in WeeklyQuantityDemanded. Computed with .pct_change(), so it compares the current week's demand to the immediately previous week for each product.
'AbsoluteMomentum_WeeklyQuantity'Week-over-week absolute change in WeeklyQuantityDemanded. Computed with .diff(), so it captures how many units demand has risen or fallen compared to the previous week.
'SmoothedMomentum_WeeklyQuantity'Exponential weighted moving average (EWMA) of Momentum_WeeklyQuantity (span=2), giving more weight to recent momentum values and smoothing out short-term fluctuations.
'MomentumVolatility'Rolling standard deviation of Momentum_WeeklyQuantity over a 3-week window (.rolling(window=3)). Reflects how much the week-to-week percentage changes fluctuate.
'Acceleration_WeeklyQuantity'The second derivative of demand. Computed as the week-over-week change of Momentum_WeeklyQuantity (i.e., .diff() on the momentum column). Indicates how quickly momentum is increasing or decreasing.
'PriceMomentumImpact'Ratio of WeeklyMeanPrice to (1 + abs(Momentum_WeeklyQuantity)). Gauges whether price is large relative to the magnitude of weekly momentum (higher values suggest weaker momentum relative to price).
'RollingMomentumMean'3-week rolling average of Momentum_WeeklyQuantity (.rolling(window=3)). Smooths out outliers to show the short-term average momentum trend.
'MomentumDirection'Binary indicator (1 if Momentum_WeeklyQuantity > 0, else 0), showing whether week-over-week demand is trending upward or downward.
'LongTermMomentumMean'6-week rolling average of Momentum_WeeklyQuantity (.rolling(window=6)) to capture longer-term demand shifts.
'CumulativeQuantityDemanded'Cumulative sum of WeeklyQuantityDemanded within each product's time series (.cumsum()). Tracks total units sold up to the current week.
'RollingQuantityGrowthRate'Computes a 4-week rolling mean of WeeklyQuantityDemanded, then compares it to the 4-week rolling mean from 4 weeks prior. ((\text{Current4WMean} - \text{Prev4WMean}) ,/, \text{Prev4WMean}). Indicates medium-term growth or decline.
'WeekNumber'ISO week number of the year (.dt.isocalendar().week). Captures potential weekly seasonality patterns (e.g., holiday effects).
'AccelerationDecay'EWMA (span=4) of Acceleration_WeeklyQuantity. Smooths out the second-derivative changes to detect persistent accelerations or decelerations in demand.
'IsPeak_ST'Short-term peak indicator: 1 if WeeklyQuantityDemanded equals the rolling 3-week max; otherwise 0. Flags a local high within the last 3 weeks.
'IsPeak_LT'Long-term peak indicator: 1 if WeeklyQuantityDemanded equals the rolling 9-week max; otherwise 0. Flags a local high within the last 9 weeks.
'IsTrough_ST'Short-term trough indicator: 1 if WeeklyQuantityDemanded equals the rolling 3-week min; otherwise 0. Flags a local low within the last 3 weeks.
'IsTrough_LT'Long-term trough indicator: 1 if WeeklyQuantityDemanded equals the rolling 9-week min; otherwise 0. Flags a local low within the last 9 weeks.
'MomentumStability'Coefficient of variation (rolling std / rolling mean) of Momentum_WeeklyQuantity over the past 4 weeks. Reflects how steady or volatile the weekly momentum has been recently.
'AvgCategoryMomentum'For each WeekEndingDate and L3ProdCat, the average Momentum_WeeklyQuantity across all products in that category. Merged back into each row, so it reflects category-wide momentum for the same date.
'MinCategoryMomentum'The minimum momentum value within the same WeekEndingDate & L3ProdCat. Indicates the slowest-growing or steepest-declining product in the category that week.
'MaxCategoryMomentum'The maximum momentum value within the same WeekEndingDate & L3ProdCat. Indicates the fastest-growing product in the category that week.
'CategoryMomentumVolatility'Standard deviation of Momentum_WeeklyQuantity within the same WeekEndingDate & L3ProdCat. Shows how spread out momentum is across the category.
'TotalCategoryDemand'Sum of WeeklyQuantityDemanded for all products in the same L3ProdCat on that particular WeekEndingDate. Reflects the total category size for that week.
'AvgCategoryDemand'Average of WeeklyQuantityDemanded across all products in the same L3ProdCat and WeekEndingDate. Complements TotalCategoryDemand for a per-product view.
'RelativeMomentumToCategory'Product’s momentum minus the category’s average momentum (i.e., Momentum_WeeklyQuantity - AvgCategoryMomentum). Shows whether the product is gaining faster or slower than its peers.
'RelativeDemandToCategory'Ratio of the product’s WeeklyQuantityDemanded to the category’s average demand (AvgCategoryDemand) on the same date. Values > 1 mean above-average demand.
'IsTopPerformerInCategory'Boolean (True / False) whether the product’s WeeklyQuantityDemanded is the highest among all products in its L3ProdCat on that WeekEndingDate.
'CategoryDemandGrowthRate'For each L3ProdCat, percentage change in TotalCategoryDemand from last week to the current week: ((\text{Demand}t - \text{Demand}{t-1}) ,/, \text{Demand}_{t-1}).
'NextWeeklyQuantity'Target Variable: The following week’s demand for the same product. Created by shifting WeeklyQuantityDemanded one row backward in time (so row N has the demand from row N+1).
import numpy as np
np.random.seed(42)

import pandas as pd
import shap
from shap.plots import initjs

initjs()

# Read in the train and test dataset
train, test = pd.read_csv("./data/home_decor_train.csv"), pd.read_csv(
    "./data/home_decor_test.csv"
)
display(train.head())
display(test.head())
from utils import train_model 

# Define constants
feature_cols = [
    col
    for col in train.columns
    if col
    not in [
        "ProductName",
        "WeekEndingDate",
        "NextWeeklyQuantity",
        "L1ProdCat",
        "L2ProdCat",
        "L3ProdCat",
    ]
]
target_col = "NextWeeklyQuantity"

# Train initial model
training_artifacts = train_model(train, test, feature_cols, target_col)

Instructions

Explore the training_artifacts dictionary and use Explainable AI to answer the following questions:

1. Which are the top 10 most impactful features based on SHAP during training? Store as a pandas series called top_10_shap_features where index are the feature names, and values are sorted from most to least impactful?

# TODO: INSERT CODE HERE

2. Plot a scatter plot of the shap values for the feature CumulativeQuantityDemanded. Identify if there's a positive or negative relationship between the feature and its shap values, save this as one of ["positive", "negative"] in a variable shap_relationship_of_CumulativeQuantityDemanded. Furthermore, identify any inflection points in the scatter plot, saving the value of CumulativeQuantityDemanded where the inflection point occurs as point_of_inflection_in_CumulativeQuantityDemanded.

# TODO: INSERT CODE HERE

3. There are quite a few features we suspect are redundant. Let's check how correlated the features are and use insights from SHAP to remove redundant features. Retrain the model by calling train_model and save the output as a variable training_artifacts_reduced_feature_set.

# TODO: INSERT CODE HERE