Skip to content
Using SAM-Audio Locally
SAM-Audio Local Inference (CUDA)
Text & Span Prompting on AudioCaps
This notebook demonstrates how to run Metaโs SAM-Audio locally on a CUDA-enabled GPU (RTX 3090 / RunPod) for prompt-based audio separation.
It provides a stable, end-to-end workflow for loading the SAM-Audio model, preparing audio inputs, and performing text and span prompting to isolate target sounds from real-world audio examples.
What this notebook covers
- โ
Stable environment setup for SAM-Audio on CUDA
(including the TorchCodec 0.7.0 compatibility fix to avoid kernel crashes) - โ Loading the SAM-Audio base model with FP16 inference
- โ Fetching and previewing audio samples from AudioCaps
- โ Preparing audio tensors (resampling, mono conversion, dtype alignment)
- โ Span prompting (temporal anchors) combined with text descriptions
- โ Running audio separation on GPU
- โ Visualizing waveforms directly in the notebook
- โ Saving and listening to target and residual audio outputs
Models & Data
- Model:
facebook/sam-audio-base - Dataset: AudioCaps (OpenSound)
- Prompting modes used:
- Text prompting
- Span (temporal) prompting
Hardware & Runtime Assumptions
- NVIDIA GPU with CUDA support (tested on RTX 3090)
- PyTorch 2.8 + CUDA 12.8
- FP16 inference on GPU
- Python 3.12 (RunPod base image)
Output
At the end of the notebook, you will have:
- ๐ฏ An isolated target sound (
target.wav) - ๐งฉ A residual track containing all other audio
- Inline playback and waveform visualization for both
1. Environment Setup & Compatibility Fixes
# !pip install -q git+https://github.com/facebookresearch/sam-audio.git# !pip uninstall -y torchcodec
# !pip install --no-cache-dir "torchcodec==0.7.0" -f https://download.pytorch.org/whl/torchcodec/# !pip install hf_transferimport os
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])Load the SAM-Audio Model (FP16 on CUDA)
import torch
import torchaudio
from pathlib import Path
from sam_audio import SAMAudio, SAMAudioProcessor
MODEL_ID = "facebook/sam-audio-base" # base model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAMAudio.from_pretrained(MODEL_ID).eval()
# Use fp16 on GPU to cut VRAM roughly in half (3090 supports fp16 well)
model = model.to(device=device, dtype=torch.float16).eval()
processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
print("Device:", device)
print("Sampling rate:", processor.audio_sampling_rate)
!nvidia-smiLoad an Example Audio from AudioCaps
from datasets import load_dataset
from IPython.display import Audio
# Get an example audio from AudioCaps
dset = load_dataset(
"parquet",
data_files="hf://datasets/OpenSound/AudioCaps/data/test-00000-of-00041.parquet",
)
samples = dset["train"][8]["audio"].get_all_samples()
Audio(samples.data, rate=samples.sample_rate)
Visualizing the Audio Track
import matplotlib.pyplot as plt
import torch
def plot_waveform(wav, sr, title="Waveform", max_seconds=15):
"""
wav: Tensor [C, T] or [T]
sr: sample rate (int)
"""
if isinstance(wav, torch.Tensor):
w = wav.detach().cpu()
else:
w = torch.tensor(wav)
if w.ndim == 1:
w = w.unsqueeze(0)
# limit duration for readability
max_samples = int(sr * max_seconds)
w = w[:, :max_samples]
t = torch.arange(w.shape[1]) / sr
plt.figure(figsize=(14, 3))
for c in range(w.shape[0]):
plt.plot(t, w[c].numpy(), label=f"ch{c}")
plt.title(title)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
if w.shape[0] > 1:
plt.legend()
plt.tight_layout()
plt.show()
plot_waveform(samples.data, samples.sample_rate, title="Traffic noise + Horn")
Prepare Audio for Inference
wav = torch.tensor(samples.data)
# Ensure [C, T]
if wav.ndim == 1:
wav = wav.unsqueeze(0) # [1, T]
elif wav.ndim == 2 and wav.shape[0] > wav.shape[1]:
wav = wav.transpose(0, 1) # [T, C] -> [C, T]
wav = wav.float()
orig_sr = int(samples.sample_rate)
target_sr = int(processor.audio_sampling_rate)
if orig_sr != target_sr:
wav = torchaudio.functional.resample(wav, orig_sr, target_sr)
# Mono
wav = wav.mean(0, keepdim=True) # [1, T]
# Move to device + match model dtype (FP16 on CUDA)
wav = wav.to(device=device, dtype=next(model.parameters()).dtype)
duration_s = wav.shape[-1] / target_sr
print(f"Resampled duration: {duration_s:.2f}s @ {target_sr} Hz")โ
โ
โ
โ
โ