Skip to main content

AdamW Optimizer in PyTorch Tutorial

Discover how the AdamW optimizer improves model performance by decoupling weight decay from gradient updates. This tutorial explains the key differences between Adam and AdamW, their use cases and provides a step-by-step guide to implementing AdamW in PyTorch.
Oct 21, 2024  · 12 min read

Optimization algorithms play a crucial role in deep learning: they fine-tune model weights to minimize loss functions during training. One such algorithm is the Adam optimizer.

Adam became extremely popular in deep learning due to its ability to combine the advantages of momentum and adaptive learning rates. This made it highly efficient for training deep neural networks. It also requires minimal tuning of hyperparameters, thus making it widely accessible and effective across various tasks.

In 2017, Ilya Loshchilov and Frank Hutter introduced a more advanced version of the popular Adam algorithm in their paper "Decoupled Weight Decay Regularization." They named it AdamW, which stands out for decoupling weight decay from the gradient update process. This separation is a crucial improvement over Adam and helps with better model generalization.

AdamW has become increasingly important in modern deep learning applications, particularly in handling large-scale models. Its superior ability to regulate weight updates has contributed to its adoption in tasks that demand high performance and stability. 

In this tutorial, we are going to touch on the key differences between Adam and AdamW, and the different use cases, and we will be implementing a step-by-step guide to implementing AdamW in PyTorch

Adam vs AdamW

Adam and AdamW are both adaptive optimizers widely used in deep learning. The big difference between them is how they handle weight regularization, which impacts their effectiveness in different scenarios. 

While Adam combines momentum and adaptive learning rates to offer efficient optimization, it incorporates L2 regularization in a way that can hinder performance. AdamW addresses this by decoupling weight decay from the learning rate update, providing a more effective approach for large models and improving generalization. Weight decay, a form of L2 regularization, penalizes large weights in the model. Adam incorporates weight decay into the gradient update process, while AdamW applies it separately after the gradient update

Here are some other ways they differ: 

Key differences between Adam and AdamW

Although both optimizers are designed to manage momentum and adjust learning rates dynamically, they differ fundamentally in their treatment of weight decay. 

In Adam, weight decay is applied indirectly as part of the gradient update, which can unintentionally modify the learning dynamics and interfere with the optimization process. AdamW, however, separates weight decay from the gradient step, ensuring regularization directly impacts the parameters without altering the adaptive learning mechanism. 

This design leads to more precise regularization, helping models generalize better, particularly in tasks that involve large and complex datasets. As a result, the two optimizers often have very different use cases.  

Use cases for Adam

Adam performs better in tasks where regularization is less critical or when computational efficiency is prioritized over-generalization. Examples include:

  • Smaller Neural Networks. For tasks like basic image classification using small CNNs (Convolutional Neural Networks) on datasets like MNIST or CIFAR-10, where the model complexity is low, Adam can efficiently optimize without needing extensive regularization.
  • Simple regression problems. In straightforward regression tasks with limited feature sets, such as predicting house prices using a linear regression model, Adam can quickly converge without needing advanced regularization techniques.
  • Early stage prototyping. During the initial stages of model development, where rapid experimentation is needed, Adam allows for quick iterations on simpler architectures, enabling researchers to identify potential problems without the overhead of tuning regularization parameters.
  • Less noisy data. When working with clean datasets with minimal noise, such as well-curated text data for sentiment analysis, Adam can effectively learn patterns without the risk of overfitting that might necessitate heavier regularization.
  • Short training cycles. In scenarios with time constraints, such as rapid model deployment for real-time applications, Adam’s efficient optimization can help deliver satisfactory results quickly, even if they may not be fully optimized for generalization.

Use cases for AdamW

AdamW excels in scenarios where overfitting is a concern and model size is substantial. For example: 

  • Large-scale transformers. In natural language processing tasks, such as fine-tuning models like GPT on extensive text corpora, AdamW's ability to manage weight decay effectively prevents overfitting, ensuring better generalization.
  • Complex computer vision models. For tasks involving deep convolutional neural networks (CNNs) trained on large datasets like ImageNet, AdamW helps maintain model stability and performance by decoupling weight decay, which is crucial for achieving high accuracy.
  • Multi-task learning. In scenarios where a model is trained on multiple tasks simultaneously, AdamW provides the flexibility to handle diverse datasets and prevent overfitting on any single task.
  • Generative models. For training generative adversarial networks (GANs), where maintaining a balance between the generator and discriminator is critical, AdamW's improved regularization can help stabilize training and enhance the quality of generated outputs.
  • Reinforcement learning. In reinforcement learning applications where models must adapt to complex environments and learn robust policies, AdamW helps mitigate overfitting to specific states or actions, improving the model's general performance in varied situations.

Advantages of AdamW Over Adam

But why would anyone want to use AdamW over Adam? Simple. AdamW offers several key benefits that enhance its performance, particularly in complex modeling scenarios. 

It addresses some of the limitations found in the Adam optimizer, thus making it more effective at optimization and contributing to improved model training and robustness. 

Here are some more of the standout advantages:

  • Decoupled weight decay. By separating weight decay from gradient updates, AdamW allows for more precise control over regularization, leading to better model generalization.
  • Enhanced generalization. AdamW reduces the risk of overfitting, especially in large-scale models, making it suitable for tasks involving extensive datasets and intricate architectures.
  • Stability during training. The design of AdamW helps maintain stability throughout the training process, which is essential for models that require careful tuning of their hyperparameters.
  • Scalability. AdamW is particularly effective for scaling up models, as it can handle the increased complexity of deep networks without sacrificing performance, allowing it to be applied in state-of-the-art architectures.

How AdamW Works

AdamW's core strength lies in its approach to weight decay, which is decoupled from the adaptive gradient updates typical of Adam. This adjustment ensures regularization is applied directly to the model's weights, improving generalization without negatively impacting the learning rate dynamics.

The optimizer builds upon Adam's adaptive nature, maintaining the benefits of momentum and per-parameter learning rate adjustments. Applying weight decay independently addresses one of Adam’s key shortcomings: its tendency to affect gradient updates during regularization. This separation allows AdamW to maintain stable learning, even in complex and large-scale models, while keeping overfitting in check.

In the following sections, we’ll explore the theory behind weight decay and regularization and the mathematics that underpin AdamW’s optimization process.

Theory Behind Weight Decay and L2 Regularization

L2 regularization is a technique used to prevent overfitting. It achieves this objective by adding a penalty term to the loss function, discouraging large weight values. This technique helps create simpler models that generalize better to new data. 

In traditional optimizers, such as Adam, weight decay is applied as part of the gradient update, which inadvertently affects learning rates and can lead to suboptimal performance.

AdamW improves upon this by decoupling weight decay from the gradient computation. In other words, rather than applying weight decay during the gradient update, AdamW treats it as a separate step, applying it directly to the weights after the gradient update. This prevents weight decay from interfering with the optimization process, leading to more stable training and better generalization.

Mathematical Foundation of AdamW

AdamW modifies the traditional Adam optimizer by changing how weight decay is applied. The core equations for AdamW can be represented as follows:

  1. Momentum and adaptive learning rate: Similar to Adam, AdamW uses momentum and adaptive learning rates to compute parameter updates based on the moving averages of gradients and squared gradients.

The equation for momentum and adaptive learning rate

  1. Bias-corrected estimates:The first and second-moment estimates are corrected for bias using the following:

The formula for bias-corrected estimates

  1. Parameter update with decoupled weight decay: In AdamW, weight decay is applied directly to the parameters after the gradient update. The update rule is:

Parameter update with decoupled weight decay

Here, η is the learning rate, λ is the weight decay factor, and θt​ represents the parameters. This decoupled weight decay term λθt​ ensures that regularization is applied independently of the gradient update, which is the key difference from Adam.

Implementation of AdamW in PyTorch

Implementing AdamW in PyTorch is straightforward; this section provides a comprehensive guide to setting it up. Follow these steps to learn how to fine-tune models effectively with Adam Optimizer. 

A step-by-step guide to AdamW in PyTorch

Note: this tutorial assumes you already have PyTorch installed. Refer to the Documentation for any guidance. 

Step 1: Import the necessary libraries

import torch
import torch.nn as nn
import torch.optim as optim
Import torch.nn.functional as F

Step 2: Define the model

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 
        self.fc1 = nn.Linear(64 * 8 * 8, 128) 
        self.fc2 = nn.Linear(128, 10) 

    def forward(self, x): 
        x = self.pool(F.relu(self.conv1(x))) 
        x = self.pool(F.relu(self.conv2(x))) 
        x = x.view(-1, 64 * 8 * 8) 
        x = F.relu(self.fc1(x)) 
        x = self.fc2(x) 

Step 3: Set the hyperparameters 

learning_rate = 1e-4
weight_decay = 1e-2
num_epochs = 10 # number of epochs

Step 4: Initialize the AdamW optimizer and set up the loss function

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 
criterion = nn.CrossEntropyLoss()

Voila! 

Now, you are ready to start training your CNN model, and that’s what we will do in the next section. 

Practical Example: Fine-Tuning a Model Using AdamW

Above, we defined the model, set the hyperparamters,initializied the optimizer (AdamW), and set up the loss function. 

To train the model, we will need to import a few more modules; 

from torch.utils.data import DataLoader # provides an iterable of the dataset
import torchvision
import torchvision.transforms as transforms

Next, define the dataset and dataloaders. For this example, we will use the CIFAR-10 dataset: 

# Define transformations for the training set
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Since we’ve already defined our model, the next step is to implement the training loop to optimize the model using AdamW. 

Here is how it looks: 

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients
        outputs = model(inputs)  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        running_loss += loss.item()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

The last step is to validate the model’s performance on the validation dataset we created earlier. 

Here’s the code: 

model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)  # Forward pass
        _, predicted = torch.max(outputs.data, 1)  # Get predicted class
        total += labels.size(0)  # Update total samples
        correct += (predicted == labels).sum().item()  # Update correct predictions
accuracy = 100 * correct / total
print(f'Validation Accuracy: {accuracy:.2f}%')

And there you have it. 

You now know how to implement AdamW in PyTorch. 

Common Use Cases for AdamW

Okay, so we’ve established that AdamW gained popularity due to its more effective management of weight decay than its predecessor, Adam. 

But what are some common use cases for this optimizer? 

We will get into that in this section… 

Large-Scale Deep Learning Models

AdamW is particularly beneficial in training large models like BERT, GPT, and other transformer architectures. Such models typically have millions or even billions of parameters, which often means they demand efficient optimization algorithms that handle complex weight updates and generalization challenges.

Computer Vision and NLP Tasks

AdamW has become the optimizer of choice in computer vision tasks involving CNNs and NLP tasks involving transformers. Its ability to prevent overfitting makes it ideal for tasks involving large datasets and complex architectures. The decoupling of weight decay means AdamW avoids the issues encountered by Adam in over-regularizing models.

Hyperparameter Tuning in AdamW

Hyperparameter tuning is the process of selecting the best values for parameters that govern the training of a machine learning model but are not learned from the data itself. These parameters directly influence how the model optimizes and converges. 

Proper tuning of these hyperparameters in AdamW is essential for achieving efficient training, avoiding overfitting, and ensuring the model generalizes well to unseen data. 

In this section, we’ll explore how to fine-tune AdamW’s key hyperparameters for optimal performance.

Best practices for choosing learning rates and weight decay

The learning rate is a hyperparameter that controls how much to adjust the model weights with respect to the loss gradient during each training step. A higher learning rate speeds up training but may cause the model to overshoot optimal weights, while a lower rate allows for more fine-tuned adjustments but can make training slower or get stuck in local minima.

Weight decay, on the other hand, is a regularization technique used to prevent overfitting by penalizing large weights in the model. Namely, Weight decay adds a small penalty proportional to the size of the model weights during training, helping to reduce model complexity and improve generalization to new data.

To choose optimal learning rates and weight decay values for AdamW: 

  1. Start with a moderate learning rate – For AdamW, a learning rate around 1e-3 is often a good starting point. You can adjust it based on how well the model converges, lowering it if the model struggles to converge or increasing it if training is too slow.
  2. Experiment with weight decay. Start with a value around 1e-2 to 1e-4, depending on the model size and dataset. A slightly higher weight decay can help prevent overfitting for larger, complex models, while smaller models may need less regularization.
  3. Use learning rate scheduling. Implement learning rate schedules (like step decay or cosine annealing) to dynamically reduce the learning rate as training progresses, helping the model fine-tune its parameters as it approaches convergence.
  4. Monitor performance. Continuously track model performance on the validation set. If you observe overfitting, consider increasing weight decay, or if the training loss plateaus, lower the learning rate for better optimization.

Final Thoughts

AdamW has emerged as one of the most effective optimizers in deep learning, especially for large-scale models. This is due to its ability to decouple weight decay from gradient updates. Namely, the design of AdamW improves regularization and helps models generalize better, particularly when dealing with complex architectures and extensive datasets. 

As demonstrated in this tutorial, implementing AdamW in PyTorch is straightforward—it just requires a few adjustments from Adam. However, hyperparameter tuning remains a crucial step for maximizing the effectiveness of AdamW. Finding the right balance between learning rate and weight decay is essential for ensuring the optimizer works efficiently without overfitting or underfitting the model.

Now you know enough to implement AdamW in your own models. To continue your learning, check out some of these resources: 

AdamW Optimizer FAQs

What is the AdamW optimizer and how does it differ from Adam?

AdamW is an improved version of Adam that decouples weight decay from the gradient update process, leading to better generalization and reduced overfitting.

When should I use AdamW instead of Adam?

AdamW is ideal for large-scale models like transformers or tasks prone to overfitting, whereas Adam may suffice for simpler models and datasets.

How do I implement AdamW in PyTorch?

You can implement AdamW in PyTorch using the torch.optim.AdamW function, specifying parameters such as learning rate and weight decay.

Why is weight decay important in AdamW?

Weight decay helps prevent overfitting by penalizing large weight values. In AdamW, it’s applied directly to the weights, improving regularization without affecting learning rates.

Can AdamW be used for both NLP and computer vision tasks?

Yes, AdamW is effective for both NLP and computer vision tasks, particularly for large models like BERT and CNNs, where proper weight regularization is essential for performance.


Kurtis Pykes 's photo
Author
Kurtis Pykes
LinkedIn
Topics

Top DataCamp Courses

course

Intermediate Deep Learning with PyTorch

4 hr
11.9K
Learn about fundamental deep learning architectures such as CNNs, RNNs, LSTMs, and GRUs for modeling image and sequential data.
See DetailsRight Arrow
Start Course
See MoreRight Arrow
Related

tutorial

Adam Optimizer Tutorial: Intuition and Implementation in Python

Understand and implement the Adam optimizer in Python. Learn the intuition, math, and practical applications in machine learning with PyTorch
Bex Tuychiev's photo

Bex Tuychiev

14 min

tutorial

Adagrad Optimizer Explained: How It Works, Implementation, & Comparisons

Learn the Adagrad optimization technique, including its key benefits, limitations, implementation in PyTorch, and use cases for optimizing machine learning models.
Satyam Tripathi's photo

Satyam Tripathi

12 min

tutorial

RMSprop Optimizer Tutorial: Intuition and Implementation in Python

Learn about the RMSprop optimization algorithm, its intuition, and how to implement it in Python. Discover how this adaptive learning rate method improves on traditional gradient descent for machine learning tasks.
Bex Tuychiev's photo

Bex Tuychiev

14 min

tutorial

PyTorch Lightning: A Comprehensive Hands-On Tutorial

This comprehensive, hands-on tutorial teaches you how to simplify deep learning model development with PyTorch Lightning. Perfect for beginners and experienced developers alike, it covers environment setup, model training, and practical examples.
Bex Tuychiev's photo

Bex Tuychiev

35 min

tutorial

Building a Transformer with PyTorch

Learn how to build a Transformer model using PyTorch, a powerful tool in modern machine learning.
Arjun Sarkar's photo

Arjun Sarkar

26 min

tutorial

Investigating Tensors with PyTorch

In this tutorial, you'll learn about Tensors, PyTorch, and how to create a simple neural network with PyTorch.
Sayak Paul's photo

Sayak Paul

17 min

See MoreSee More