Skip to content
Diffusion_FashionMNIST (copy)
Setup
!pip install torchmultimodal-nightlyHidden output
import torch
import torchvision
import torchvision.transforms.functional as F
from torch import nn
from tqdm import tqdm
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import DiffusionHybridLoss
from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import linear_beta_schedule, DiscreteGaussianSchedule
from torchmultimodal.diffusion_labs.transforms.diffusion_transform import RandomDiffusionSteps
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutputSchedule
# Define Diffusion Schedule
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))Predictor
# Define Prediction Target
predictor = NoisePredictor(schedule, lambda x : torch.clamp(x,-1,1))U-Net
from torchmultimodal.diffusion_labs.models.adm_unet.adm import adm_unet unet = adm_unet( time_embed_dim=32, embed_dim=32, embed_name="context", predict_variance_value=True, image_channels=1, )
# Down scaling input blocks for unet
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, cond_channels):
super().__init__()
# Define convolutions
self.block = nn.Sequential(
nn.Conv2d(in_channels + cond_channels, out_channels, kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
nn.ReLU()
)
self.pooling = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x, c):
_, _, w, h = x.size()
c = c.expand(-1, -1, w, h) # Shape conditional input to match image
x = self.block(torch.cat([x, c], 1)) # Convolutions over image + condition
x_small = self.pooling(x) # Downsample output for next block
return x, x_small
# Upscaling blocks on unet
class UpBlock(nn.Module):
def __init__(self, inp, out):
super().__init__()
self.block = nn.Sequential(
# 2 * inp because we will be adding input from two sources
nn.Conv2d(2*inp, out, kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(out,out,kernel_size=3,padding=1),
nn.ReLU()
)
#Define convolutions
self.upsample = nn.Upsample(scale_factor=2)
def forward(self, x, x_small):
x_big = self.upsample(x_small) # Upscale input back towards original size
x = torch.cat((x_big, x), dim=1) # Join previous block with accross block
x = self.block(x) # Convolutions over image
return x
class UNet(nn.Module):
def __init__(self, time_size=32, digit_size=32, steps=1000):
super().__init__()
cond_size = time_size + digit_size
# Define UNet
self.conv = nn.Conv2d(1,128,kernel_size=3,padding=1)
# arbitrary numbers
self.down = nn.ModuleList([DownBlock(128,256,cond_size),DownBlock(256,512,cond_size)])
# can use a different achitecture for the bottleneck, but let's use downblock
self.bottleneck = DownBlock(512,512,cond_size)
self.up = nn.ModuleList([UpBlock(512,256),UpBlock(256,128)])
self.time_projection = nn.Embedding(steps,time_size)
self.prediction = nn.Conv2d(128,128,kernel_size=3,padding=1)
self.variance = nn.Conv2d(128,1,kernel_size=3,padding=1)
def forward(self, x, t, conditional_inputs):
b,c,h,w = x.shape
# every pixel in the time series gets each pixel added
timestep = self.time_projection(t).view(b,-1,1,1)
condition = conditional_inputs['context'].view(b,-1,1,1)
condition = torch.cat([timestep,condition],dim=1)
# Define forward
x = self.conv(x)
outs = []
for block in self.down:
out, x = block(x,condition)
outs.append(out)
x,_ = self.bottleneck(x,condition)
for block in self.up:
x = block(outs.pop(), x)
#temp = outs.pop()
#print(f"Debug: shapes: out: {temp.shape}, x : {x.shape} ")
# variance is optional, but helps with the training
v = self.variance(x)
p = self.prediction(x)
# diffusion output expects prediction and variance
return DiffusionOutput(p,v)
Diffusion Model
unet = UNet(time_size=32, digit_size=32)
# Add support for classifier free guidance
unet = CFGuidance(unet, {"context":32}, guidance = 2.0)# Define evalution
# can skip some steps, and the model will still learn well
eval_steps = torch.linspace(0,999,250,dtype = torch.long) # do only every 4th step
model = DDPModule(unet, schedule, predictor,eval_steps)
#model = DDPModule(unet, schedule, predictor, eval_steps)# Define conditional embeddings
# Larger models use LLMs
# simpler model, based on the fact that the data set has 10 classes
encoder = nn.Embedding(10,32)Data