Saltar al contenido principal

Descodificación especulativa: Una guía con ejemplos de aplicación

Aprende qué es la descodificación especulativa, cómo funciona, cuándo utilizarla y cómo implementarla utilizando modelos Gemma2.
Actualizado 8 nov 2024  · 12 min de lectura

Los LLM son tan potentes, pero a menudo pueden ser un poco lentos, y esto no es lo ideal en situaciones en las que necesitamos velocidad. Descodificación especulativa es una técnica diseñada para acelerar los LLM generando respuestas más rápidamente sin comprometer la calidad.

En esencia, es una forma de "adivinar por adelantado" en el proceso de generación de texto, haciendo predicciones sobre las palabras que podrían venir a continuación sin dejar de permitir la precisión y profundidad que esperamos de los LLM.

En este blog, te explicaré qué es la descodificación especulativa, cómo funciona y cómo aplicarla con los modelos Gemma 2.

¿Qué es la descodificación especulativa?

La descodificación especulativa acelera los LLM incorporando un modelo más pequeño y rápido que genera predicciones preliminares. Este modelo más pequeño, a menudo llamado modelo "borrador", genera un lote de fichas que el LLM principal, más potente, puede confirmar o refinar. El modelo borrador actúa como una primera pasada, produciendo múltiples fichas que aceleran el proceso de generación.

En lugar de que el LLM principal genere tokens secuencialmente, el modelo borrador proporciona un conjunto de candidatos probables, y el modelo principal los evalúa en paralelo. Este método reduce la carga computacional del LLM principal al descargar las predicciones iniciales, permitiéndole centrarse sólo en las correcciones o validaciones.

Canal de descodificación especulativa

Piensa en ello como un escritor con un editor. El LLM principal es el escritor, capaz de producir textos de alta calidad pero a un ritmo más lento. Un modelo de "borrador" más pequeño y rápido actúa como editor, generando rápidamente posibles continuaciones del texto. A continuación, el LLM principal evalúa estas sugerencias, incorporando las acertadas y descartando el resto. Esto permite al LLM procesar varios tokens simultáneamente, acelerando la generación de texto.

Desglosemos el proceso de descodificación especulativa en sencillos pasos:

  1. ​Generación de borradores: El modelo más pequeño (por ejemplo, Gemma2-2B-it) genera múltiples sugerencias de fichas basándose en la petición de entrada. Estas fichas se generan especulativamente, lo que significa que el modelo no está seguro de que sean correctas, pero las proporciona como fichas "borrador".
  2. En paralelo ​verificación: El modelo mayor (por ejemplo, Gemma2-9B-it) verifica estos tokens en paralelo, comprobando su probabilidad con la distribución aprendida del modelo. Si las fichas se consideran aceptables, se utilizan en el resultado final. Si no, el modelo mayor los corrige.
  3. Resultado final: Una vez verificados (o corregidos), se pasan al usuario como salida final. Todo este proceso es mucho más rápido que la descodificación tradicional de una ficha a la vez.

Descodificación tradicional frente a Descodificación especulativa

La descodificación tradicional procesa los tokens de uno en uno, lo que provoca una gran latencia, pero la descodificación especulativa permite que un modelo más pequeño genere tokens en masa, y que el modelo más grande los verifique. Esto puede reducir el tiempo de respuesta en un 30-40%, recortando la latencia de 25-30 segundos a tan sólo 15-18 segundos.

Descodificación tradicional frente a descodificación especulativa

Además, la descodificación especulativa optimiza el uso de memoria desplazando la mayor parte de la generación de tokens al modelo más pequeño, reduciendo los requisitos de memoria de 26 GB a unos 14 GB y haciendo que en el dispositivo en el dispositivo.

Por último, reduce las demandas de cálculo en un 50%, ya que el modelo más grande sólo verifica en lugar de generar fichas, lo que permite un rendimiento más suave en dispositivos móviles con potencia limitada y evita el sobrecalentamiento.

Ejemplo práctico: Descodificación especulativa con modelos Gemma2

Poner en práctica un ejemplo práctico de descodificación especulativa utilizando los modelos Gemma2. Exploraremos cómo se compara la descodificación especulativa con la inferencia estándar, tanto en términos delatencia como de rendimiento.

Paso 1: Configuración del modelo y del tokenizador

Para empezar, importa las dependencias y establece la semilla.

A continuación, comprueba si la GPU está disponible en la máquina en la que estás operando. Esto es necesario principalmente para modelos grandes como Gemma 2-9B-it o LLama2-13B.

Por último, cargamos el modelo pequeño y el grande junto con sus tokenizadores. Aquí utilizamos el modelo Gemma2-2b-it (instruir) para el proyecto de modelo y el modelo Gemma2-9b-it para la verificación.

También hay otros modelos que pueden utilizarse alternativamente. Por ejemplo:

  • Gemma 7B (principal) y Gemma 2B (borrador)
  • Mixtral-8x7B (principal) y Mistral 7B (borrador)
  • Pythia 12B (principal) y Pythia 70M (borrador)
  • Llama 2 13B (principal) y TinyLlama 1.1B (borrador)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed

# Set Seed
set_seed(42)

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the smaller Gemma2 model (draft generation)
small_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", device_map="auto")
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", torch_dtype=torch.bfloat16)

# Load the larger Gemma2 model (verification)
big_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", device_map="auto")
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", torch_dtype=torch.bfloat16)

Paso 2: Inferencia autorregresiva (normal) 

En primer lugar, realizamos la inferencia sólo en el modelo grande (Gemma2-9b-it) y generamos el resultado. Empieza por tokenizando el prompt de entrada y moviendo los tokens al dispositivo correcto (GPU si está disponible). El método generate produce la salida del modelo, generando hasta max_new_tokens. A continuación, el resultado se descodifica a partir de los identificadores de los tokens para volver a convertirlos en texto legible por humanos.

def normal_inference(big_model, big_tokenizer, prompt, max_new_tokens=50):
    inputs = big_tokenizer(prompt, return_tensors='pt').to(device)
    outputs = big_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
    return big_tokenizer.decode(outputs[0], skip_special_tokens=True)

Paso 3: Descodificación especulativa 

A continuación, probamos el método de descodificación especulativa, en el que seguimos los siguientes pasos:

  1. Generación de borradores: El modelo más pequeño genera un borrador del texto a partir de la indicación dada.
  2. Verificación: A continuación, el modelo mayor verifica el borrador calculando la log-verosimilitud de cada token del borrador.
  3. Cálculo de la log-verosimilitud: Calculamos una log-verosimilitud media para determinar la probabilidad de que el modelo grande considere correcto el esbozo del modelo pequeño.
def speculative_decoding(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
    # Step 1: Use the small model to generate the draft
    inputs = small_tokenizer(prompt, return_tensors='pt').to(device)
    small_outputs = small_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
    draft = small_tokenizer.decode(small_outputs[0], skip_special_tokens=True)

    # Step 2: Verify the draft with the big model
    big_inputs = big_tokenizer(draft, return_tensors='pt').to(device)

    # Step 3: Calculate log-likelihood of the draft tokens under the large model
    with torch.no_grad():
        outputs = big_model(big_inputs['input_ids'])
        log_probs = torch.log_softmax(outputs.logits, dim=-1)

    draft_token_ids = big_inputs['input_ids']
    log_likelihood = 0
    for i in range(draft_token_ids.size(1) - 1):
        token_id = draft_token_ids[0, i + 1]
        log_likelihood += log_probs[0, i, token_id].item()

    avg_log_likelihood = log_likelihood / (draft_token_ids.size(1) - 1)

    # Return the draft and its log-likelihood score
    return draft, avg_log_likelihood

Nota: La log-verosimilitud es el logaritmo de la probabilidad que un modelo asigna a una secuencia concreta de fichas. Aquí, refleja la probabilidad o "confianza" que tiene el modelo en que la secuencia de tokens (el texto generado) es válida dados los tokens anteriores.

Paso 4: Medir la latencia

Tras aplicar ambas técnicas, podemos medir sus respectivas latencias. Para la descodificación especulativa, evaluamos el rendimiento examinando el valor de log-verosimilitud. Un valor de log-verosimilitud próximo a cero, sobre todo en el rango negativo, indica que el texto generado es exacto.

def measure_latency(small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens=50):
    # Measure latency for normal inference (big model only)
    start_time = time.time()
    normal_output = normal_inference(big_model, big_tokenizer, prompt, max_new_tokens)
    normal_inference_latency = time.time() - start_time
    print(f"Normal Inference Output: {normal_output}")
    print(f"Normal Inference Latency: {normal_inference_latency:.4f} seconds")
    print("\n\n")

    # Measure latency for speculative decoding
    start_time = time.time()
    speculative_output, log_likelihood = speculative_decoding(
        small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
    )
    speculative_decoding_latency = time.time() - start_time
    print(f"Speculative Decoding Output: {speculative_output}")
    print(f"Speculative Decoding Latency: {speculative_decoding_latency:.4f} seconds")
    print(f"Log Likelihood (Verification Score): {log_likelihood:.4f}")

    return normal_inference_latency, speculative_decoding_latency

Esto vuelve:

  • Probabilidad logarítmica (puntuación de verificación): -0.5242
  • Latencia normal de inferencia: 17,8899 segundos
  • Latencia de descodificación especulativa: 10,5291 segundos (aproximadamente un 70% más rápido)

La menor latencia se debe a que el modelo más pequeño tarda menos en generar el texto y el modelo más grande tarda menos sólo en verificar el texto generado.

Pruebas de descodificación especulativa con cinco instrucciones

Comparemos la descodificación especulativa con la inferencia autorregresiva utilizando cinco indicaciones:

# List of 5 prompts
prompts = [
    "The future of artificial intelligence is ",
    "Machine learning is transforming the world by ",
    "Natural language processing enables computers to understand ",
    "Generative models like GPT-3 can create ",
    "AI ethics and fairness are important considerations for "
]

# Inference settings
max_new_tokens = 200

# Initialize accumulators for latency, memory, and tokens per second
total_tokens_per_sec_normal = 0
total_tokens_per_sec_speculative = 0
total_normal_latency = 0
total_speculative_latency = 0

# Perform inference on each prompt and accumulate the results
for prompt in prompts:
    normal_latency, speculative_latency, _, _, tokens_per_sec_normal, tokens_per_sec_speculative = measure_latency_and_memory(
        small_model, big_model, small_tokenizer, big_tokenizer, prompt, max_new_tokens
    )
    total_tokens_per_sec_normal += tokens_per_sec_normal
    total_tokens_per_sec_speculative += tokens_per_sec_speculative
    total_normal_latency += normal_latency
    total_speculative_latency += speculative_latency

# Calculate averages
average_tokens_per_sec_normal = total_tokens_per_sec_normal / len(prompts)
average_tokens_per_sec_speculative = total_tokens_per_sec_speculative / len(prompts)
average_normal_latency = total_normal_latency / len(prompts)
average_speculative_latency = total_speculative_latency / len(prompts)

# Output the averages
print(f"Average Normal Inference Latency: {average_normal_latency:.4f} seconds")
print(f"Average Speculative Decoding Latency: {average_speculative_latency:.4f} seconds")
print(f"Average Normal Inference Tokens per second: {average_tokens_per_sec_normal:.2f} tokens/second")
print(f"Average Speculative Decoding Tokens per second: {average_tokens_per_sec_speculative:.2f} tokens/second")
Average Normal Inference Latency: 25.0876 seconds
Average Speculative Decoding Latency: 15.7802 seconds
Average Normal Inference Tokens per second: 7.97 tokens/second
Average Speculative Decoding Tokens per second: 12.68 tokens/second

Esto demuestra que la descodificación especulativa es más eficaz, ya que genera más fichas por segundo que la inferencia normal. Esta mejora se debe a que el modelo más pequeño se encarga de la mayor parte de la generación del texto, mientras que el papel del modelo más grande se limita a la verificación, lo que reduce la carga computacional total en términos de latencia y memoria. 

Con estos requisitos de memoria, podemos implantar fácilmente técnicas de descodificación especulativa en los dispositivos periféricos y aumentar la velocidad de nuestras aplicaciones en los dispositivos, como chatbots, traductores de idiomas, juegos, etc.

Descodificación especulativa optimizada con cuantificación  

El planteamiento anterior es eficiente, pero hay un compromiso entre la latencia y la optimización de la memoria para la inferencia en el dispositivo. Para solucionarlo, apliquemos la cuantización tanto a los modelos pequeños como a los grandes. Puedes experimentar y probar a aplicar la cuantización sólo al modelo grande, puesto que el modelo pequeño ya ocupa el menor espacio.

La cuantificación se aplica a modelos más pequeños y más grandes utilizando la configuración BitsAndBytesConfig de la biblioteca Cara abrazada transformers. La cuantización nos permite reducir significativamente el uso de memoria y, en muchos casos, mejorar la velocidad de inferencia convirtiendo los pesos del modelo a una forma más compacta.

Añade el siguiente fragmento de código al código anterior para comprobar los efectos de la cuantización.

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Enables 4-bit quantization
    bnb_4bit_quant_type="nf4",  # Specifies the quantization type (nf4)
    bnb_4bit_compute_dtype=torch.bfloat16,  # Uses bfloat16 for computation
    bnb_4bit_use_double_quant=False,  # Disables double quantization to avoid additional overhead
)

# Load the smaller and larger Gemma2 models with quantization applied
small_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", device_map="auto", quantization_config=bnb_config)
big_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", device_map="auto", quantization_config=bnb_config)

He aquí una rápida comparación de los resultados para mostrar los efectos de la descodificación especulativa con y sin cuantización:

Comparaciones con y sin cuantización

Cuantización de 4 bits (compresión del peso)

La configuración especifica load_in_4bit=True, lo que significa que los pesos del modelo se cuantifican desde su formato original de coma flotante de 32 o 16 bits a enteros de 4 bits. Esto reduce la huella de memoria del modelo. La cuantización comprime los pesos del modelo, lo que nos permite almacenarlos y operar con ellos de forma más eficaz. Estos son losahorros concretosde memoria :

  • Al reducir la precisión de flotantes de 32 ó 16 bits a enteros de 4 bits, cada peso ocupa ahora 1/4 ó 1/8 del espacio original, lo que reduce significativamente el uso de memoria.
  • Esto se refleja en el uso de memoria como
    • Uso normal de la memoria de inferencia: 26.458 MB
    • Uso de memoria de descodificación especulativa: 8.993 MB.

bfloat16 para el cálculo (uso eficiente de los Tensor Cores)

La configuración bnb_4bit_compute_dtype=torch.bfloat16 especifica que el cálculo se realice en bfloat16 (BF16), un formato de coma flotante más eficiente. El BF16 tiene un rango dinámico más amplio que el FP16, pero ocupa la mitad de memoria que el FP32, lo que lo convierte en un buen equilibrio entre precisión y rendimiento.

El uso de BF16, especialmente en GPUs NVIDIA (como la A100), utiliza Tensor Cores, que están optimizados para las operaciones de BF16. Así se consiguen multiplicaciones de matrices y otros cálculos más rápidos durante la inferencia.

Para la descodificación especulativa, observamos una mejora de la latencia:

  • Latencia normal de inferencia: 27,65 segundos
  • Latencia de descodificación especulativa: 15,56 segundos

La menor huella de memoria implica un acceso más rápido a la memoria y un uso más eficiente de los recursos de la GPU, lo que se traduce en una generación más rápida.

Tipo de cuantización NF4 (precisión optimizada)

La opción bnb_4bit_quant_type="nf4" especifica Cuantización Norma-Cuatro (NF4), que está optimizada para redes neuronales. La cuantización NF4 ayuda a conservar la precisión de partes importantes del modelo, aunque los pesos se representen en 4 bits. Esto minimiza la degradación del rendimiento del modelo en comparación con la simple cuantización de 4 bits.

NF4 ayuda a conseguir un equilibrio entre la compacidad de la cuantización de 4 bits y la precisión de las predicciones del modelo, garantizando que el rendimiento se mantenga próximo al original, al tiempo que se reduce drásticamente el uso de memoria.

Doble cuantización desactivada 

La doble cuantización (bnb_4bit_use_double_quant=False) introduce una capa adicional de cuantización sobre los pesos de 4 bits, lo que puede reducir aún más el uso de memoria, pero también añade sobrecarga de cálculo. En este caso, se desactiva la doble cuantización para evitar ralentizar la inferencia.

Aplicaciones de la descodificación especulativa

Las aplicaciones potenciales de la descodificación especulativa son amplias y apasionantes. He aquí algunos ejemplos:

  • Chatbots y asistentes virtuales: Para que esas conversaciones con la IA resulten más naturales y fluidas, con tiempos de respuesta más rápidos.
  • Traducción en tiempo real: La descodificación especulativa reduce la latencia en la traducción en tiempo real.
  • Generación de contenidos: La descodificación especulativa acelera la creación de contenidos.
  • Juegos y aplicaciones interactivas: Para mejorar la capacidad de respuesta de los personajes o de los elementos del juego impulsados por la IA y conseguir una experiencia más envolvente, la descodificación especulativa puede ayudarnos a obtener respuestas casi en tiempo real.

Aplicaciones de la descodificación especulativa

Desafíos de la descodificación especulativa

Aunque la descodificación especulativa es muy prometedora, no está exenta de dificultades:

  • Sobrecarga de memoria: Mantener múltiples estados del modelo (tanto para el borrador como para el LLM principal) puede aumentar el uso de memoria, especialmente cuando se utilizan modelos más grandes para la verificación.
  • Ajuste del modelo de calado: Elegir el modelo de calado adecuado y ajustar sus parámetros es crucial para lograr el equilibrio adecuado entre velocidad y precisión, ya que un modelo demasiado simplista puede dar lugar a frecuentes fallos de verificación.
  • Complejidad de la aplicación: Implementar la descodificación especulativa es más complejo que los métodos tradicionales, ya que requiere una sincronización cuidadosa entre el modelo de borrador pequeño y el modelo de verificación más grande, así como un tratamiento eficaz de los errores.
  • Compatibilidad con las estrategias de descodificación: Actualmente, la descodificación especulativa sólo admite la búsqueda codiciosa y el muestreo, lo que limita su uso a estrategias de descodificación más sofisticadas, como la búsqueda por haz o el muestreo diverso.
  • Gastos generales de verificación: Si el modelo de borrador más pequeño genera fichas que fallan frecuentemente en la verificación, las ganancias de eficiencia pueden disminuir, ya que el modelo más grande necesitará regenerar partes de la salida, anulando potencialmente las ventajas de velocidad.
  • Soporte limitado para el procesamiento por lotes: La descodificación especulativa no suele admitir entradas por lotes, lo que puede reducir su eficacia en sistemas de alto rendimiento que requieran el procesamiento paralelo de múltiples peticiones.

Conclusión

La descodificación especulativa es una potente técnica que puede revolucionar la forma en que interactuamos con grandes modelos lingüísticos. Puede acelerar significativamente la inferencia LLM sin comprometer la calidad del texto generado. Aunque hay retos que superar, las ventajas de la descodificación especulativa son innegables, y podemos esperar que su adopción crezca en los próximos años, permitiendo una nueva generación de aplicaciones de IA más rápidas, con mayor capacidad de respuesta y más eficientes.


Photo of Aashi Dutt
Author
Aashi Dutt
LinkedIn
Twitter

Soy una Google Developers Expert en ML(Gen AI), una Kaggle 3x Expert y una Women Techmakers Ambassador con más de 3 años de experiencia en tecnología. Cofundé una startup de tecnología sanitaria en 2020 y estoy cursando un máster en informática en Georgia Tech, especializándome en aprendizaje automático.

Temas

Aprende IA con estos cursos

programa

Desarrollo de aplicaciones de IA

23 horas hr
Aprende a crear aplicaciones potenciadas por IA con las últimas herramientas para desarrolladores de IA, como la API OpenAI, Hugging Face y LangChain.
Ver detallesRight Arrow
Comienza El Curso
Certificación disponible

curso

Desarrollar sistemas de IA con la API OpenAI

3 hr
3.7K
Aprovecha la API OpenAI para preparar tus aplicaciones de IA para la producción.
Ver másRight Arrow
Relacionado
An AI juggles tasks

blog

Cinco proyectos que puedes crear con modelos de IA generativa (con ejemplos)

Aprende a utilizar modelos de IA generativa para crear un editor de imágenes, un chatbot similar a ChatGPT con pocos recursos y una aplicación clasificadora de aprobación de préstamos y a automatizar interacciones PDF y un asistente de voz con GPT.
Abid Ali Awan's photo

Abid Ali Awan

10 min

blog

¿Qué es un modelo generativo?

Los modelos generativos utilizan el machine learning para descubrir patrones en los datos y generar datos nuevos. Conoce su importancia y sus aplicaciones en la IA.
Abid Ali Awan's photo

Abid Ali Awan

11 min

blog

Clasificación en machine learning: Introducción

Aprende sobre la clasificación en machine learning viendo qué es, cómo se utiliza y algunos ejemplos de algoritmos de clasificación.
Zoumana Keita 's photo

Zoumana Keita

14 min

tutorial

Autocodificadores variacionales: Cómo funcionan y por qué son importantes

Aprende los principios fundamentales, las aplicaciones y las ventajas prácticas de los autocodificadores variacionales y sigue una implementación paso a paso con PyTorch.
Kurtis Pykes 's photo

Kurtis Pykes

30 min

tutorial

Tutorial de DeepSeek-Coder-V2: Ejemplos, instalación, puntos de referencia

DeepSeek-Coder-V2 es un modelo de lenguaje de código de código abierto que rivaliza con el rendimiento de GPT-4, Gemini 1.5 Pro, Claude 3 Opus, Llama 3 70B o Codestral.
Dimitri Didmanidze's photo

Dimitri Didmanidze

8 min

tutorial

Visión GPT-4: Guía completa para principiantes

Este tutorial le presentará todo lo que necesita saber sobre GPT-4 Vision, desde cómo acceder a él hasta ejemplos prácticos del mundo real y sus limitaciones.
Arunn Thevapalan's photo

Arunn Thevapalan

12 min

See MoreSee More