Skip to content

SAM-Audio Local Inference (CUDA)

Text & Span Prompting on AudioCaps

This notebook demonstrates how to run Metaโ€™s SAM-Audio locally on a CUDA-enabled GPU (RTX 3090 / RunPod) for prompt-based audio separation.

It provides a stable, end-to-end workflow for loading the SAM-Audio model, preparing audio inputs, and performing text and span prompting to isolate target sounds from real-world audio examples.


What this notebook covers

  • โœ… Stable environment setup for SAM-Audio on CUDA
    (including the TorchCodec 0.7.0 compatibility fix to avoid kernel crashes)
  • โœ… Loading the SAM-Audio base model with FP16 inference
  • โœ… Fetching and previewing audio samples from AudioCaps
  • โœ… Preparing audio tensors (resampling, mono conversion, dtype alignment)
  • โœ… Span prompting (temporal anchors) combined with text descriptions
  • โœ… Running audio separation on GPU
  • โœ… Visualizing waveforms directly in the notebook
  • โœ… Saving and listening to target and residual audio outputs

Models & Data

  • Model: facebook/sam-audio-base
  • Dataset: AudioCaps (OpenSound)
  • Prompting modes used:
    • Text prompting
    • Span (temporal) prompting

Hardware & Runtime Assumptions

  • NVIDIA GPU with CUDA support (tested on RTX 3090)
  • PyTorch 2.8 + CUDA 12.8
  • FP16 inference on GPU
  • Python 3.12 (RunPod base image)

Output

At the end of the notebook, you will have:

  • ๐ŸŽฏ An isolated target sound (target.wav)
  • ๐Ÿงฉ A residual track containing all other audio
  • Inline playback and waveform visualization for both

1. Environment Setup & Compatibility Fixes

# !pip install -q git+https://github.com/facebookresearch/sam-audio.git
# !pip uninstall -y torchcodec
# !pip install --no-cache-dir "torchcodec==0.7.0" -f https://download.pytorch.org/whl/torchcodec/
# !pip install hf_transfer
import os
from huggingface_hub import login

login(token=os.environ["HF_TOKEN"])

Load the SAM-Audio Model (FP16 on CUDA)

import torch
import torchaudio
from pathlib import Path

from sam_audio import SAMAudio, SAMAudioProcessor

MODEL_ID = "facebook/sam-audio-base"  # base model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SAMAudio.from_pretrained(MODEL_ID).eval()

# Use fp16 on GPU to cut VRAM roughly in half (3090 supports fp16 well)
model = model.to(device=device, dtype=torch.float16).eval()

processor = SAMAudioProcessor.from_pretrained(MODEL_ID)

print("Device:", device)
print("Sampling rate:", processor.audio_sampling_rate)
!nvidia-smi

Load an Example Audio from AudioCaps

from datasets import load_dataset
from IPython.display import Audio

# Get an example audio from AudioCaps
dset = load_dataset(
    "parquet",
    data_files="hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet",
)

samples = dset["train"][8]["audio"].get_all_samples()
Audio(samples.data, rate=samples.sample_rate)

Visualizing the Audio Track

import matplotlib.pyplot as plt
import torch

def plot_waveform(wav, sr, title="Waveform", max_seconds=15):
    """
    wav: Tensor [C, T] or [T]
    sr: sample rate (int)
    """
    if isinstance(wav, torch.Tensor):
        w = wav.detach().cpu()
    else:
        w = torch.tensor(wav)

    if w.ndim == 1:
        w = w.unsqueeze(0)

    # limit duration for readability
    max_samples = int(sr * max_seconds)
    w = w[:, :max_samples]

    t = torch.arange(w.shape[1]) / sr

    plt.figure(figsize=(14, 3))
    for c in range(w.shape[0]):
        plt.plot(t, w[c].numpy(), label=f"ch{c}")
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    if w.shape[0] > 1:
        plt.legend()
    plt.tight_layout()
    plt.show()
plot_waveform(samples.data, samples.sample_rate, title="Traffic noise + Horn")

Prepare Audio for Inference

wav = torch.tensor(samples.data)

# Ensure [C, T]
if wav.ndim == 1:
    wav = wav.unsqueeze(0)  # [1, T]
elif wav.ndim == 2 and wav.shape[0] > wav.shape[1]:
    wav = wav.transpose(0, 1)  # [T, C] -> [C, T]

wav = wav.float()

orig_sr = int(samples.sample_rate)
target_sr = int(processor.audio_sampling_rate)

if orig_sr != target_sr:
    wav = torchaudio.functional.resample(wav, orig_sr, target_sr)

# Mono
wav = wav.mean(0, keepdim=True)  # [1, T]

# Move to device + match model dtype (FP16 on CUDA)
wav = wav.to(device=device, dtype=next(model.parameters()).dtype)

duration_s = wav.shape[-1] / target_sr
print(f"Resampled duration: {duration_s:.2f}s @ {target_sr} Hz")
โ€Œ
โ€Œ
โ€Œ