Pular para o conteúdo principal

Ajuste fino do SAM 2 em um conjunto de dados personalizado: Tutorial

Saiba como fazer o ajuste fino do SAM 2 do Meta AI usando o conjunto de dados Chest CT Segmentation para melhorar o desempenho da segmentação de imagens do modelo na análise de imagens médicas.
Actualizado 4 de set. de 2024  · 14 min de leitura

O Segment Anything Model 2 (SAM 2) da Meta é a mais recente inovação em tecnologia de segmentação. É o primeiro modelo unificado do Meta que pode segmentar objetos em imagens e vídeos em tempo real.

Mas por que ajustar o SAM 2 se ele já pode segmentar qualquer coisa?

Embora o SAM 2 seja avançado e pronto para uso, seu desempenho em tarefas raras ou específicas do domínio nem sempre atende às expectativas. Ajuste fino permite que você adapte o SAM2 às suas necessidades específicas, melhorando sua precisão e eficiência para o seu caso de uso específico.

Neste artigo, vou guiar você passo a passo pelo processo de ajuste fino do SAM 2.

O que é SAM2?

O SAM2 é um modelo de base desenvolvido pela Meta para segmentação visual em imagens e vídeos. Ao contrário de seu antecessor, SAMque se concentrava principalmente em imagens estáticas, o SAM2 foi projetado para lidar também com as complexidades da segmentação de vídeo.

SAM2 - Tarefa, modelo e dados

SAM2 - Tarefa, modelo e dados (Fonte: Ravi et al., 2024)

Ele emprega uma arquitetura de transformador com memória de streaming, permitindo o processamento de vídeo em tempo real. O treinamento do SAM 2 envolveu um conjunto de dados vasto e variado, com o novo conjunto de dados SA-V, que inclui mais de 600.000 anotações de máscara em 51.000 vídeos.

Seu mecanismo de dados, que permite a coleta interativa de dados e o aprimoramento do modelo, dá a ele a capacidade de segmentar tudo o que for possível. Esse mecanismo permite que o SAM 2 aprenda e se adapte continuamente, tornando-o mais eficiente no tratamento de dados novos e desafiadores. No entanto, para tarefas específicas de domínio ou objetos raros, o ajuste fino é essencial para obter o desempenho ideal.

Por que fazer o ajuste fino do SAM2?

No contexto do SAM 2, o ajuste fino é o processo de treinamento adicional do modelo SAM 2 pré-treinado em um conjunto de dados específico para melhorar seu desempenho em uma tarefa ou domínio específico. Embora o SAM 2 seja uma ferramenta avançada treinada em um conjunto de dados amplo e diversificado, sua natureza de uso geral nem sempre produz resultados ideais para tarefas especializadas ou raras.

Por exemplo, se você estiver trabalhando em um projeto de imagens médicas que exija a identificação de tipos específicos de tumores, o desempenho do modelo poderá ficar aquém do esperado devido ao seu treinamento generalizado.

Ajuste fino do pipeline do processo

O processo de ajuste fino

O ajuste fino do SAM 2 resolve essa limitação, permitindo que você adapte o modelo ao seu conjunto de dados específico. Esse processo melhora a a precisão do modelo e o torna mais eficaz para seu caso de uso exclusivo.

Aqui estão os principais benefícios do ajuste fino do SAM 2:

  1. Precisão aprimorada: Com o ajuste fino do modelo em seu conjunto de dados específico, você pode aumentar significativamente sua precisão, garantindo um melhor desempenho em seu aplicativo específico.
  2. Segmentação especializada: O ajuste fino permite que o modelo se torne hábil na segmentação de tipos de objetos específicos, estilos visuais ou ambientes relevantes para o seu projeto, fornecendo resultados personalizados que um modelo de uso geral pode não alcançar.
  3. Eficiência: O ajuste fino geralmente é mais eficiente do que treinar um modelo do zero. Normalmente, requer menos dados e tempo, o que a torna uma solução prática para adaptar rapidamente o modelo a tarefas novas ou de nicho.

Primeiros passos com o ajuste fino do SAM 2: Pré-requisitos

Para começar a fazer o ajuste fino do SAM 2, você precisará ter os seguintes pré-requisitos:

  1. Acesso ao modelo SAM 2 e à base de código: Ter acesso ao modelo SAM 2 e à sua base de código. Você pode fazer o download do modelo SAM 2 pré-treinado no repositório GitHub da Meta.
  2. Um conjunto de dados adequado: Você precisará de um conjunto de dados que inclua máscaras de segmentação de verdade. Para este tutorial, usaremos o conjunto de dados Conjunto de dados de segmentação de TC de tóraxque você pode baixar e preparar para o treinamento.
  3. Recursos computacionais: O ajuste fino do SAM 2 requer hardware com potência computacional suficiente. As GPUs são altamente recomendadas para garantir que o processo seja eficiente e gerenciável, especialmente ao trabalhar com grandes conjuntos de dados ou modelos complexos. Neste exemplo, é usada uma GPU A100 no Google Colab.

Software e outros requisitos:

Preparando o conjunto de dados para o ajuste fino do SAM 2

A qualidade do seu conjunto de dados é crucial para o ajuste fino do modelo SAM 2. Imagens ou vídeos anotados de alta qualidade com máscaras de segmentação precisas são essenciais para que você obtenha o desempenho ideal. Anotações precisas permitem que o modelo aprenda os recursos corretos, levando a uma melhor precisão de segmentação e robustez em aplicativos do mundo real.

1. Aquisição de dados

A primeira etapa envolve a aquisição do conjunto de dados, que forma a espinha dorsal do processo de treinamento. Obtivemos nossos dados do Kaggleuma plataforma confiável que fornece uma gama diversificada de conjuntos de dados. Usando a API do Kaggle, baixamos os dados no formato necessário, garantindo que as imagens e as máscaras de segmentação correspondentes estivessem prontamente disponíveis para processamento posterior.

2. Extração e limpeza de dados

Depois de fazer o download dos conjuntos de dados, executamos as seguintes etapas:

  • Descompactação e limpeza: Extraia os dados dos arquivos zip baixados e exclua os arquivos desnecessários para economizar espaço em disco.
  • Extração de ID: Identificadores exclusivos (IDs) para imagens e máscaras são extraídos para garantir o mapeamento correto entre eles durante o treinamento.
  • Remoção de arquivos desnecessários: Remova todos os arquivos ruidosos ou irrelevantes, como determinadas imagens com problemas conhecidos, para manter a integridade do conjunto de dados.

3. Conversão para formatos utilizáveis

Como o modelo SAM2 requer entradas em formatos específicos, convertemos os dados da seguinte forma:

  • DICOM para NumPy: As imagens DICOM foram lidas e armazenadas como matrizes NumPy, que foram redimensionadas para uma dimensão padrão de 512x512 pixels.
  • NRRD para NumPy para máscaras: Da mesma forma, os arquivos NRRD contendo máscaras para pulmões, coração e traqueia foram processados e salvos como matrizes NumPy. Essas máscaras foram então remodeladas para corresponder às imagens correspondentes.
  • Conversão para JPG/PNG: Para melhor visualização e compatibilidade, as matrizes NumPy foram convertidas para os formatos JPG/PNG. Essa etapa incluiu a normalização dos valores de intensidade da imagem e a garantia de que as máscaras estavam orientadas corretamente.

4. Salvar e organizar dados

As imagens e máscaras processadas são então organizadas nas respectivas pastas para facilitar o acesso durante o processo de ajuste fino. Além disso, os caminhos para essas imagens e máscaras são gravados em um arquivo CSV (train.csv ) para facilitar o carregamento de dados durante o treinamento.

5. Visualização e validação

A etapa final envolveu a validação do conjunto de dados para garantir sua precisão:

  • Visualização: Visualizamos os pares imagem-máscara sobrepondo as máscaras às imagens. Isso nos ajudou a verificar o alinhamento e a precisão das máscaras.
  • Inspeção: Ao inspecionar algumas amostras, pudemos confirmar que o conjunto de dados estava corretamente preparado e pronto para ser usado no ajuste fino.

Aqui você encontra está um bloco de notas rápido para mostrar a você o código para a criação do conjunto de dados. Você pode seguir esse caminho de criação de dados ou usar diretamente qualquer conjunto de dados disponível on-line no mesmo formato que o mencionado nos pré-requisitos.

Ajuste fino do SAM2

O Segment Anything Model 2 contém vários componentes, mas o truque aqui para um ajuste fino mais rápido é treinar apenas componentes leves, como o decodificador de máscara e o codificador de prompt, em vez de todo o modelo. As etapas para o ajuste fino desse modelo são as seguintes:

Etapa 1: Instalar o SAM-2

Para iniciar o processo de ajuste fino, precisamos instalar a biblioteca SAM-2, que é essencial para o Segment Anything Model (SAM2). Esse modelo foi projetado para lidar com várias tarefas de segmentação de forma eficaz. A instalação envolve a clonagem do repositório SAM-2 do GitHub e a instalação das dependências necessárias.

!git clone https://github.com/facebookresearch/segment-anything-2
%cd /content/segment-anything-2
!pip install -q -e .

Esse trecho de código garante que a biblioteca SAM2 esteja corretamente instalada e pronta para ser usada em nosso fluxo de trabalho de ajuste fino.

Etapa 2: Faça o download do conjunto de dados

Depois que a biblioteca SAM-2 estiver instalada, a próxima etapa é adquirir o conjunto de dados que usaremos para o ajuste fino. Usamos um conjunto de dados disponível no Kaggle, especificamente um conjunto de dados de segmentação de TC de tórax que contém imagens e máscaras de pulmões, coração e traqueia.

O conjunto de dados contém:

  • images.zip: Imagens no formato RGB
  • masks.zip: Máscaras de segmentação no formato RGB
  • train.csv: Arquivo CSV com nomes de imagens

Imagem do conjunto de dados de tomografia computadorizada

Imagem do conjunto de dados de tomografia computadorizada

Neste blog, usaremos apenas imagens e máscaras de pulmões para segmentação. A API do Kaggle permite que você baixe conjuntos de dados diretamente para o nosso ambiente. Começamos fazendo o upload do arquivo kaggle.json do Kaggle para acessar qualquer conjunto de dados com facilidade.

Para obter o kaggle.json, acesse a guia Settings (Configurações) no seu perfil de usuário e selecione Create New Token (Criar novo token). Isso acionará o download do Kaggle. json, um arquivo que contém suas credenciais de API.

# get dataset from Kaggle
from google.colab import files
files.upload()  # This will prompt you to upload the kaggle.json file

!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d polomarco/chest-ct-segmentation

Descompacte os dados:

!unzip chest-ct-segmentation.zip -d chest-ct-segmentation

Com o conjunto de dados pronto, vamos iniciar o processo de ajuste fino. Como mencionei anteriormente, o segredo aqui é fazer o ajuste fino apenas dos componentes leves do SAM2, como o decodificador de máscara e o codificador de prompt, em vez de todo o modelo. Essa abordagem é mais eficiente e requer menos recursos.

Etapa 3: Faça o download dos pontos de controle do SAM-2

Para o processo de ajuste fino, precisamos começar com pesos pré-treinados do modelo SAM2. Esses pesos, chamados de "pontos de controle", são o ponto de partida para treinamento adicional. Os pontos de verificação foram treinados em uma ampla variedade de imagens e, ao ajustá-los em nosso conjunto de dados específico, podemos obter melhor desempenho em nossas tarefas de destino.

!wget -O sam2_hiera_tiny.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"
!wget -O sam2_hiera_small.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"
!wget -O sam2_hiera_base_plus.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"
!wget -O sam2_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"

Nessa etapa, baixamos vários pontos de verificação do SAM-2 que correspondem a diferentes tamanhos de modelo (por exemplo, minúsculo, pequeno, base_plus, grande). A escolha do ponto de verificação pode ser ajustada com base nos recursos computacionais disponíveis e na tarefa específica em questão.

Etapa 4: Preparação de dados

Com o download do conjunto de dados, a próxima etapa é prepará-lo para o treinamento. Isso envolve a divisão do conjunto de dados em conjuntos de treinamento e teste e a criação de estruturas de dados que podem ser inseridas no modelo SAM 2 durante o ajuste fino.

%cd /content/segment-anything-2

import os
import pandas as pd
import cv2
import torch
import torch.nn.utils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.model_selection import train_test_split
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Path to the chest-ct-segmentation dataset folder
data_dir = "/content/segment-anything-2/chest-ct-segmentation"
images_dir = os.path.join(data_dir, "images/images")
masks_dir = os.path.join(data_dir, "masks/masks")

# Load the train.csv file
train_df = pd.read_csv(os.path.join(data_dir, "train.csv"))

# Split the data into two halves: one for training and one for testing
train_df, test_df = train_test_split(train_df, test_size=0.2, random_state=42)

# Prepare the training data list
train_data = []
for index, row in train_df.iterrows():
   image_name = row['ImageId']
   mask_name = row['MaskId']

   # Append image and corresponding mask paths
   train_data.append({
       "image": os.path.join(images_dir, image_name),
       "annotation": os.path.join(masks_dir, mask_name)
   })

# Prepare the testing data list (if needed for inference or evaluation later)
test_data = []
for index, row in test_df.iterrows():
   image_name = row['ImageId']
   mask_name = row['MaskId']

   # Append image and corresponding mask paths
   test_data.append({
       "image": os.path.join(images_dir, image_name),
       "annotation": os.path.join(masks_dir, mask_name)
   })

Dividimos o conjunto de dados em um conjunto de treinamento (80%) e um conjunto de teste (20%) para garantir que possamos avaliar o desempenho do modelo após o treinamento. Os dados de treinamento serão usados para ajustar o modelo SAM 2, enquanto os dados de teste serão usados para inferência e avaliação.

Depois de dividir o conjunto de dados em conjuntos de treinamento e teste, a próxima etapa envolve a criação de máscaras binárias, a seleção de pontos-chave dentro dessas máscaras e a visualização desses elementos para garantir que os dados sejam processados corretamente. 

1. Leitura e redimensionamento de imagens: O processo começa com a seleção aleatória de uma imagem e sua máscara correspondente do conjunto de dados. A imagem é convertida do formato BGR para o formato RGB, que é o formato de cor padrão para a maioria dos modelos de aprendizagem profunda. A anotação correspondente (máscara) é lida no modo de escala de cinza. Em seguida, tanto a imagem quanto a máscara de anotação são redimensionadas para uma dimensão máxima de 1024 pixels, mantendo a proporção para garantir que os dados se ajustem aos requisitos de entrada do modelo e reduzam a carga computacional.

def read_batch(data, visualize_data=False):
   # Select a random entry
   ent = data[np.random.randint(len(data))]

   # Get full paths
   Img = cv2.imread(ent["image"])[..., ::-1]  # Convert BGR to RGB
   ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)  # Read annotation as grayscale

   if Img is None or ann_map is None:
       print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
       return None, None, None, 0

   # Resize image and mask
   r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])  # Scaling factor
   Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
   ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

2. Binarização de máscaras de segmentação: A máscara de anotação multiclasse (que pode ter várias classes de objetos rotuladas com valores de pixel diferentes) é convertida em uma máscara binária. Essa máscara destaca todas as regiões de interesse na imagem, simplificando a tarefa de segmentação para um problema de classificação binária: objeto vs. plano de fundo. A máscara binária é então erodida usando um kernel 5x5.

A erosão reduz ligeiramente o tamanho da máscara, o que ajuda a evitar efeitos de limite ao selecionar pontos. Isso garante que os pontos selecionados estejam bem no interior do objeto, e não perto das bordas, que podem ser ruidosas ou ambíguas.

Os pontos-chave são selecionados dentro da máscara erodida. Esses pontos funcionam como avisos durante o processo de ajuste fino, orientando o modelo sobre onde concentrar sua atenção. Os pontos são selecionados aleatoriamente no interior dos objetos para garantir que sejam representativos e não sejam influenciados por limites ruidosos.

   ### Continuation of read_batch() ###

   # Initialize a single binary mask
   binary_mask = np.zeros_like(ann_map, dtype=np.uint8)
   points = []

   # Get binary masks and combine them into a single mask
   inds = np.unique(ann_map)[1:]  # Skip the background (index 0)
   for ind in inds:
       mask = (ann_map == ind).astype(np.uint8)  # Create binary mask for each unique index
       binary_mask = np.maximum(binary_mask, mask)  # Combine with the existing binary mask

   # Erode the combined binary mask to avoid boundary points
   eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)

   # Get all coordinates inside the eroded mask and choose a random point
   coords = np.argwhere(eroded_mask > 0)
   if len(coords) > 0:
       for _ in inds:  # Select as many points as there are unique labels
           yx = np.array(coords[np.random.randint(len(coords))])
           points.append([yx[1], yx[0]])

   points = np.array(points)

3. Visualização: Essa etapa é fundamental para verificar se as etapas de pré-processamento de dados foram executadas corretamente. Ao inspecionar visualmente os pontos na máscara binarizada, você pode garantir que o modelo receberá a entrada apropriada durante o treinamento. Por fim, a máscara binária é remodelada e formatada corretamente (com dimensões adequadas para a entrada do modelo), e os pontos também são remodelados para uso posterior no processo de treinamento. A função retorna a imagem processada, a máscara binária, os pontos selecionados e o número de máscaras encontradas.

    ### Continuation of read_batch() ###

    if visualize_data:
        # Plotting the images and points
        plt.figure(figsize=(15, 5))

        # Original Image
        plt.subplot(1, 3, 1)
        plt.title('Original Image')
        plt.imshow(img)
        plt.axis('off')

        # Segmentation Mask (binary_mask)
        plt.subplot(1, 3, 2)
        plt.title('Binarized Mask')
        plt.imshow(binary_mask, cmap='gray')
        plt.axis('off')

        # Mask with Points in Different Colors
        plt.subplot(1, 3, 3)
        plt.title('Binarized Mask with Points')
        plt.imshow(binary_mask, cmap='gray')

        # Plot points in different colors
        colors = list(mcolors.TABLEAU_COLORS.values())
        for i, point in enumerate(points):
            plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100, label=f'Point {i+1}')  # Corrected to plot y, x order

        # plt.legend()
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    binary_mask = np.expand_dims(binary_mask, axis=-1)  # Now shape is (1024, 1024, 1)
    binary_mask = binary_mask.transpose((2, 0, 1))
    points = np.expand_dims(points, axis=1)

    # Return the image, binarized mask, points, and number of masks
    return img, binary_mask, points, len(inds)

# Visualize the data
Img1, masks1, points1, num_masks = read_batch(train_data, visualize_data=True)

O código acima retorna a figura a seguir, que contém a imagem original do conjunto de dados junto com a máscara binarizada e a máscara binarizada com pontos. 

Máscara e pontos binarizados para o conjunto de dados

Imagem original, máscara binarizada e máscara binarizada com pontos para o conjunto de dados.

Etapa 5: Ajuste fino do modelo SAM2

O ajuste fino do modelo SAM2 envolve várias etapas, inclusive o carregamento do modelo, a configuração do otimizador e do agendador e a atualização iterativa dos pesos do modelo com base nos dados de treinamento.

Carregue os pontos de controle do modelo:

sam2_checkpoint = "sam2_hiera_small.pt"  # @param ["sam2_hiera_tiny.pt", "sam2_hiera_small.pt", "sam2_hiera_base_plus.pt", "sam2_hiera_large.pt"]
model_cfg = "sam2_hiera_s.yaml" # @param ["sam2_hiera_t.yaml", "sam2_hiera_s.yaml", "sam2_hiera_b+.yaml", "sam2_hiera_l.yaml"]

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)

Começamos criando o modelo SAM2 usando os pontos de verificação pré-treinados. Em seguida, o modelo é agrupado em uma classe de previsão, o que simplifica o processo de definição de imagens, prompts de codificação e máscaras de decodificação.

Configurar hiperparâmetros

Configuramos vários hiperparâmetros para garantir que o modelo aprenda de forma eficaz, como a taxa de aprendizado, o decaimento do peso e as etapas de acumulação do gradiente. Esses hiperparâmetros controlam o processo de aprendizagem, inclusive a velocidade com que o modelo atualiza seus pesos e como ele evita o ajuste excessivo. Você pode brincar com eles.

# Train mask decoder.
predictor.model.sam_mask_decoder.train(True)

# Train prompt encoder.
predictor.model.sam_prompt_encoder.train(True)

# Configure optimizer.
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=0.0001,weight_decay=1e-4) #1e-5, weight_decay = 4e-5

# Mix precision.
scaler = torch.cuda.amp.GradScaler()

# No. of steps to train the model.
NO_OF_STEPS = 3000 # @param 

# Fine-tuned model name.
FINE_TUNED_MODEL_NAME = "fine_tuned_sam2"

O otimizador é responsável por atualizar os pesos do modelo, enquanto o programador ajusta a taxa de aprendizagem durante o treinamento para melhorar a convergência. Ao fazer o ajuste fino desses parâmetros, podemos obter melhor precisão de segmentação.

Iniciar o treinamento

O processo de ajuste fino real é iterativo, no qual, em cada etapa, um lote de imagens e máscaras apenas para pulmões é passado pelo modelo, e a perda é calculada e usada para atualizar os pesos do modelo.

# Initialize scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.2) # 500 , 250, gamma = 0.1
accumulation_steps = 4  # Number of steps to accumulate gradients before updating

for step in range(1, NO_OF_STEPS + 1):
   with torch.cuda.amp.autocast():
       image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)
       if image is None or mask is None or num_masks == 0:
           continue

       input_label = np.ones((num_masks, 1))
       if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
           continue

       if input_point.size == 0 or input_label.size == 0:
           continue

       predictor.set_image(image)
       mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
       if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
           continue

       sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
           points=(unnorm_coords, labels), boxes=None, masks=None,
       )

       batched_mode = unnorm_coords.shape[0] > 1
       high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
       low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
           image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
           image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
           sparse_prompt_embeddings=sparse_embeddings,
           dense_prompt_embeddings=dense_embeddings,
           multimask_output=True,
           repeat_image=batched_mode,
           high_res_features=high_res_features,
       )
       prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

       gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
       prd_mask = torch.sigmoid(prd_masks[:, 0])
       seg_loss = (-gt_mask * torch.log(prd_mask + 0.000001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

       inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
       iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
       score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
       loss = seg_loss + score_loss * 0.05

       # Apply gradient accumulation
       loss = loss / accumulation_steps
       scaler.scale(loss).backward()

       # Clip gradients
       torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)

       if step % accumulation_steps == 0:
           scaler.step(optimizer)
           scaler.update()
           predictor.model.zero_grad()

       # Update scheduler
       scheduler.step()

       if step % 500 == 0:
           FINE_TUNED_MODEL = FINE_TUNED_MODEL_NAME + "_" + str(step) + ".torch"
           torch.save(predictor.model.state_dict(), FINE_TUNED_MODEL)

       if step == 1:
           mean_iou = 0

       mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())

       if step % 100 == 0:
           print("Step " + str(step) + ":\t", "Accuracy (IoU) = ", mean_iou)

Durante cada iteração, o modelo processa um lote de imagens, calcula as máscaras de segmentação e as compara com a verdade terrestre para calcular a perda. Essa perda é então usada para ajustar os pesos do modelo, melhorando gradualmente o desempenho do modelo. Após o treinamento de cerca de 3.000 épocas, obtivemos uma precisão (IoU - Intersection over Union) de cerca de 72%.

Etapa 6: Inferência com o modelo de ajuste fino

O modelo pode então ser usado para inferência, onde ele prevê máscaras de segmentação em imagens novas e não vistas. Comece com asfunções auxiliares read_images e get_points para obter a imagem de inferência e sua máscara junto com os pontos-chave.

def read_image(image_path, mask_path):  # read and resize image and mask
   img = cv2.imread(image_path)[..., ::-1]  # Convert BGR to RGB
   mask = cv2.imread(mask_path, 0)
   r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
   img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
   mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
   return img, mask

def get_points(mask, num_points):  # Sample points inside the input mask
   points = []
   coords = np.argwhere(mask > 0)
   for i in range(num_points):
       yx = np.array(coords[np.random.randint(len(coords))])
       points.append([[yx[1], yx[0]]])
   return np.array(points)

Em seguida, carregue as imagens de amostra que você deseja para inferência, juntamente com os pesos recém-ajustados, e execute a inferência definindo torch.no_grad().

# Randomly select a test image from the test_data
selected_entry = random.choice(test_data)
image_path = selected_entry['image']
mask_path = selected_entry['annotation']

# Load the selected image and mask
image, mask = read_image(image_path, mask_path)

# Generate random points for the input
num_samples = 30  # Number of points per segment to sample
input_points = get_points(mask, num_samples)

# Load the fine-tuned model
FINE_TUNED_MODEL_WEIGHTS = "fine_tuned_sam2_1000.torch"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

# Build net and load weights
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS))

# Perform inference and predict masks
with torch.no_grad():
   predictor.set_image(image)
   masks, scores, logits = predictor.predict(
       point_coords=input_points,
       point_labels=np.ones([input_points.shape[0], 1])
   )

# Process the predicted masks and sort by scores
np_masks = np.array(masks[:, 0])
np_scores = scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]

# Initialize segmentation map and occupancy mask
seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

# Combine masks to create the final segmentation map
for i in range(sorted_masks.shape[0]):
   mask = sorted_masks[i]
   if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
       continue

   mask_bool = mask.astype(bool)
   mask_bool[occupancy_mask] = False  # Set overlapping areas to False in the mask
   seg_map[mask_bool] = i + 1  # Use boolean mask to index seg_map
   occupancy_mask[mask_bool] = True  # Update occupancy_mask

# Visualization: Show the original image, mask, and final segmentation side by side
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Final Segmentation')
plt.imshow(seg_map, cmap='jet')
plt.axis('off')

plt.tight_layout()
plt.show()

Nessa etapa, usamos o modelo ajustado para gerar máscaras de segmentação para imagens de teste. Em seguida, as máscaras previstas são visualizadas juntamente com as imagens originais e as máscaras da verdade básica para avaliar o desempenho do modelo.

Segmentação final na inferência

Imagem de segmentação final nos dados de teste 

Conclusão

O ajuste fino do SAM2 oferece uma maneira prática de aprimorar seus recursos para tarefas específicas. Independentemente de você estar trabalhando com imagens médicas, veículos autônomos ou edição de vídeo, o ajuste fino permite que você use o SAM2 para atender às suas necessidades específicas. Seguindo este guia, você pode adaptar o SAM2 aos seus projetos e obter resultados de segmentação de última geração.

Para casos de uso mais avançados, considere o ajuste fino de componentes adicionais do SAM2, como o codificador de imagem. Embora isso exija mais recursos, oferece maior flexibilidade e melhorias de desempenho.

Temas

Aprenda IA com estes cursos!

curso

Image Processing in Python

4 hr
44.8K
Learn to process, transform, and manipulate images at your will.
Ver DetalhesRight Arrow
Iniciar Curso
Ver maisRight Arrow
Relacionado
AI shaking hands with a human

blog

As 5 melhores ferramentas de IA para ciência de dados em 2024: Aumente seu fluxo de trabalho hoje mesmo

Os recentes avanços em IA têm o potencial de mudar drasticamente a ciência de dados. Leia este artigo para descobrir as cinco melhores ferramentas de IA que todo cientista de dados deve conhecer

blog

5 maneiras exclusivas de usar a IA na análise de dados

A análise de dados com IA está em alta entre os profissionais de dados. Aprenda cinco maneiras exclusivas de aproveitar o poder da IA para a análise de dados neste guia.
Austin Chia's photo

Austin Chia

tutorial

Guia para iniciantes do LlaMA-Factory WebUI: Ajuste fino dos LLMs

Saiba como fazer o ajuste fino dos LLMs em conjuntos de dados personalizados, avaliar o desempenho e exportar e servir modelos com facilidade usando a estrutura com pouco ou nenhum código do LLaMA-Factory.
Abid Ali Awan's photo

Abid Ali Awan

12 min

tutorial

Como fazer o ajuste fino do GPT 3.5: Liberando todo o potencial da IA

Explore o GPT-3.5 Turbo e descubra o potencial transformador do ajuste fino. Saiba como personalizar esse modelo de linguagem avançado para aplicativos de nicho, aprimorar seu desempenho e entender os custos associados, a segurança e as considerações de privacidade.
Moez Ali's photo

Moez Ali

11 min

tutorial

Tutorial do DeepChecks: Automatizando os testes de machine learning

Saiba como realizar a validação de dados e modelos para garantir um desempenho robusto de machine learning usando nosso guia passo a passo para automatizar testes com o DeepChecks.
Abid Ali Awan's photo

Abid Ali Awan

12 min

tutorial

Guia de Introdução ao Ajuste Fino de LLMs

O ajuste fino dos grandes modelos de linguagem (LLMs, Large Language Models) revolucionou o processamento de linguagem natural (PLN), oferecendo recursos sem precedentes em tarefas como tradução de idiomas, análise de sentimentos e geração de textos. Essa abordagem transformadora aproveita modelos pré-treinados como o GPT-2, aprimorando seu desempenho em domínios específicos pelo processo de ajuste fino.
Josep Ferrer's photo

Josep Ferrer

12 min

See MoreSee More