Pular para o conteúdo principal

Tutorial do otimizador AdamW no PyTorch

Descubra como o otimizador AdamW melhora o desempenho do modelo ao desacoplar o decaimento do peso das atualizações de gradiente. Este tutorial explica as principais diferenças entre o Adam e o AdamW, seus casos de uso e fornece um guia passo a passo para você implementar o AdamW no PyTorch.
Actualizado 22 de out. de 2024  · 12 min de leitura

Os algoritmos de otimização desempenham um papel fundamental na aprendizagem profunda: eles ajustam os pesos do modelo para minimizar as funções de perda durante o treinamento. Um desses algoritmos é o otimizador Adam.

O Adam se tornou extremamente popular na aprendizagem profunda devido à sua capacidade de combinar as vantagens das taxas de aprendizagem adaptativa e dinâmica. Isso o tornou altamente eficiente para o treinamento de redes neurais profundas. Ele também exige um ajuste mínimo dos hiperparâmetros, o que o torna amplamente acessível e eficaz em várias tarefas.

Em 2017, Ilya Loshchilov e Frank Hutter apresentaram uma versão mais avançada do popular algoritmo Adam em seu artigo "Regularização de decaimento de peso desacoplado."Eles o chamaram de AdamW, que se destaca por desacoplar o decaimento do peso do processo de atualização do gradiente. Essa separação é um aprimoramento crucial em relação ao Adam e ajuda a melhorar a generalização do modelo.

O AdamW tem se tornado cada vez mais importante nos aplicativos modernos de aprendizagem profunda, principalmente no manuseio de modelos de grande escala. Sua capacidade superior de regular as atualizações de peso contribuiu para sua adoção em tarefas que exigem alto desempenho e estabilidade. 

Neste tutorial, abordaremos as principais diferenças entre o Adam e o AdamW, bem como os diferentes casos de uso, e implementaremos um guia passo a passo para implementar o AdamW no PyTorch.

Adam vs AdamW

O Adam e o AdamW são otimizadores adaptativos amplamente usados na aprendizagem profunda. A grande diferença entre eles é como lidam com a regularização de peso, o que afeta sua eficácia em diferentes cenários. 

Embora o Adam combine momentum e taxas de aprendizado adaptáveis para oferecer uma otimização eficiente, ele incorpora a regularização L2 de uma forma que pode prejudicar o desempenho. O AdamW aborda essa questão desacoplando o decaimento do peso da atualização da taxa de aprendizado, fornecendo uma abordagem mais eficaz para modelos grandes e melhorando a generalização. O decaimento de peso, uma forma de regularização L2, penaliza pesos grandes no modelo. O Adam incorpora a redução de peso no processo de atualização do gradiente, enquanto o AdamW a aplica separadamente após a atualização do gradiente

Aqui estão algumas outras maneiras pelas quais eles diferem: 

Principais diferenças entre o Adam e o AdamW

Embora ambos os otimizadores tenham sido projetados para gerenciar o momentum e ajustar as taxas de aprendizado dinamicamente, eles diferem fundamentalmente no tratamento da deterioração do peso. 

Na Adam, a redução do peso é aplicada indiretamente como parte da atualização do gradiente, o que pode modificar involuntariamente a dinâmica do aprendizado e interferir no processo de otimização. O AdamW, no entanto, separa o decaimento do peso da etapa do gradiente, garantindo que a regularização afete diretamente os parâmetros sem alterar o mecanismo de aprendizagem adaptativa. 

Esse design leva a uma regularização mais precisa, ajudando os modelos a se generalizarem melhor, principalmente em tarefas que envolvem conjuntos de dados grandes e complexos. Como resultado, os dois otimizadores geralmente têm casos de uso muito diferentes.  

Casos de uso para Adam

O Adam tem um desempenho melhor em tarefas em que a regularização é menos crítica ou quando a eficiência computacional é priorizada em relação à generalização. Os exemplos incluem:

  • Redes neurais menores. Para tarefas como classificação básica de imagens usando pequenas CNNs(Redes Neurais Convolucionais) em conjuntos de dados como MNIST ou CIFAR-10, em que a complexidade do modelo é baixa, o Adam pode otimizar com eficiência sem a necessidade de regularização extensiva.
  • Problemas de regressão simples. Em tarefas de regressão simples com conjuntos de recursos limitados, como a previsão de preços de imóveis usando um modelo de regressão linear, a Adam pode convergir rapidamente sem a necessidade de técnicas avançadas de regularização.
  • Prototipagem em estágio inicial. Durante os estágios iniciais do desenvolvimento do modelo, em que é necessária uma experimentação rápida, o Adam permite iterações rápidas em arquiteturas mais simples, possibilitando que os pesquisadores identifiquem possíveis problemas sem a sobrecarga de ajustar os parâmetros de regularização.
  • Dados menos ruidosos. Ao trabalhar com conjuntos de dados limpos e com o mínimo de ruído, como dados de texto bem selecionados para análise de sentimentos, o Adam pode aprender padrões de forma eficaz sem o risco de ajuste excessivo que poderia exigir uma regularização mais pesada.
  • Ciclos de treinamento curtos. Em cenários com restrições de tempo, como a implantação rápida de modelos para aplicativos em tempo real, a otimização eficiente do Adam pode ajudar a fornecer resultados satisfatórios rapidamente, mesmo que eles não sejam totalmente otimizados para generalização.

Casos de uso do AdamW

O AdamW é excelente em cenários em que o excesso de ajuste é uma preocupação e o tamanho do modelo é substancial. Por exemplo: 

  • Transformadores de grande porte. Em tarefas de processamento de linguagem natural, como o ajuste fino de modelos como o GPT em corpora de texto extensos, a capacidade do AdamW de gerenciar a diminuição do peso evita efetivamente o ajuste excessivo, garantindo uma melhor generalização.
  • Modelos complexos de visão computacional. Para tarefas que envolvem redes neurais convolucionais profundas (CNNs) treinadas em grandes conjuntos de dados, como o ImageNet, o AdamW ajuda a manter a estabilidade e o desempenho do modelo, desacoplando a deterioração do peso, o que é fundamental para obter alta precisão.
  • Aprendizagem multitarefa. Nos cenários em que um modelo é treinado em várias tarefas simultaneamente, o AdamW oferece a flexibilidade necessária para lidar com diversos conjuntos de dados e evitar o ajuste excessivo em uma única tarefa.
  • Modelos generativos. Para o treinamento de redes adversárias generativas (GANs), em que é fundamental manter um equilíbrio entre o gerador e o discriminador, a regularização aprimorada do AdamW pode ajudar a estabilizar o treinamento e melhorar a qualidade dos resultados gerados.
  • Aprendizagem por reforço. Em aplicações de aprendizado por reforço em que os modelos precisam se adaptar a ambientes complexos e aprender políticas robustas, o AdamW ajuda a reduzir o excesso de ajuste a estados ou ações específicas, melhorando o desempenho geral do modelo em situações variadas.

Vantagens do AdamW em relação ao Adam

Mas por que você gostaria de usar o AdamW em vez do Adam? Simples. O AdamW oferece vários benefícios importantes que melhoram seu desempenho, especialmente em cenários de modelagem complexos. 

Ele aborda algumas das limitações encontradas no otimizador Adam, tornando-o mais eficaz na otimização e contribuindo para melhorar o treinamento e a robustez do modelo. 

Aqui estão algumas das vantagens mais importantes:

  • Decaimento de peso desacoplado. Ao separar o decaimento do peso das atualizações de gradiente, o AdamW permite um controle mais preciso sobre a regularização, levando a uma melhor generalização do modelo.
  • Generalização aprimorada. O AdamW reduz o risco de ajuste excessivo, especialmente em modelos de grande escala, tornando-o adequado para tarefas que envolvem conjuntos de dados extensos e arquiteturas complexas.
  • Estabilidade durante o treinamento. O design do AdamW ajuda a manter a estabilidade durante todo o processo de treinamento, o que é essencial para modelos que exigem um ajuste cuidadoso de seus hiperparâmetros.
  • Escalabilidade. O AdamW é particularmente eficaz para ampliar os modelos, pois pode lidar com o aumento da complexidade das redes profundas sem sacrificar o desempenho, permitindo que ele seja aplicado em arquiteturas de última geração.

Como o AdamW funciona

O ponto forte do AdamW está em sua abordagem de redução de peso, que é desacoplada das atualizações de gradiente adaptativo típicas do Adam. Esse ajuste garante que a regularização seja aplicada diretamente aos pesos do modelo, melhorando a generalização sem afetar negativamente a dinâmica da taxa de aprendizado.

O otimizador se baseia na natureza adaptativa do Adam, mantendo os benefícios dos ajustes da taxa de aprendizado por parâmetro e por momento. A aplicação do decaimento de peso aborda independentemente uma das principais deficiências do Adam: sua tendência de afetar as atualizações de gradiente durante a regularização. Essa separação permite que o AdamW mantenha um aprendizado estável, mesmo em modelos complexos e de grande escala, ao mesmo tempo em que mantém o excesso de ajuste sob controle.

Nas seções a seguir, exploraremos a teoria por trás da redução e regularização do peso e a matemática que sustenta o processo de otimização do AdamW.

Teoria por trás do decaimento de peso e da regularização L2

A regularização L2 é uma técnica usada para evitar o ajuste excessivo. Ele atinge esse objetivo adicionando um termo de penalidade à função de perda, desencorajando valores de peso grandes. Essa técnica ajuda a criar modelos mais simples que se generalizam melhor para novos dados. 

Nos otimizadores tradicionais, como o Adam, a redução do peso é aplicada como parte da atualização do gradiente, o que afeta inadvertidamente as taxas de aprendizado e pode levar a um desempenho abaixo do ideal.

O AdamW aprimora isso ao desacoplar o decaimento do peso do cálculo do gradiente. Em outras palavras, em vez de aplicar o decaimento do peso durante a atualização do gradiente, o AdamW o trata como uma etapa separada, aplicando-o diretamente aos pesos após a atualização do gradiente. Isso evita que a deterioração do peso interfira no processo de otimização, levando a um treinamento mais estável e a uma melhor generalização.

Fundamentos matemáticos do AdamW

O AdamW modifica o otimizador Adam tradicional, alterando a forma como a redução de peso é aplicada. As principais equações do AdamW podem ser representadas da seguinte forma:

  1. Momentum e taxa de aprendizado adaptável: Semelhante ao Adam, o AdamW usa momentum e taxas de aprendizado adaptáveis para calcular atualizações de parâmetros com base nas médias móveis de gradientes e gradientes quadrados.

A equação para o momento e a taxa de aprendizagem adaptativa

  1. Estimativas com correção de viés:As estimativas de primeiro e segundo momentos são corrigidas quanto à tendência usando o seguinte:

A fórmula para estimativas com correção de viés

  1. Atualização de parâmetros com decaimento de peso desacoplado: No AdamW, o decaimento do peso é aplicado diretamente aos parâmetros após a atualização do gradiente. A regra de atualização é:

Atualização de parâmetros com decaimento de peso desacoplado

Aqui, η é a taxa de aprendizado, λ é o fator de decaimento do peso e θt representa os parâmetros. Esse termo de decaimento de peso desacoplado λθt garante que a regularização seja aplicada independentemente da atualização do gradiente, que é a principal diferença em relação ao Adam.

Implementação do AdamW no PyTorch

Implementação do AdamW no PyTorch é simples; esta seção fornece um guia abrangente para configurá-lo. Siga estas etapas para saber como fazer o ajuste fino dos modelos de forma eficaz com o Adam Optimizer. 

Um guia passo a passo para o AdamW no PyTorch

Observação: este tutorial pressupõe que você já tenha o PyTorch instalado. Consulte a Documentação para obter orientação.

Etapa 1: Importar as bibliotecas necessárias

import torch
import torch.nn as nn
import torch.optim as optim
Import torch.nn.functional as F

Etapa 2: Definir o modelo

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 
        self.fc1 = nn.Linear(64 * 8 * 8, 128) 
        self.fc2 = nn.Linear(128, 10) 

    def forward(self, x): 
        x = self.pool(F.relu(self.conv1(x))) 
        x = self.pool(F.relu(self.conv2(x))) 
        x = x.view(-1, 64 * 8 * 8) 
        x = F.relu(self.fc1(x)) 
        x = self.fc2(x) 

Etapa 3: Definir os hiperparâmetros 

learning_rate = 1e-4
weight_decay = 1e-2
num_epochs = 10 # number of epochs

Etapa 4: Inicialize o otimizador AdamW e configure a função de perda

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 
criterion = nn.CrossEntropyLoss()

Voila! 

Agora, você está pronto para começar a treinar o modelo CNN, e é isso que faremos na próxima seção. 

Exemplo prático: Ajuste fino de um modelo usando o AdamW

Acima, definimos o modelo, definimos os hiperparâmetros, inicializamos o otimizador (AdamW) e configuramos a função de perda. 

Para treinar o modelo, precisaremos importar mais alguns módulos; 

from torch.utils.data import DataLoader # provides an iterable of the dataset
import torchvision
import torchvision.transforms as transforms

Em seguida, defina o conjunto de dados e os carregadores de dados. Para este exemplo, usaremos o conjunto de dados CIFAR-10: 

# Define transformations for the training set
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Como já definimos nosso modelo, a próxima etapa é implementar o loop de treinamento para otimizar o modelo usando o AdamW. 

Aqui está a aparência: 

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()  # Clear gradients
        outputs = model(inputs)  # Forward pass
        loss = criterion(outputs, labels)  # Calculate loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        running_loss += loss.item()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

A última etapa é validar o desempenho do modelo no conjunto de dados de validação que criamos anteriormente. 

Aqui está o código: 

model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)  # Forward pass
        _, predicted = torch.max(outputs.data, 1)  # Get predicted class
        total += labels.size(0)  # Update total samples
        correct += (predicted == labels).sum().item()  # Update correct predictions
accuracy = 100 * correct / total
print(f'Validation Accuracy: {accuracy:.2f}%')

E aí está o que você precisa. 

Agora você sabe como implementar o AdamW no PyTorch. 

Casos de uso comuns para o AdamW

Ok, então estabelecemos que o AdamW ganhou popularidade devido ao seu gerenciamento mais eficaz da perda de peso do que seu antecessor, o Adam. 

Mas quais são alguns casos de uso comuns para esse otimizador? 

Vamos falar sobre isso nesta seção... 

Modelos de aprendizagem profunda em grande escala

O AdamW é particularmente útil no treinamento de modelos grandes, como BERT, GPT e outras arquiteturas de transformadores. Esses modelos normalmente têm milhões ou até bilhões de parâmetros, o que geralmente significa que eles exigem algoritmos de otimização eficientes que lidem com atualizações de peso complexas e desafios de generalização.

Visão computacional e tarefas de PNL

O AdamW tornou-se o otimizador preferido em tarefas de visão computacional que envolvem CNNs e tarefas de PNL que envolvem transformadores. Sua capacidade de evitar o ajuste excessivo o torna ideal para tarefas que envolvem grandes conjuntos de dados e arquiteturas complexas. O desacoplamento do decaimento do peso significa que o AdamW evita os problemas encontrados pelo Adam em modelos com excesso de regularização.

Ajuste de hiperparâmetros no AdamW

O ajuste de hiperparâmetros é o processo de seleção dos melhores valores para os parâmetros que regem o treinamento de um modelo de aprendizado de máquina, mas que não são aprendidos com os próprios dados. Esses parâmetros influenciam diretamente a forma como o modelo é otimizado e converge. 

O ajuste adequado desses hiperparâmetros no AdamW é essencial para obter um treinamento eficiente, evitar o ajuste excessivo e garantir que o modelo seja bem generalizado para dados não vistos. 

Nesta seção, exploraremos como fazer o ajuste fino dos principais hiperparâmetros do AdamW para obter o melhor desempenho.

Práticas recomendadas para escolher taxas de aprendizado e decaimento de peso

A taxa de aprendizado é um hiperparâmetro que controla o quanto ajustar os pesos do modelo em relação ao gradiente de perda durante cada etapa de treinamento. Uma taxa de aprendizado mais alta acelera o treinamento, mas pode fazer com que o modelo ultrapasse os pesos ideais, enquanto uma taxa mais baixa permite ajustes mais finos, mas pode tornar o treinamento mais lento ou ficar preso em mínimos locais.

A redução de peso, por outro lado, é uma técnica de regularização usada para evitar o ajuste excessivo, penalizando pesos grandes no modelo. Ou seja, o decaimento de peso adiciona uma pequena penalidade proporcional ao tamanho dos pesos do modelo durante o treinamento, ajudando a reduzir a complexidade do modelo e a melhorar a generalização para novos dados.

Para escolher as taxas de aprendizado ideais e os valores de decaimento de peso para o AdamW: 

  1. Comece com uma taxa de aprendizado moderada - Para o AdamW, uma taxa de aprendizado em torno de 1e-3 costuma ser um bom ponto de partida. Você pode ajustá-lo com base na convergência do modelo, diminuindo-o se o modelo tiver dificuldades para convergir ou aumentando-o se o treinamento for muito lento.
  2. Experimente o decaimento de peso. Comece com um valor em torno de 1e-2 a 1e-4, dependendo do tamanho do modelo e do conjunto de dados. Um decaimento de peso ligeiramente maior pode ajudar a evitar o ajuste excessivo de modelos maiores e complexos, enquanto modelos menores podem precisar de menos regularização.
  3. Use a programação da taxa de aprendizado. Implemente programações de taxa de aprendizagem (como decaimento de etapas ou recozimento de cosseno) para reduzir dinamicamente a taxa de aprendizagem à medida que o treinamento progride, ajudando o modelo a ajustar seus parâmetros à medida que se aproxima da convergência.
  4. Monitore o desempenho. Acompanhe continuamente o desempenho do modelo no conjunto de validação. Se você observar um ajuste excessivo, considere aumentar o decaimento do peso ou, se a perda de treinamento atingir um platô, reduza a taxa de aprendizado para uma melhor otimização.

Considerações finais

O AdamW surgiu como um dos otimizadores mais eficazes na aprendizagem profunda, especialmente para modelos de grande escala. Isso se deve à sua capacidade de desacoplar a redução de peso das atualizações de gradiente. Ou seja, o design do AdamW melhora a regularização e ajuda os modelos a se generalizarem melhor, principalmente ao lidar com arquiteturas complexas e conjuntos de dados extensos. 

Conforme demonstrado neste tutorial, a implementação do AdamW no PyTorch é simples - requer apenas alguns ajustes do Adam. No entanto, o ajuste de hiperparâmetros continua sendo uma etapa crucial para maximizar a eficácia do AdamW. Encontrar o equilíbrio certo entre a taxa de aprendizado e o decaimento do peso é essencial para garantir que o otimizador funcione de forma eficiente, sem que o modelo seja superajustado ou subajustado.

Agora você sabe o suficiente para implementar o AdamW em seus próprios modelos. Para continuar seu aprendizado, confira alguns desses recursos: 

Perguntas frequentes sobre o AdamW Optimizer

O que é o otimizador AdamW e como ele difere do Adam?

O AdamW é uma versão aprimorada do Adam que desacopla o decaimento do peso do processo de atualização do gradiente, o que leva a uma melhor generalização e redução do excesso de ajuste.

Quando devo usar AdamW em vez de Adam?

O AdamW é ideal para modelos de grande escala, como transformadores, ou tarefas propensas a excesso de ajuste, enquanto o Adam pode ser suficiente para modelos e conjuntos de dados mais simples.

Como faço para implementar o AdamW no PyTorch?

Você pode implementar o AdamW no PyTorch usando a classe torch.optim.AdamW especificando parâmetros como a taxa de aprendizado e o decaimento do peso.

Por que o decaimento de peso é importante no AdamW?

O decaimento do peso ajuda a evitar o ajuste excessivo, penalizando valores grandes de peso. No AdamW, ele é aplicado diretamente aos pesos, melhorando a regularização sem afetar as taxas de aprendizado.

O AdamW pode ser usado tanto para tarefas de PNL quanto de visão computacional?

Sim, o AdamW é eficaz para tarefas de PNL e visão computacional, especialmente para modelos grandes como BERT e CNNs, em que a regularização adequada do peso é essencial para o desempenho.


Kurtis Pykes 's photo
Author
Kurtis Pykes
LinkedIn
Temas

Principais cursos da DataCamp

Certificação disponível

curso

Aprendizagem profunda intermediária com PyTorch

4 hr
8.8K
Aprenda sobre arquiteturas fundamentais de aprendizagem profunda, como CNNs, RNNs, LSTMs e GRUs para modelar imagens e dados sequenciais.
Ver DetalhesRight Arrow
Iniciar curso
Ver maisRight Arrow
Relacionado

tutorial

Tutorial do Adam Optimizer: Intuição e implementação em Python

Compreender e implementar o otimizador Adam em Python. Com o PyTorch, você aprenderá a intuição, a matemática e as aplicações práticas do machine learning
Bex Tuychiev's photo

Bex Tuychiev

14 min

tutorial

Criando um transformador com o PyTorch

Saiba como criar um modelo Transformer usando o PyTorch, uma ferramenta avançada de aprendizado de máquina moderno.
Arjun Sarkar's photo

Arjun Sarkar

26 min

tutorial

Autoencodificadores variacionais: Como eles funcionam e por que são importantes

Aprenda os princípios básicos, as aplicações e os benefícios práticos dos autoencodificadores variacionais e acompanhe uma implementação passo a passo com o PyTorch.
Kurtis Pykes 's photo

Kurtis Pykes

30 min

tutorial

Como treinar um LLM com o PyTorch

Domine o processo de treinamento de grandes modelos de linguagem usando o PyTorch, desde a configuração inicial até a implementação final.
Zoumana Keita 's photo

Zoumana Keita

8 min

tutorial

Otimização em Python: Técnicas, pacotes e práticas recomendadas

Este artigo ensina a você sobre otimização numérica, destacando diferentes técnicas. Ele discute os pacotes Python, como SciPy, CVXPY e Pyomo, e fornece um notebook DataLab prático para você executar exemplos de código.
Kurtis Pykes 's photo

Kurtis Pykes

19 min

tutorial

Guia de torchchat do PyTorch: Configuração local com Python

Saiba como configurar o torchchat do PyTorch localmente com Python neste tutorial prático, que fornece orientação e exemplos passo a passo.
François Aubry's photo

François Aubry

Ver maisVer mais