Skip to content

Setup

!pip install torchmultimodal-nightly
Hidden output
import torch
import torchvision
import torchvision.transforms.functional as F

from torch import nn
from tqdm import tqdm
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import DiffusionHybridLoss
from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import linear_beta_schedule, DiscreteGaussianSchedule
from torchmultimodal.diffusion_labs.transforms.diffusion_transform import RandomDiffusionSteps
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput

Schedule

# Define Diffusion Schedule

schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))

Predictor

# Define Prediction Target

predictor = NoisePredictor(schedule, lambda x : torch.clamp(x,-1,1))

U-Net

from torchmultimodal.diffusion_labs.models.adm_unet.adm import adm_unet unet = adm_unet( time_embed_dim=32, embed_dim=32, embed_name="context", predict_variance_value=True, image_channels=1, )
# Down scaling input blocks for unet
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cond_channels):
        super().__init__()
        # Define convolutions
        self.block = nn.Sequential(
            nn.Conv2d(in_channels + cond_channels, out_channels, kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
            nn.ReLU()
        )
        

        self.pooling = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, c):
        _, _, w, h = x.size()
        c = c.expand(-1, -1, w, h)              # Shape conditional input to match image
        x = self.block(torch.cat([x, c], 1))    # Convolutions over image + condition
        x_small = self.pooling(x)               # Downsample output for next block
        return x, x_small

# Upscaling blocks on unet
class UpBlock(nn.Module):
    def __init__(self, inp, out):
        super().__init__()
        self.block = nn.Sequential(
            # 2 * inp because we will be adding input from two sources
            nn.Conv2d(2*inp, out, kernel_size=3,padding=1),
            nn.ReLU(),
            nn.Conv2d(out,out,kernel_size=3,padding=1),
            nn.ReLU()
        )
        
        #Define convolutions
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, x, x_small):
        x_big = self.upsample(x_small)          # Upscale input back towards original size
        x = torch.cat((x_big, x), dim=1)        # Join previous block with accross block
        x = self.block(x)                       # Convolutions over image
        return x

class UNet(nn.Module):
    def __init__(self, time_size=32, digit_size=32, steps=1000):
        super().__init__()
        cond_size = time_size + digit_size 
        # Define UNet 
        self.conv = nn.Conv2d(1,128,kernel_size=3,padding=1)
        # arbitrary numbers
        self.down = nn.ModuleList([DownBlock(128,256,cond_size),DownBlock(256,512,cond_size)])
        # can use a different achitecture for the bottleneck, but let's use downblock
        self.bottleneck = DownBlock(512,512,cond_size)
        self.up   = nn.ModuleList([UpBlock(512,256),UpBlock(256,128)])
        
        self.time_projection  = nn.Embedding(steps,time_size)
        self.prediction = nn.Conv2d(128,128,kernel_size=3,padding=1)
        self.variance   = nn.Conv2d(128,1,kernel_size=3,padding=1)
        
    def forward(self, x, t, conditional_inputs):
        b,c,h,w = x.shape
        # every pixel in the time series gets each pixel added
        timestep = self.time_projection(t).view(b,-1,1,1)
        condition = conditional_inputs['context'].view(b,-1,1,1)
        condition = torch.cat([timestep,condition],dim=1)
        
        # Define forward
        x = self.conv(x)
        outs = []
        for block in self.down:
            out, x = block(x,condition)
            outs.append(out)
        x,_ = self.bottleneck(x,condition)
        for block in self.up:
            x = block(outs.pop(), x)
            #temp = outs.pop()
            #print(f"Debug: shapes: out: {temp.shape}, x : {x.shape} ")
            
        # variance is optional, but helps with the training
        v = self.variance(x)
        p = self.prediction(x)
        # diffusion output expects prediction and variance
        return DiffusionOutput(p,v)
            
            
            

Diffusion Model

unet = UNet(time_size=32, digit_size=32)
# Add support for classifier free guidance
unet = CFGuidance(unet, {"context":32}, guidance = 2.0)
# Define evalution
# can skip some steps, and the model will still learn well
eval_steps = torch.linspace(0,999,250,dtype = torch.long) # do only every 4th step
model = DDPModule(unet, schedule, predictor,eval_steps)
#model = DDPModule(unet, schedule, predictor, eval_steps)
# Define conditional embeddings
# Larger models use LLMs
# simpler model, based on the fact that the data set has 10 classes
encoder = nn.Embedding(10,32)

Data