"Drift" is a term used in machine learning to describe how the performance of a machine learning model in production slowly gets worse over time. This can happen for a number of reasons, such as changes in the distribution of the input data over time or the relationship between the input (x) and the desired target (y) changing.
Drift can be a big problem when we use machine learning in the real world, where data is often dynamic and always changing. This article will deep dive into why models drift, different types of drift, algorithms to detect them, and finally, wrap up this article with an open-source implementation of drift detection in Python.
What is Drift?
Machine learning models are trained with historical data, but once they are used in the real world, they may become outdated and lose their accuracy over time due to a phenomenon called drift. Drift is the change over time in the statistical properties of the data that was used to train a machine learning model. This can cause the model to become less accurate or perform differently than it was designed to.
In other words, "drift" is the decline in a model's ability to make accurate predictions due to changes in the environment in which it is being used.
Why do Machine Learning Models Drift?
There are several reasons why machine learning models can drift over time.
One common reason is simply that the data that the model was trained on becomes outdated or no longer represents the current conditions.
For example, consider a machine learning model trained to predict the stock price of a company based on historical data. If we train the model with data from a stable market, it might do well at first. However, if the market becomes more volatile over time, the model might not be able to accurately predict the stock price anymore because the statistical properties of the data have changed.
Another reason for model drift is that the model was not designed to handle changes in the data. Some machine learning models can handle changes in the data better than others, but no model can avoid drift completely.
Types of Drift
Let's explore the two different types of drift to consider:
1. Concept Drift
Concept drift, also known as model drift, occurs when the task that the model was designed to perform changes over time. For example, imagine that a machine learning model was trained to detect spam emails based on the content of the email. If the types of spam emails that people receive change significantly, the model may no longer be able to accurately detect spam.
Concept Drift can be further divided into four categories (Learning under Concept Drift: A Review, Jie Lu et al.):
- Sudden Drift
- Gradual Drift
- Incremental Drift
- Recurring Concepts
2. Data Drift
Data drift, also known as covariate shift, occurs when the distribution of the input data changes over time. For example, consider a machine learning model that was trained to predict the likelihood of a customer purchasing a product based on their age and income. If the distribution of ages and incomes of the customers changes significantly over time, the model may no longer be able to predict the likelihood of a purchase accurately.
It is important to be aware of both concept drift and data drift and take steps to prevent or mitigate their effects. Some strategies for addressing drift include continuously monitoring and evaluating the performance of a model, updating the model with new data, and using machine learning models that are more robust to drift.
You can learn more about post-deployment data science, such as drift, in our DataFramed podcast episode.
How do You Detect Drift?
There are two ways we can detect drift:
1. Machine Learning Model-Based Approach: Model-based approach to detect whether the incoming input data has drifted or not.
2. Statistical Tests: There are many statistical tests to detect data drift. They are primarily divided into three categories:
- Sequential analysis methods
- Acustom model to detect drift
- Time distribution method, which is very common.
Time distribution-based methods use statistical methods to calculate the difference between two probability distributions to detect drift. These methods include the Population Stability Index, KL Divergence, JS Divergence, KS Test, and the Wasserstein Metric.
Algorithms for Detecting Data Drift
Kolmogorov-Smirnov (K-S) test
The Kolmogorov-Smirnov (K-S) test is a nonparametric statistical test that is used to determine whether two sets of data come from the same distribution. It is often used to test whether a sample of data comes from a specific population or to compare two samples to determine if they come from the same population.
The null hypothesis in this test is that the distributions are the same. If this hypothesis is rejected, it suggests that there is a drift in the model.
The K-S test is a useful tool for comparing datasets and determining whether they come from the same distribution.
Population Stability Index
The Population Stability Index (PSI) is a statistical measure that is used to compare the distribution of a categorical variable in two different datasets.
The Population Stability Index (PSI) is a tool used to measure how much the distribution of a variable has changed between two samples or over time. It is commonly used to monitor changes in the characteristics of a population and to identify potential problems with the performance of a machine learning model.
The PSI was originally developed to monitor changes in the distribution of a score in risk scorecards, but it is now used to examine distributional shifts for all model-related attributes, including both dependent and independent variables.
A high PSI value indicates that there is a significant difference between the distributions of the variable in the two datasets, which may suggest that there is a drift in the model.
If the distribution of a variable has changed significantly, or if several variables have changed to some extent, it may be necessary to recalibrate or rebuild the model to improve its performance.
The Page-Hinkley method is a statistical method used to detect changes in the mean of a series of data over time. It is commonly used to monitor the performance of machine learning models and detect changes in the distribution of the data that may indicate model drift.
To use the Page-Hinkley method, the first step is to define a threshold value and a decision function. The threshold value is a value above which a change in the mean is considered significant, and the decision function is a function that returns a value of 1 if a change has been detected and a value of 0 if no change has been detected.
Next, the mean of the data series is calculated at each time step, and the decision function is applied to the data to determine if a change has occurred. If the decision function returns a value of 1, it indicates that a change has been detected and the model may be drifting.
The Page-Hinkley method is a simple and effective way to detect changes in the mean of a data series over time. It is particularly useful for detecting small changes in the mean that may not be immediately apparent when looking at the data. However, it is important to carefully select the threshold value and decision function to ensure that the method is sensitive enough to detect changes in the data but not so sensitive that it generates false alarms.
Drift Detection Implementation in Python
In this section, we will use Evidently to detect drift. Evidently is an open-source Python library made for data scientists and engineers who work with machine learning. It helps them test, evaluate, and keep track of how well their models work from validation to production.
``` # import libraries import pandas as pd import numpy as np from sklearn import datasets from evidently.report import Report from evidently.metrics import DataDriftTable from evidently.metrics import DatasetDriftMetric ```
Import Dataset and Create Reference and Target Partitions
``` # create ref and cur dataset for drift detection adult_data = datasets.fetch_openml(name='adult', version=2, as_frame='auto') adult = adult_data.frame adult_ref = adult[~adult.education.isin(['Some-college', 'HS-grad', 'Bachelors'])] adult_cur = adult[adult.education.isin(['Some-college', 'HS-grad', 'Bachelors'])] adult_cur.iloc[:2000, 3:5] = np.nan ```
Generate Drift Report
``` #dataset-level metrics data_drift_dataset_report = Report(metrics=[ DatasetDriftMetric(), DataDriftTable(), ]) data_drift_dataset_report.run(reference_data=adult_ref, current_data=adult_cur) data_drift_dataset_report ```
Drift Detection Dashboard - created using EvidentlyAI
Export Drift Report in JSON format
``` #report in a JSON format data_drift_dataset_report.json() ```
Check out the complete Datacamp Notebook here.
Data and model drift can pose significant challenges to machine learning systems in production. By understanding the causes and effects of drift, and implementing effective drift monitoring practices, you can ensure that your machine learning models remain accurate and reliable over time.
Monitoring the performance of your models, using a drift detection model, and regularly retraining on updated data are just a few of the best practices you can follow to mitigate the risks of drift. By being proactive about drift monitoring, you can ensure that your machine learning system continues to deliver value to your organization.
Monitoring machine learning models for drift is just one aspect of a broader field called MLOps. Understanding MLOps concepts is essential for any data scientist, engineer, or leader to take machine learning models from a local notebook to a functioning model in production.
If you would like to take a deep dive into understanding MLOps and how it can benefit you in your career, check out our MLOps Concepts course. Here, you’ll learn what MLOps is, understand the different phases in MLOps processes, and identify different levels of MLOps maturity. After learning about the essential MLOps concepts, you’ll be well-equipped in your journey to implement machine learning continuously, reliably, and efficiently.
Drift Detection FAQs
What is machine learning model drift?
Machine learning model drift is when a model's performance on new data is different from how it performed on the training data it was built on. This can happen for a variety of reasons, including changes in the distribution of data over time, the addition of new data that doesn't fit the original model's assumptions, or the model's own inability to adapt to changing conditions.
Why is model drift a problem?
Model drift can significantly impact the performance and accuracy of a machine learning model. As the model's predictions become less reliable, it can lead to incorrect decisions or actions that can have negative consequences. For example, in a healthcare setting, model drift could lead to incorrect diagnoses or treatment recommendations, while in a finance setting, it could result in poor investment decisions.
How do you detect model drift?
There are several ways to determine if a model is drifting, such as statistical tests, drift detection algorithms, and looking at how well the model is doing. Some of these methods are made to find drift in real-time, while others are better for testing at set times or in groups. It's important to choose the right technique for the specific application and data environment.
How do you prevent model drift?
Preventing model drift requires a combination of careful model selection, regular monitoring and testing, and proactive intervention. This may involve using algorithms that are more robust to drift, regularly retraining models on new data, or implementing strategies to actively address drift when it is detected. It's also important to have a clear understanding of the factors that can cause drift so that you can take steps to prevent them.
How does data distribution affect model drift?
The data distribution can significantly affect the performance of a machine learning model. If the distribution of data changes over time, it can lead to model drift, as the model may no longer be able to accurately predict new data that doesn't match its original assumptions. This can happen in a variety of ways, such as through natural variations in the data, the addition of new data sources, or changes in the underlying processes or systems that generate the data.
Is model drift reversible?
In some cases, model drift can be reversible by retraining the model on new data or adjusting its parameters. However, this is not always possible, especially if the data distribution has changed significantly or the model has become overly complex or specialized. In these situations, it may be necessary to start over with a new model.
Is it possible to completely eliminate model drift?
Completely eliminating model drift is difficult, if not impossible. Even the most robust and well-designed machine learning models can be affected by changes in the data or the underlying processes that generate it. The best approach is to manage and mitigate the impact of model drift through regular monitoring, testing, and intervention.
How does model drift impact model performance?
Model drift can have a significant impact on the performance of a machine learning model. As the model's predictions become less accurate, it can lead to reduced performance on important metrics such as accuracy, precision, recall, and overall model effectiveness. In some cases, model drift can even cause a model to fail completely, resulting in incorrect or unreliable predictions.
How does model drift affect model accuracy?
Model drift can have a negative impact on the accuracy of a machine learning model. As the model's predictions become less accurate, it can lead to incorrect decisions or actions, which can have negative consequences in real-world applications. For example, in a healthcare setting, model drift could lead to incorrect diagnoses or treatment recommendations, while in a finance setting, it could result in poor investment decisions. It's important to regularly monitor and test for model drift in order to maintain the accuracy of the model.
SQL vs Python: Which Should You Learn?
How to Install Python
How to Create a Histogram with Plotly
Precision-Recall Curve in Python Tutorial