Skip to main content

Late Chunking for RAG: Implementation With Jina AI

Learn how to implement late chunking with Jina AI to improve context preservation and retrieval accuracy in RAG applications.
Nov 20, 2024  · 7 min read

In RAG applications, there's a constant trade-off between two approaches: embedding the whole document for better context or breaking it into smaller chunks for more precise retrieval.

Embedding the entire document captures the big picture but can lose important details, while shorter chunks keep the details but often miss the overall context.

Late chunking offers a solution by keeping the full document context intact while splitting it into smaller, easier-to-handle chunks.

In this article, I will introduce late chunking as a better alternative to traditional naive chunking methods and show you how to implement them step by step.

RAG with LangChain

Integrate external data with LLMs using Retrieval Augmented Generation (RAG) and LangChain.
Explore Course

Naive Chunking and Its Limitations in RAG

In a RAG pipeline, documents are broken into smaller chunks before being embedded and stored in a vector database. Each chunk is processed independently and used for retrieval when queries are made. However, this "naive chunking" approach often loses important long-distance context.

The problem arises because traditional chunking splits documents without considering how information is connected. For example, in a document about Paris, the phrase "the city" might end up in a different chunk from where "Paris" is mentioned. Without the full context, the retrieval model may struggle to link these references, leading to less accurate results. This issue is even worse in long documents where key context is spread over multiple sections.

Late Chunking: Preserving Context in Document Splitting

Late chunking solves the problem by changing when you split the document. Instead of breaking the document into chunks first, late chunking embeds the entire document using a long-context model. Only after this does it split the document into smaller chunks.

These are the key benefits of late chunking:

  • Keeps context: Late chunking ensures that each chunk retains the overall context by embedding the whole document first. This way, references and connections across the text remain intact in the chunk embeddings.
  • Better retrieval: The chunk embeddings created through late chunking are richer and more accurate, improving retrieval results in RAG systems because the model understands the document better.
  • Handles long texts: It's great for very long documents that traditional models can’t handle in one go due to token limits.

Using long-context models like Jina’s jinaai/jina-embeddings-v2-base-en, which supports up to 8192 tokens, late chunking allows for embedding large text sections effectively before splitting them into chunks.

Implementing Late Chunking

Here’s a step-by-step guide to help you implement late chunking using Jina’s long-context embedding model. You can get Jina’s API key for free here, and we will be using the following input text as the demo:

input_text = """Berlin is the capital and largest city of Germany, both by area and by population.
Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.
The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."""

Step 1: Get chunks and span annotations

First, use your Jina API key and the helper function below to break your input text into chunks. These chunks come with span annotations that help split the document embedding later. Jina’s API uses natural boundaries like paragraph or sentence breaks to ensure the chunks make sense and retain their meaning.

import json
import requests

def custom_tokenize_jina_api(input_text: str):
    url = '<https://segment.jina.ai/>'
    headers = {
        'Content-Type': 'application/json',
        'Authorization': 'Bearer ENTER_YOUR_JINA_API_KEY'
    }
    data = {
        "content": input_text,
        "tokenizer": "o200k_base",
        "return_tokens": "true",
        "return_chunks": "true",
        "max_chunk_length": "1000"
    }
    # Make the API request
    response = requests.post(url, headers=headers, json=data)
    response_data = response.json()
    chunks = response_data.get("chunks", [])
    i = 1
    j = 1
    span_annotations = []
    for x in response_data['tokens']:
        if j == 1:
            j = len(x)
        else:
            j = len(x) + i
        span_annotations.append((i, j))
        i = j
    return chunks, span_annotations
chunks, span_annotations = custom_tokenize_jina_api(input_text)

print(chunks)
print(span_annotations)
['Berlin is the capital and largest city of Germany, both by area and by population.\\n\\n', "Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.\\n\\n", 'The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.']
[(1, 17), (17, 44), (44, 69)]

Step 2: Tokenize the text and generate token-level document embeddings

First, use a tokenizer compatible with long-context models, such as Jina’s embeddings-v2-base-en, to break down the entire document into tokens. Next, use a long-context transformer model to create embeddings for each token. This means every word or token in your document gets its unique embedding that captures its meaning.

from transformers import AutoModel
from transformers import AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
model_output[0].shape
torch.Size([1, 71, 768]) # 71 represents number of tokens in the entire document

Step 3: Late chunking

Once you have the token embeddings for the whole document, you’re ready for late chunking. Use the span annotations from step one to split these token embeddings into smaller chunks. Then, apply mean pooling to average the embeddings within each chunk, creating a single embedding for each chunk. We now have chunk embeddings with strong contextual information about the entire document.

def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)
    return outputs
embeddings = late_chunking(model_output, [span_annotations])[0]
len(embeddings)
3 # matches number of chunks in Step 1

Step 4: Late chunking vs. traditional chunking results

To understand the benefits of late chunking, let’s compare it with traditional chunking:

embeddings_traditional_chunking = model.encode(chunks)
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
q = "Berlin"
berlin_embedding = model.encode(q)

print(q)
print('\\n')
for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
  print(chunk.strip())
  print(f'Late chunking:', cos_sim(berlin_embedding, new_embedding))
  print(f'Traditional chunking:', cos_sim(berlin_embedding, trad_embeddings))
  print("------------------------------------------------------------------")
Berlin
Berlin is the capital and largest city of Germany, both by area and by population.
Late chunking: 0.84954596
Traditional chunking: 0.84862185
------------------------------------------------------------------
Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits.
Late chunking: 0.82489026
Traditional chunking: 0.70843375
------------------------------------------------------------------
The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.
Late chunking: 0.84980094
Traditional chunking: 0.7534553
------------------------------------------------------------------

As you can see in the second and third chunks, traditional chunking shows similarity scores of 70-75% when compared to the word "Berlin." However, with late chunking, which maintains the context of the entire document, these scores rise to 82-84%. This shows that late chunking does a better job of preserving context and creating more meaningful embeddings, resulting in more accurate search results.

Conclusion

Late chunking is a major improvement for document retrieval systems, especially in RAG pipelines. By waiting to split the document until after it's fully embedded, late chunking keeps the full context in each chunk. This results in more accurate and meaningful embeddings.

Project: Building RAG Chatbots for Technical Documentation

Implement RAG with LangChain to create a chatbot for answering questions about technical documentation.

Photo of Ryan Ong
Author
Ryan Ong
LinkedIn
Twitter

Ryan is a lead data scientist specialising in building AI applications using LLMs. He is a PhD candidate in Natural Language Processing and Knowledge Graphs at Imperial College London, where he also completed his Master’s degree in Computer Science. Outside of data science, he writes a weekly Substack newsletter, The Limitless Playbook, where he shares one actionable idea from the world's top thinkers and occasionally writes about core AI concepts.

Topics

Learn AI with these courses!

course

Retrieval Augmented Generation (RAG) with LangChain

3 hr
968
Learn cutting-edge methods for integrating external data with LLMs using Retrieval Augmented Generation (RAG) with LangChain.
See DetailsRight Arrow
Start Course
See MoreRight Arrow
Related

blog

Advanced RAG Techniques

Learn advanced RAG methods like dense retrieval, reranking, or multi-step reasoning to tackle issues like hallucination or ambiguity.
Stanislav Karzhev's photo

Stanislav Karzhev

12 min

tutorial

How to Improve RAG Performance: 5 Key Techniques with Examples

Explore different approaches to enhance RAG systems: Chunking, Reranking, and Query Transformations.
Eugenia Anello's photo

Eugenia Anello

tutorial

RAG vs Fine-Tuning: A Comprehensive Tutorial with Practical Examples

Learn the differences between RAG and Fine-Tuning techniques for customizing model performance and reducing hallucinations in LLMs.
Abid Ali Awan's photo

Abid Ali Awan

13 min

tutorial

Corrective RAG (CRAG) Implementation With LangGraph

Corrective RAG (CRAG) is a RAG technique that incorporates self-assessment of retrieved documents to improve the accuracy and relevance of generated responses.
Ryan Ong's photo

Ryan Ong

14 min

tutorial

Using a Knowledge Graph to Implement a RAG Application

Learn how to implement knowledge graphs for RAG applications by following this step-by-step tutorial to enhance AI responses with structured knowledge.
Dr Ana Rojo-Echeburúa's photo

Dr Ana Rojo-Echeburúa

19 min

tutorial

Recursive Retrieval for RAG: Implementation With LlamaIndex

Learn how to implement recursive retrieval in RAG systems using LlamaIndex to improve the accuracy and relevance of retrieved information, especially for large document collections.
Ryan Ong's photo

Ryan Ong

8 min

See MoreSee More