programa
Decodificação especulativa: Um guia com exemplos de implementação
LLMs são muito eficientes, mas muitas vezes podem ser um pouco lentos, o que não é ideal em cenários em que precisamos de velocidade. Decodificação especulativa é uma técnica projetada para acelerar os LLMs, gerando respostas mais rapidamente sem comprometer a qualidade.
Em essência, é uma maneira de "adivinhar" no processo de geração de texto, fazendo previsões sobre as palavras que podem vir a seguir, ao mesmo tempo em que permite a precisão e a profundidade que esperamos dos LLMs.
Neste blog, explicarei o que é a decodificação especulativa, como ela funciona e como implementá-la com os modelos Gemma 2.
O que é decodificação especulativa?
A decodificação especulativa acelera os LLMs ao incorporar um modelo menor e mais rápido que gera previsões preliminares. Esse modelo menor, geralmente chamado de modelo de "rascunho", gera um lote de tokens que o LLM principal, mais avançado, pode confirmar ou refinar. O modelo de rascunho funciona como uma primeira passagem, produzindo vários tokens que aceleram o processo de geração.
Em vez de o LLM principal gerar tokens sequencialmente, o modelo de rascunho fornece um conjunto de candidatos prováveis, e o modelo principal os avalia em paralelo. Esse método reduz a carga computacional do LLM principal ao descarregar as previsões iniciais, permitindo que ele se concentre apenas em correções ou validações.
Pense nisso como um escritor com um editor. O principal LLM é o escritor, capaz de produzir textos de alta qualidade, mas em um ritmo mais lento. Um modelo de "rascunho" menor e mais rápido atua como editor, gerando rapidamente possíveis continuações do texto. Em seguida, o LLM principal avalia essas sugestões, incorporando as mais precisas e descartando as demais. Isso permite que o LLM processe vários tokens ao mesmo tempo, acelerando a geração de texto.
Vamos dividir o processo de decodificação especulativa em etapas simples:
- Geração de rascunho: O modelo menor (por exemplo, Gemma2-2B-it) gera várias sugestões de tokens com base no prompt de entrada. Esses tokens são gerados de forma especulativa, o que significa que o modelo não tem certeza de que estão corretos, mas os fornece como tokens de "rascunho".
- Paralelo verificação: O modelo maior (por exemplo, Gemma2-9B-it) verifica esses tokens em paralelo, comparando a probabilidade deles com a distribuição aprendida do modelo. Se os tokens forem considerados aceitáveis, eles serão usados na saída final. Caso contrário, o modelo maior os corrige.
- Resultado final: Depois que os tokens são verificados (ou corrigidos), eles são passados para o usuário como o resultado final. Todo esse processo é muito mais rápido do que a decodificação tradicional de um token por vez.
Decodificação tradicional vs. Decodificação de dados. Decodificação especulativa
A decodificação tradicional processa os tokens um por vez, o que leva a uma alta latência, mas a decodificação especulativa permite que um modelo menor gere tokens em massa, com o modelo maior verificando-os. Isso pode reduzir o tempo de resposta em 30 a 40%, diminuindo a latência de 25 a 30 segundos para apenas 15 a 18 segundos.
Além disso, a decodificação especulativa otimiza o uso da memória, transferindo a maior parte da geração de tokens para o modelo menor, reduzindo os requisitos de memória de 26 GB para cerca de 14 GB e fazendo com que o no dispositivo inferência no dispositivo mais acessível.
Por fim, ele reduz as demandas de computação em 50%, pois o modelo maior apenas verifica em vez de gerar tokens, permitindo um desempenho mais suave em dispositivos móveis com energia limitada e evitando o superaquecimento.
Exemplo prático: Decodificação especulativa com modelos Gemma2
Para implementar um exemplo prático de decodificação especulativa usando os modelos Gemma2. Exploraremos como a decodificação especulativa se compara à inferência padrão em termos delatência e desempenho do .
Etapa 1: Configuração do modelo e do tokenizador
Para começar, importe as dependências e defina a semente.
Em seguida, verifique se a GPU está disponível na máquina em que você está operando. Isso é necessário principalmente para modelos grandes, como Gemma 2-9B-it ou LLama2-13B.
Por fim, carregamos os modelos pequeno e grande junto com seus tokenizadores. Aqui, estamos usando o modelo Gemma2-2b-it (instruído) para o modelo preliminar e o modelo Gemma2-9b-it para verificação.
Há alguns outros modelos que também podem ser usados como alternativa. Por exemplo:
- Gemma 7B (principal) e Gemma 2B (rascunho)
- Mixtral-8x7B (principal) e Mistral 7B (rascunho)
- Pythia 12B (principal) e Pythia 70M (rascunho)
- Llama 2 13B (principal) e TinyLlama 1.1B (rascunho)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
# Set Seed
set_seed(42)
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the smaller Gemma2 model (draft generation)
small_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", device_map="auto")
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", torch_dtype=torch.bfloat16)
# Load the larger Gemma2 model (verification)
big_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", device_map="auto")
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", torch_dtype=torch.bfloat16)
Etapa 2: Inferência autorregressiva (normal)
Primeiro, realizamos a inferência apenas no modelo grande (Gemma2-9b-it) e geramos resultados. Comece por tokenização o prompt de entrada e movendo os tokens para o dispositivo correto (GPU, se disponível). O método generate
produz a saída do modelo, gerando até max_new_tokens
. O resultado é então decodificado a partir de IDs de token para um texto legível por humanos.
def normal_inference(big_model, big_tokenizer, prompt, max_new_tokens=50):
inputs = big_tokenizer(prompt, return_tensors='pt').to(device)
outputs = big_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
return big_tokenizer.decode(outputs[0], skip_special_tokens=True)
Etapa 3: Decodificação especulativa
Em seguida, tentamos o método de decodificação especulativa, no qual seguimos as seguintes etapas:
- Geração de rascunho: O modelo menor gera um rascunho do texto a partir do prompt fornecido.
- Verificação: Em seguida, o modelo maior verifica o rascunho calculando a probabilidade de log para cada token no rascunho.
- Cálculo da probabilidade lógica: Calculamos uma probabilidade média de log para determinar a probabilidade de o modelo grande considerar correto o rascunho do modelo pequeno.
def speculative_decoding(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
# Step 1: Use the small model to generate the draft
inputs = small_tokenizer(prompt, return_tensors='pt').to(device)
small_outputs = small_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
draft = small_tokenizer.decode(small_outputs[0], skip_special_tokens=True)
# Step 2: Verify the draft with the big model
big_inputs = big_tokenizer(draft, return_tensors='pt').to(device)
# Step 3: Calculate log-likelihood of the draft tokens under the large model
with torch.no_grad():
outputs = big_model(big_inputs['input_ids'])
log_probs = torch.log_softmax(outputs.logits, dim=-1)
draft_token_ids = big_inputs['input_ids']
log_likelihood = 0
for i in range(draft_token_ids.size(1) - 1):
token_id = draft_token_ids[0, i + 1]
log_likelihood += log_probs[0, i, token_id].item()
avg_log_likelihood = log_likelihood / (draft_token_ids.size(1) - 1)
# Return the draft and its log-likelihood score
return draft, avg_log_likelihood
Observação: O log-likelihood é o logaritmo da probabilidade que um modelo atribui a uma sequência específica de tokens. Aqui, ele reflete a probabilidade ou a "confiança" do modelo de que a sequência de tokens (o texto gerado) é válida, considerando os tokens anteriores.
Etapa 4: Medição da latência
Depois de implementar as duas técnicas, podemos medir suas respectivas latências. Para a decodificação especulativa, avaliamos o desempenho examinando o valor de log-likelihood. Um valor de log-likelihood próximo de zero, especialmente na faixa negativa, indica que o texto gerado é preciso.
def measure_latency(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
# Measure latency for normal inference (big model only)
start_time = time.time()
normal_output = normal_inference(big_model, big_tokenizer, prompt, max_new_tokens)
normal_inference_latency = time.time() - start_time
print(f"Normal Inference Output: {normal_output}")
print(f"Normal Inference Latency: {normal_inference_latency:.4f} seconds")
print("\n\n")
# Measure latency for speculative decoding
start_time = time.time()
speculative_output, log_likelihood = speculative_decoding(
small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
)
speculative_decoding_latency = time.time() - start_time
print(f"Speculative Decoding Output: {speculative_output}")
print(f"Speculative Decoding Latency: {speculative_decoding_latency:.4f} seconds")
print(f"Log Likelihood (Verification Score): {log_likelihood:.4f}")
return normal_inference_latency, speculative_decoding_latency
Isso retorna:
- Probabilidade logarítmica (pontuação de verificação): -0.5242
- Latência de inferência normal: 17,8899 segundos
- Latência de decodificação especulativa: 10,5291 segundos (cerca de 70% mais rápido)
A latência mais baixa se deve ao menor tempo gasto pelo modelo menor para a geração de texto e ao menor tempo gasto pelo modelo maior apenas para verificar o texto gerado.
Testando a decodificação especulativa em cinco prompts
Vamos comparar a decodificação especulativa com a inferência autorregressiva usando cinco prompts:
# List of 5 prompts
prompts = [
"The future of artificial intelligence is ",
"Machine learning is transforming the world by ",
"Natural language processing enables computers to understand ",
"Generative models like GPT-3 can create ",
"AI ethics and fairness are important considerations for "
]
# Inference settings
max_new_tokens = 200
# Initialize accumulators for latency, memory, and tokens per second
total_tokens_per_sec_normal = 0
total_tokens_per_sec_speculative = 0
total_normal_latency = 0
total_speculative_latency = 0
# Perform inference on each prompt and accumulate the results
for prompt in prompts:
normal_latency, speculative_latency, _, _, tokens_per_sec_normal, tokens_per_sec_speculative = measure_latency_and_memory(
small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
)
total_tokens_per_sec_normal += tokens_per_sec_normal
total_tokens_per_sec_speculative += tokens_per_sec_speculative
total_normal_latency += normal_latency
total_speculative_latency += speculative_latency
# Calculate averages
average_tokens_per_sec_normal = total_tokens_per_sec_normal / len(prompts)
average_tokens_per_sec_speculative = total_tokens_per_sec_speculative / len(prompts)
average_normal_latency = total_normal_latency / len(prompts)
average_speculative_latency = total_speculative_latency / len(prompts)
# Output the averages
print(f"Average Normal Inference Latency: {average_normal_latency:.4f} seconds")
print(f"Average Speculative Decoding Latency: {average_speculative_latency:.4f} seconds")
print(f"Average Normal Inference Tokens per second: {average_tokens_per_sec_normal:.2f} tokens/second")
print(f"Average Speculative Decoding Tokens per second: {average_tokens_per_sec_speculative:.2f} tokens/second")
Average Normal Inference Latency: 25.0876 seconds
Average Speculative Decoding Latency: 15.7802 seconds
Average Normal Inference Tokens per second: 7.97 tokens/second
Average Speculative Decoding Tokens per second: 12.68 tokens/second
Isso mostra que a decodificação especulativa é mais eficiente, gerando mais tokens por segundo do que a inferência normal. Essa melhoria se deve ao fato de o modelo menor lidar com a maior parte da geração de texto, enquanto a função do modelo maior se limita à verificação, reduzindo a carga computacional geral em termos de latência e memória.
Com esses requisitos de memória, podemos implantar facilmente técnicas de decodificação especulativa em dispositivos de borda e ganhar velocidade em nossos aplicativos no dispositivo, como chatbots, tradutores de idiomas, jogos e muito mais.
Decodificação especulativa otimizada com quantização
A abordagem acima é eficiente, mas há uma compensação entre a latência e a otimização da memória para a inferência no dispositivo. Para resolver isso, vamos aplicar a quantização a modelos pequenos e grandes. Você pode experimentar e tentar aplicar a quantização somente ao modelo grande, pois o modelo pequeno já ocupa o menor espaço.
A quantificação é aplicada a modelos menores e maiores usando a configuração BitsAndBytesConfig
da biblioteca Hugging Face transformers
. A quantização nos permite reduzir significativamente o uso da memória e, em muitos casos, melhorar a velocidade de inferência, convertendo os pesos do modelo em uma forma mais compacta.
Adicione o seguinte trecho de código ao código acima para que você veja os efeitos da quantização.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Enables 4-bit quantization
bnb_4bit_quant_type="nf4", # Specifies the quantization type (nf4)
bnb_4bit_compute_dtype=torch.bfloat16, # Uses bfloat16 for computation
bnb_4bit_use_double_quant=False, # Disables double quantization to avoid additional overhead
)
# Load the smaller and larger Gemma2 models with quantization applied
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", quantization_config=bnb_config)
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", quantization_config=bnb_config)
Aqui está uma rápida comparação das saídas para mostrar os efeitos da decodificação especulativa com e sem quantização:
Quantização de 4 bits (compressão de peso)
A configuração especifica load_in_4bit=True
, o que significa que os pesos do modelo são quantizados de seu formato original de ponto flutuante de 32 ou 16 bits para inteiros de 4 bits. Isso reduz o espaço de memória do modelo. A quantização comprime os pesos do modelo, o que nos permite armazená-los e operá-los com mais eficiência. Essas são aseconomias concretasde memória do :
- Ao reduzir a precisão de flutuantes de 32 ou 16 bits para inteiros de 4 bits, cada peso agora ocupa 1/4 ou 1/8 do espaço original, reduzindo significativamente o uso da memória.
- Isso se reflete no uso da memória como:
- Uso normal da memória de inferência: 26,458 MB
- Uso de memória de decodificação especulativa: 8.993 MB.
bfloat16 para computação (uso eficiente de Tensor Cores)
A configuração bnb_4bit_compute_dtype=torch.bfloat16
especifica que o cálculo é realizado em bfloat16 (BF16), um formato de ponto flutuante mais eficiente. O BF16 tem uma faixa dinâmica mais ampla do que o FP16, mas ocupa metade da memória em comparação com o FP32, o que o torna um bom equilíbrio entre precisão e desempenho.
O uso do BF16, especialmente em GPUs NVIDIA (como a A100), utiliza Tensor Cores, que são otimizados para operações do BF16. Isso resulta em multiplicações mais rápidas de matrizes e outros cálculos durante a inferência.
Para decodificação especulativa, observamos uma melhora na latência:
- Latência de inferência normal: 27,65 segundos
- Latência de decodificação especulativa: 15,56 segundos
O menor espaço de memória significa acesso mais rápido à memória e uso mais eficiente dos recursos da GPU, levando a uma geração mais rápida.
Tipo de quantização NF4 (precisão otimizada)
A opção bnb_4bit_quant_type="nf4"
especifica Norm-Four Quantization (NF4), que é otimizada para redes neurais. A quantização NF4 ajuda a manter a precisão de partes importantes do modelo, mesmo que os pesos sejam representados em 4 bits. Isso minimiza a degradação do desempenho do modelo em comparação com a quantização simples de 4 bits.
O NF4 ajuda a alcançar um equilíbrio entre a compactação da quantização de 4 bits e a precisão das previsões do modelo, garantindo que o desempenho permaneça próximo ao original e, ao mesmo tempo, reduzindo drasticamente o uso da memória.
Quantização dupla desativada
A quantização dupla (bnb_4bit_use_double_quant=False
) introduz uma camada adicional de quantização sobre os pesos de 4 bits, o que pode reduzir ainda mais o uso da memória, mas também aumenta a sobrecarga de computação. Nesse caso, a quantização dupla é desativada para evitar que você diminua a velocidade da inferência.
Aplicações da decodificação especulativa
As possíveis aplicações da decodificação especulativa são vastas e empolgantes. Aqui estão alguns exemplos:
- Chatbots e assistentes virtuais: Para que essas conversas com a IA sejam mais naturais e fluidas, com tempos de resposta mais rápidos.
- Tradução em tempo real: A decodificação especulativa reduz a latência na tradução em tempo real.
- Geração de conteúdo: A decodificação especulativa acelera a criação de conteúdo.
- Aplicativos de jogos e interativos: Para melhorar a capacidade de resposta de personagens ou elementos de jogos acionados por IA para uma experiência mais imersiva, a decodificação especulativa pode nos ajudar a obter respostas quase em tempo real.
Desafios da decodificação especulativa
Embora a decodificação especulativa seja muito promissora, ela não está isenta de desafios:
- Sobrecarga de memória: A manutenção de vários estados de modelo (tanto para o rascunho quanto para o LLM principal) pode aumentar o uso da memória, especialmente quando modelos maiores são usados para verificação.
- Ajuste do modelo de rascunho: A escolha do modelo de rascunho correto e o ajuste de seus parâmetros são fundamentais para atingir o equilíbrio certo entre velocidade e precisão, pois um modelo excessivamente simplista pode levar a falhas frequentes na verificação.
- Complexidade da implementação: A implementação da decodificação especulativa é mais complexa do que os métodos tradicionais, exigindo uma sincronização cuidadosa entre o modelo de rascunho pequeno e o modelo de verificação maior, além de um tratamento eficiente de erros.
- Compatibilidade com estratégias de decodificação: Atualmente, a decodificação especulativa é compatível apenas com pesquisa e amostragem gananciosas, limitando seu uso a estratégias de decodificação mais sofisticadas, como pesquisa de feixe ou amostragem diversificada.
- Sobrecarga de verificação: Se o modelo de rascunho menor gerar tokens que frequentemente falham na verificação, os ganhos de eficiência podem ser reduzidos, pois o modelo maior precisará regenerar partes da saída, o que pode anular as vantagens de velocidade.
- Suporte limitado para processamento em lote: Normalmente, a decodificação especulativa não oferece suporte a entradas em lote, o que pode reduzir sua eficácia em sistemas de alto rendimento que exigem o processamento paralelo de várias solicitações.
Conclusão
A decodificação especulativa é uma técnica avançada que tem o potencial de revolucionar a maneira como interagimos com grandes modelos de linguagem. Ele pode acelerar significativamente a inferência LLM sem comprometer a qualidade do texto gerado. Embora existam desafios a serem superados, os benefícios da decodificação especulativa são inegáveis, e podemos esperar que sua adoção cresça nos próximos anos, possibilitando uma nova geração de aplicativos de IA mais rápidos, mais responsivos e mais eficientes.
Sou Google Developers Expert em ML (Gen AI), Kaggle 3x Expert e Women Techmakers Ambassador com mais de 3 anos de experiência em tecnologia. Fui cofundador de uma startup de tecnologia de saúde em 2020 e estou fazendo mestrado em ciência da computação na Georgia Tech, com especialização em machine learning.
Aprenda IA com estes cursos!
curso
Developing AI Systems with the OpenAI API
curso
AI Security and Risk Management
blog
O que é um modelo generativo?
blog
O que é IA? Um guia rápido para iniciantes
tutorial
Autoencodificadores variacionais: Como eles funcionam e por que são importantes
tutorial
DeepSeek-Coder-V2 Tutorial: Exemplos, instalação, padrões de referência
Dimitri Didmanidze
8 min
tutorial
Visão GPT-4: Um guia abrangente para iniciantes
tutorial