Programa
LLMs são muito eficientes, mas muitas vezes podem ser um pouco lentos, o que não é ideal em cenários em que precisamos de velocidade. Decodificação especulativa é uma técnica projetada para acelerar os LLMs, gerando respostas mais rapidamente sem comprometer a qualidade.
Em essência, é uma maneira de "adivinhar" no processo de geração de texto, fazendo previsões sobre as palavras que podem vir a seguir, ao mesmo tempo em que permite a precisão e a profundidade que esperamos dos LLMs.
Neste blog, explicarei o que é a decodificação especulativa, como ela funciona e como implementá-la com os modelos Gemma 2.
O que é decodificação especulativa?
A decodificação especulativa acelera os LLMs ao incorporar um modelo menor e mais rápido que gera previsões preliminares. Esse modelo menor, geralmente chamado de modelo de "rascunho", gera um lote de tokens que o LLM principal, mais avançado, pode confirmar ou refinar. O modelo de rascunho funciona como uma primeira passagem, produzindo vários tokens que aceleram o processo de geração.
Em vez de o LLM principal gerar tokens sequencialmente, o modelo de rascunho fornece um conjunto de candidatos prováveis, e o modelo principal os avalia em paralelo. Esse método reduz a carga computacional do LLM principal ao descarregar as previsões iniciais, permitindo que ele se concentre apenas em correções ou validações.
Pense nisso como um escritor com um editor. O principal LLM é o escritor, capaz de produzir textos de alta qualidade, mas em um ritmo mais lento. Um modelo de "rascunho" menor e mais rápido atua como editor, gerando rapidamente possíveis continuações do texto. Em seguida, o LLM principal avalia essas sugestões, incorporando as mais precisas e descartando as demais. Isso permite que o LLM processe vários tokens ao mesmo tempo, acelerando a geração de texto.
Vamos dividir o processo de decodificação especulativa em etapas simples:
- Geração de rascunho: O modelo menor (por exemplo, Gemma2-2B-it) gera várias sugestões de tokens com base no prompt de entrada. Esses tokens são gerados de forma especulativa, o que significa que o modelo não tem certeza de que estão corretos, mas os fornece como tokens de "rascunho".
- Paralelo verificação: O modelo maior (por exemplo, Gemma2-9B-it) verifica esses tokens em paralelo, comparando a probabilidade deles com a distribuição aprendida do modelo. Se os tokens forem considerados aceitáveis, eles serão usados na saída final. Caso contrário, o modelo maior os corrige.
- Resultado final: Depois que os tokens são verificados (ou corrigidos), eles são passados para o usuário como o resultado final. Todo esse processo é muito mais rápido do que a decodificação tradicional de um token por vez.
Decodificação tradicional vs. Decodificação de dados. Decodificação especulativa
A decodificação tradicional processa os tokens um por vez, o que leva a uma alta latência, mas a decodificação especulativa permite que um modelo menor gere tokens em massa, com o modelo maior verificando-os. Isso pode reduzir o tempo de resposta em 30 a 40%, diminuindo a latência de 25 a 30 segundos para apenas 15 a 18 segundos.
Além disso, a decodificação especulativa otimiza o uso da memória, transferindo a maior parte da geração de tokens para o modelo menor, reduzindo os requisitos de memória de 26 GB para cerca de 14 GB e fazendo com que o no dispositivo inferência no dispositivo mais acessível.
Por fim, ele reduz as demandas de computação em 50%, pois o modelo maior apenas verifica em vez de gerar tokens, permitindo um desempenho mais suave em dispositivos móveis com energia limitada e evitando o superaquecimento.
Exemplo prático: Decodificação especulativa com modelos Gemma2
Para implementar um exemplo prático de decodificação especulativa usando os modelos Gemma2. Exploraremos como a decodificação especulativa se compara à inferência padrão em termos delatência e desempenho do .
Etapa 1: Configuração do modelo e do tokenizador
Para começar, importe as dependências e defina a semente.
Em seguida, verifique se a GPU está disponível na máquina em que você está operando. Isso é necessário principalmente para modelos grandes, como Gemma 2-9B-it ou LLama2-13B.
Por fim, carregamos os modelos pequeno e grande junto com seus tokenizadores. Aqui, estamos usando o modelo Gemma2-2b-it (instruído) para o modelo preliminar e o modelo Gemma2-9b-it para verificação.
Há alguns outros modelos que também podem ser usados como alternativa. Por exemplo:
- Gemma 7B (principal) e Gemma 2B (rascunho)
- Mixtral-8x7B (principal) e Mistral 7B (rascunho)
- Pythia 12B (principal) e Pythia 70M (rascunho)
- Llama 2 13B (principal) e TinyLlama 1.1B (rascunho)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
# Set Seed
set_seed(42)
# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the smaller Gemma2 model (draft generation)
small_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", device_map="auto")
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", torch_dtype=torch.bfloat16)
# Load the larger Gemma2 model (verification)
big_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", device_map="auto")
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", torch_dtype=torch.bfloat16)
Etapa 2: Inferência autorregressiva (normal)
Primeiro, realizamos a inferência apenas no modelo grande (Gemma2-9b-it) e geramos resultados. Comece por tokenização o prompt de entrada e movendo os tokens para o dispositivo correto (GPU, se disponível). O método generate
produz a saída do modelo, gerando até max_new_tokens
. O resultado é então decodificado a partir de IDs de token para um texto legível por humanos.
def normal_inference(big_model, big_tokenizer, prompt, max_new_tokens=50):
inputs = big_tokenizer(prompt, return_tensors='pt').to(device)
outputs = big_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
return big_tokenizer.decode(outputs[0], skip_special_tokens=True)
Etapa 3: Decodificação especulativa
Em seguida, tentamos o método de decodificação especulativa, no qual seguimos as seguintes etapas:
- Geração de rascunho: O modelo menor gera um rascunho do texto a partir do prompt fornecido.
- Verificação: Em seguida, o modelo maior verifica o rascunho calculando a probabilidade de log para cada token no rascunho.
- Cálculo da probabilidade lógica: Calculamos uma probabilidade média de log para determinar a probabilidade de o modelo grande considerar correto o rascunho do modelo pequeno.
def speculative_decoding(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
# Step 1: Use the small model to generate the draft
inputs = small_tokenizer(prompt, return_tensors='pt').to(device)
small_outputs = small_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
draft = small_tokenizer.decode(small_outputs[0], skip_special_tokens=True)
# Step 2: Verify the draft with the big model
big_inputs = big_tokenizer(draft, return_tensors='pt').to(device)
# Step 3: Calculate log-likelihood of the draft tokens under the large model
with torch.no_grad():
outputs = big_model(big_inputs['input_ids'])
log_probs = torch.log_softmax(outputs.logits, dim=-1)
draft_token_ids = big_inputs['input_ids']
log_likelihood = 0
for i in range(draft_token_ids.size(1) - 1):
token_id = draft_token_ids[0, i + 1]
log_likelihood += log_probs[0, i, token_id].item()
avg_log_likelihood = log_likelihood / (draft_token_ids.size(1) - 1)
# Return the draft and its log-likelihood score
return draft, avg_log_likelihood
Observação: O log-likelihood é o logaritmo da probabilidade que um modelo atribui a uma sequência específica de tokens. Aqui, ele reflete a probabilidade ou a "confiança" do modelo de que a sequência de tokens (o texto gerado) é válida, considerando os tokens anteriores.
Etapa 4: Medição da latência
Depois de implementar as duas técnicas, podemos medir suas respectivas latências. Para a decodificação especulativa, avaliamos o desempenho examinando o valor de log-likelihood. Um valor de log-likelihood próximo de zero, especialmente na faixa negativa, indica que o texto gerado é preciso.
def measure_latency(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
# Measure latency for normal inference (big model only)
start_time = time.time()
normal_output = normal_inference(big_model, big_tokenizer, prompt, max_new_tokens)
normal_inference_latency = time.time() - start_time
print(f"Normal Inference Output: {normal_output}")
print(f"Normal Inference Latency: {normal_inference_latency:.4f} seconds")
print("\n\n")
# Measure latency for speculative decoding
start_time = time.time()
speculative_output, log_likelihood = speculative_decoding(
small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
)
speculative_decoding_latency = time.time() - start_time
print(f"Speculative Decoding Output: {speculative_output}")
print(f"Speculative Decoding Latency: {speculative_decoding_latency:.4f} seconds")
print(f"Log Likelihood (Verification Score): {log_likelihood:.4f}")
return normal_inference_latency, speculative_decoding_latency
Isso retorna:
- Probabilidade logarítmica (pontuação de verificação): -0.5242
- Latência de inferência normal: 17,8899 segundos
- Latência de decodificação especulativa: 10,5291 segundos (cerca de 70% mais rápido)
A latência mais baixa se deve ao menor tempo gasto pelo modelo menor para a geração de texto e ao menor tempo gasto pelo modelo maior apenas para verificar o texto gerado.
Testando a decodificação especulativa em cinco prompts
Vamos comparar a decodificação especulativa com a inferência autorregressiva usando cinco prompts:
# List of 5 prompts
prompts = [
"The future of artificial intelligence is ",
"Machine learning is transforming the world by ",
"Natural language processing enables computers to understand ",
"Generative models like GPT-3 can create ",
"AI ethics and fairness are important considerations for "
]
# Inference settings
max_new_tokens = 200
# Initialize accumulators for latency, memory, and tokens per second
total_tokens_per_sec_normal = 0
total_tokens_per_sec_speculative = 0
total_normal_latency = 0
total_speculative_latency = 0
# Perform inference on each prompt and accumulate the results
for prompt in prompts:
normal_latency, speculative_latency, _, _, tokens_per_sec_normal, tokens_per_sec_speculative = measure_latency_and_memory(
small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
)
total_tokens_per_sec_normal += tokens_per_sec_normal
total_tokens_per_sec_speculative += tokens_per_sec_speculative
total_normal_latency += normal_latency
total_speculative_latency += speculative_latency
# Calculate averages
average_tokens_per_sec_normal = total_tokens_per_sec_normal / len(prompts)
average_tokens_per_sec_speculative = total_tokens_per_sec_speculative / len(prompts)
average_normal_latency = total_normal_latency / len(prompts)
average_speculative_latency = total_speculative_latency / len(prompts)
# Output the averages
print(f"Average Normal Inference Latency: {average_normal_latency:.4f} seconds")
print(f"Average Speculative Decoding Latency: {average_speculative_latency:.4f} seconds")
print(f"Average Normal Inference Tokens per second: {average_tokens_per_sec_normal:.2f} tokens/second")
print(f"Average Speculative Decoding Tokens per second: {average_tokens_per_sec_speculative:.2f} tokens/second")
Average Normal Inference Latency: 25.0876 seconds
Average Speculative Decoding Latency: 15.7802 seconds
Average Normal Inference Tokens per second: 7.97 tokens/second
Average Speculative Decoding Tokens per second: 12.68 tokens/second
Isso mostra que a decodificação especulativa é mais eficiente, gerando mais tokens por segundo do que a inferência normal. Essa melhoria se deve ao fato de o modelo menor lidar com a maior parte da geração de texto, enquanto a função do modelo maior se limita à verificação, reduzindo a carga computacional geral em termos de latência e memória.
Com esses requisitos de memória, podemos implantar facilmente técnicas de decodificação especulativa em dispositivos de borda e ganhar velocidade em nossos aplicativos no dispositivo, como chatbots, tradutores de idiomas, jogos e muito mais.
Decodificação especulativa otimizada com quantização
A abordagem acima é eficiente, mas há uma compensação entre a latência e a otimização da memória para a inferência no dispositivo. Para resolver isso, vamos aplicar a quantização a modelos pequenos e grandes. Você pode experimentar e tentar aplicar a quantização somente ao modelo grande, pois o modelo pequeno já ocupa o menor espaço.
A quantificação é aplicada a modelos menores e maiores usando a configuração BitsAndBytesConfig
da biblioteca Hugging Face transformers
. A quantização nos permite reduzir significativamente o uso da memória e, em muitos casos, melhorar a velocidade de inferência, convertendo os pesos do modelo em uma forma mais compacta.
Adicione o seguinte trecho de código ao código acima para que você veja os efeitos da quantização.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Enables 4-bit quantization
bnb_4bit_quant_type="nf4", # Specifies the quantization type (nf4)
bnb_4bit_compute_dtype=torch.bfloat16, # Uses bfloat16 for computation
bnb_4bit_use_double_quant=False, # Disables double quantization to avoid additional overhead
)
# Load the smaller and larger Gemma2 models with quantization applied
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", quantization_config=bnb_config)
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", quantization_config=bnb_config)
Aqui está uma rápida comparação das saídas para mostrar os efeitos da decodificação especulativa com e sem quantização:
Quantização de 4 bits (compressão de peso)
A configuração especifica load_in_4bit=True
, o que significa que os pesos do modelo são quantizados de seu formato original de ponto flutuante de 32 ou 16 bits para inteiros de 4 bits. Isso reduz o espaço de memória do modelo. A quantização comprime os pesos do modelo, o que nos permite armazená-los e operá-los com mais eficiência. Essas são aseconomias concretasde memória do :
- Ao reduzir a precisão de flutuantes de 32 ou 16 bits para inteiros de 4 bits, cada peso agora ocupa 1/4 ou 1/8 do espaço original, reduzindo significativamente o uso da memória.
- Isso se reflete no uso da memória como:
- Uso normal da memória de inferência: 26,458 MB
- Uso de memória de decodificação especulativa: 8.993 MB.
bfloat16 para computação (uso eficiente de Tensor Cores)
A configuração bnb_4bit_compute_dtype=torch.bfloat16
especifica que o cálculo é realizado em bfloat16 (BF16), um formato de ponto flutuante mais eficiente. O BF16 tem uma faixa dinâmica mais ampla do que o FP16, mas ocupa metade da memória em comparação com o FP32, o que o torna um bom equilíbrio entre precisão e desempenho.
O uso do BF16, especialmente em GPUs NVIDIA (como a A100), utiliza Tensor Cores, que são otimizados para operações do BF16. Isso resulta em multiplicações mais rápidas de matrizes e outros cálculos durante a inferência.
Para decodificação especulativa, observamos uma melhora na latência:
- Latência de inferência normal: 27,65 segundos
- Latência de decodificação especulativa: 15,56 segundos
O menor espaço de memória significa acesso mais rápido à memória e uso mais eficiente dos recursos da GPU, levando a uma geração mais rápida.
Tipo de quantização NF4 (precisão otimizada)
A opção bnb_4bit_quant_type="nf4"
especifica Norm-Four Quantization (NF4), que é otimizada para redes neurais. A quantização NF4 ajuda a manter a precisão de partes importantes do modelo, mesmo que os pesos sejam representados em 4 bits. Isso minimiza a degradação do desempenho do modelo em comparação com a quantização simples de 4 bits.
O NF4 ajuda a alcançar um equilíbrio entre a compactação da quantização de 4 bits e a precisão das previsões do modelo, garantindo que o desempenho permaneça próximo ao original e, ao mesmo tempo, reduzindo drasticamente o uso da memória.
Quantização dupla desativada
A quantização dupla (bnb_4bit_use_double_quant=False
) introduz uma camada adicional de quantização sobre os pesos de 4 bits, o que pode reduzir ainda mais o uso da memória, mas também aumenta a sobrecarga de computação. Nesse caso, a quantização dupla é desativada para evitar que você diminua a velocidade da inferência.
Aplicações da decodificação especulativa
As possíveis aplicações da decodificação especulativa são vastas e empolgantes. Aqui estão alguns exemplos:
- Chatbots e assistentes virtuais: Para que essas conversas com a IA sejam mais naturais e fluidas, com tempos de resposta mais rápidos.
- Tradução em tempo real: A decodificação especulativa reduz a latência na tradução em tempo real.
- Geração de conteúdo: A decodificação especulativa acelera a criação de conteúdo.
- Aplicativos de jogos e interativos: Para melhorar a capacidade de resposta de personagens ou elementos de jogos acionados por IA para uma experiência mais imersiva, a decodificação especulativa pode nos ajudar a obter respostas quase em tempo real.
Desafios da decodificação especulativa
Embora a decodificação especulativa seja muito promissora, ela não está isenta de desafios:
- Sobrecarga de memória: A manutenção de vários estados de modelo (tanto para o rascunho quanto para o LLM principal) pode aumentar o uso da memória, especialmente quando modelos maiores são usados para verificação.
- Ajuste do modelo de rascunho: A escolha do modelo de rascunho correto e o ajuste de seus parâmetros são fundamentais para atingir o equilíbrio certo entre velocidade e precisão, pois um modelo excessivamente simplista pode levar a falhas frequentes na verificação.
- Complexidade da implementação: A implementação da decodificação especulativa é mais complexa do que os métodos tradicionais, exigindo uma sincronização cuidadosa entre o modelo de rascunho pequeno e o modelo de verificação maior, além de um tratamento eficiente de erros.
- Compatibilidade com estratégias de decodificação: Atualmente, a decodificação especulativa é compatível apenas com pesquisa e amostragem gananciosas, limitando seu uso a estratégias de decodificação mais sofisticadas, como pesquisa de feixe ou amostragem diversificada.
- Sobrecarga de verificação: Se o modelo de rascunho menor gerar tokens que frequentemente falham na verificação, os ganhos de eficiência podem ser reduzidos, pois o modelo maior precisará regenerar partes da saída, o que pode anular as vantagens de velocidade.
- Suporte limitado para processamento em lote: Normalmente, a decodificação especulativa não oferece suporte a entradas em lote, o que pode reduzir sua eficácia em sistemas de alto rendimento que exigem o processamento paralelo de várias solicitações.
Conclusão
A decodificação especulativa é uma técnica avançada que tem o potencial de revolucionar a maneira como interagimos com grandes modelos de linguagem. Ele pode acelerar significativamente a inferência LLM sem comprometer a qualidade do texto gerado. Embora existam desafios a serem superados, os benefícios da decodificação especulativa são inegáveis, e podemos esperar que sua adoção cresça nos próximos anos, possibilitando uma nova geração de aplicativos de IA mais rápidos, mais responsivos e mais eficientes.

Sou Google Developers Expert em ML (Gen AI), Kaggle 3x Expert e Women Techmakers Ambassador com mais de 3 anos de experiência em tecnologia. Fui cofundador de uma startup de tecnologia de saúde em 2020 e estou fazendo mestrado em ciência da computação na Georgia Tech, com especialização em machine learning.