# Введение в диффузионные модели

В данном ноутбуке будет рассмотрено несколько базовых подходов к генерации изображений с помощью диффузионных моделей. Мы будем генерировать фотографии лиц. В качестве эталона, будем использовать выровненные, центрированные фотографии знаменитостей из датасета CelebFaces Attributes (CelebA).

Установим нужные пакеты, скачаем датасет и чекпоинты моделей, которые нам понадобятся в этом задании.

In [None]:
import glob
import math
import os
from functools import partial

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms.v2 as T
import tqdm.auto as tqdm
from einops import einsum, rearrange, repeat
from einops.layers.torch import Rearrange
from IPython.display import display
from PIL import Image

# Increase these if figures appear small
plt.rcParams["figure.figsize"] = fx, fy = (18.00, 18.00)

# Not using `bfloat16` matrix multiplication for consistency
# You might get better performance without much precision loss
# by setting this to "medium" on some devices
torch.set_float32_matmul_precision("high")

In [None]:
SKIP_TRAIN = True

In [None]:
try:
    # Download, verify and extract the CelebA dataset
    tv.datasets.CelebA(".", download=True)
except Exception:
    print("Official download method failed. Trying direct download.")
    !mkdir -p celeba
    !curl 'https://courses.cv-gml.ru/storage/seminars/ae-vae-gan/img_align_celeba.zip' -o celeba/img_align_celeba.zip
    !curl 'https://courses.cv-gml.ru/storage/seminars/ae-vae-gan/celeba_txt.zip' -O
    !unzip -o celeba_txt.zip

    print()
    for _, md5sum, file_name in tv.datasets.CelebA.file_list:
        print("Expected:  ", md5sum, " celeba/" + file_name)
        print("Downloaded: ", end="", flush=True)
        !md5sum "celeba/{file_name}"
        print()

    # Verify and extract the CelebA dataset
    tv.datasets.CelebA(".", download=True)

# Download pretrained model checkpoints
if not os.path.exists("DenoisingDiffusion.ckpt"):
    !curl -O 'https://courses.cv-gml.ru/storage/seminars/image-diffusion/DenoisingDiffusion.ckpt'

In [None]:
class CelebABoilerplate(L.LightningModule):
    def __init__(self, image_size, batch_size=16, lr=0.0001, **kwargs):
        super().__init__(**kwargs)
        self.save_hyperparameters()

    def get_dataset(self, kind):
        return tv.datasets.CelebA(
            ".",
            split=kind,
            transform=T.Compose(
                [
                    T.Resize(self.hparams.image_size),
                    T.CenterCrop(self.hparams.image_size),
                    # [0; 255] -> [0; 1]
                    T.ToImage(),
                    T.ToDtype(torch.float32, scale=True),
                    # [0; 1] -> [-1; 1]
                    T.Normalize(3 * [0.5], 3 * [0.5]),
                ]
            ),
        )

    def get_dataloader(self, kind):
        return torch.utils.data.DataLoader(
            self.get_dataset(kind),
            num_workers=os.cpu_count(),
            shuffle=kind == "train",
            batch_size=self.hparams.batch_size,
        )

    def train_dataloader(self):
        return self.get_dataloader("train")

    def val_dataloader(self):
        return self.get_dataloader("valid")

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=2)
        return [opt], [sch]

    def training_step(self, *args, **kwargs):
        return self.step(*args, **kwargs)

    def validation_step(self, *args, **kwargs):
        return self.step(*args, **kwargs)


def get_latest_version(root_dir):
    versions = glob.glob(f"{root_dir}/lightning_logs/version_*")
    versions = sorted(versions, key=lambda k: int(k.split("_")[-1]))
    return versions[-1] if versions else None


def get_latest_checkpoint(root_dir):
    last_version = get_latest_version(root_dir)
    if last_version:
        (checkpoint,) = glob.glob(f"{last_version}/checkpoints/*")
        return checkpoint
    else:
        return None


def img_to_u8(img):
    img = img.transpose((1, 2, 0))
    # [-1; 1] to [0; 255]
    img = (255 / 2) * (img + 1)
    # float to uint8
    img = img.round().clip(0, 255).astype(np.uint8)
    return img


def make_grid(imgs):
    imgs = np.concatenate(
        [
            np.concatenate(
                [img_to_u8(img) for img in row],
                axis=1,
            )
            for row in imgs
        ],
        axis=0,
    )
    return imgs


def show_images(imgs, scale=2):
    imgs = make_grid(imgs)
    imgs = Image.fromarray(imgs)
    if scale:
        imgs = imgs.resize(
            [scale * d for d in imgs.size],
            resample=Image.Resampling.NEAREST,
        )
    display(imgs)

In [None]:
@torch.no_grad()
def show_generation_process(ddpm, method, cuda=True, num_images=16, **kwargs):
    ddpm = ddpm.eval()
    if cuda:
        ddpm = ddpm.cuda()

    images = [[] for _ in range(num_images)]
    noisy = method(ddpm, num_images=num_images, keep_noisy=True, **kwargs)

    max_steps = len(noisy)
    for step in np.linspace(0, max_steps - 1, num_images, dtype=int):
        results = noisy[step].cpu().numpy()
        for target, image in zip(images, results):
            target.append(image)

    show_images(images)

In [None]:
@torch.no_grad()
def show_generation_results(ddpm, method, cuda=True, num_images=4, **kwargs):
    ddpm = ddpm.eval()
    if cuda:
        ddpm = ddpm.cuda()

    images = method(ddpm, num_images=num_images * num_images, **kwargs)
    images = list(images.cpu().numpy())
    images = [[images.pop() for _ in range(num_images)] for _ in range(num_images)]

    show_images(images, scale=4)

In [None]:
def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d


def cast_tuple(t, length=1):
    if isinstance(t, tuple):
        return t
    return (t,) * length

## DDPM

Рассмотрим статью "[`Denoising Diffusion Probabilistic Models`](https://arxiv.org/pdf/2006.11239.pdf)", опубликованную в Июне 2020.

### Модифицированная архитектура UNet

В данной статье предлагается использовать UNet-подобную архитектуру с рядом популярных модификаций часто используемых в индустрии. Рассмотрим слегка упрощенный код, адаптированный из [репозитория `lucidrains/denoising-diffusion-pytorch`](https://github.com/lucidrains/denoising-diffusion-pytorch).

#### ResNet-подобные базовые блоки

In [None]:
class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_dim=None):
        super().__init__()
        if exists(time_dim):
            self.mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_dim, dim_out * 2),
            )
        else:
            self.mlp = None

        self.block1 = Block(dim, dim_out)
        self.block2 = Block(dim_out, dim_out)

        if dim != dim_out:
            self.res_conv = nn.Conv2d(dim, dim_out, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)

        return h + self.res_conv(x)

#### Attention блоки с дополнительной памятью

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim=1) * self.g * (x.shape[1] ** 0.5)


class MemoryKV(nn.Module):
    # Trainable transformer memory, aka "registers"

    def __init__(self, heads, num_tokens, dim_head):
        super().__init__()
        self.heads = heads
        self.num_tokens = num_tokens
        self.dim_head = dim_head

        self.memory = nn.Parameter(torch.randn(2, heads, num_tokens, dim_head))

    def forward(self, k, v):
        b, h, _, c = k.shape
        n = self.num_tokens
        dims = dict(b=b, h=h, n=n, c=c)

        # Broadcast memory tensor to batch
        mem = self.memory
        mem = repeat(mem, "t h n c -> t b h n c", t=2, **dims)
        mk, mv = mem

        # Concatenate memory tokens
        k = torch.cat((k, mk), dim=-2)
        v = torch.cat((v, mv), dim=-2)

        return k, v

In [None]:
class MemAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads=4,
        dim_head=32,
        num_mem_kv=4,
    ):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.dim_head = dim_head
        self.num_mem_kv = num_mem_kv
        self.hidden_dim = hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)

        self.mem_kv = MemoryKV(heads, num_mem_kv, dim_head)
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, img):
        b, _, x, y = img.shape
        h = self.heads
        c = self.dim_head
        dims = dict(b=b, h=h, c=c, x=x, y=y)

        # Project to queries, keys and values (also, split heads)
        skip = img
        img = self.norm(img)
        qkv = self.to_qkv(img)
        qkv = rearrange(qkv, "b (t h c) x y -> t b h (x y) c", t=3, **dims)
        q, k, v = qkv

        # Add memory tokens to keys and values
        k, v = self.mem_kv(k, v)

        # Differentiable equivalent of
        # out = {k: v}[q]
        out = self.attend(q, k, v)

        # Project back (also, merge heads)
        out = rearrange(out, "b h (x y) c -> b (h c) x y", **dims)
        out = self.to_out(out)

        # Skip connection
        out = out + skip
        return out

    def attend(self, q, k, v):
        raise NotImplementedError()

In [None]:
class FullMemAttention(MemAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def attend(self, q, k, v):
        q = q.contiguous()
        return F.scaled_dot_product_attention(q, k, v)

In [None]:
class LinearMemAttention(MemAttention):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.scale = self.dim_head**-0.5
        self.to_out = nn.Sequential(
            self.to_out,
            RMSNorm(self.dim),
        )

    def attend(self, q, k, v):
        # Extract linear context
        k = k.softmax(dim=-2)  # over sequence dimension
        context = einsum(k, v, "... n d, ... n e -> ... e d")

        # Attend to the generated context
        q = q.softmax(dim=-1)  # over channel dimension
        q = q * self.scale
        result = einsum(q, context, "... n d, ... e d -> ... n e")

        return result

#### Позиционные эмбединги

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim, theta=10000, scale=1000):
        super().__init__()
        self.dim = dim
        self.theta = theta
        self.scale = scale

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.theta) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = self.scale * x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [None]:
plt.imshow(SinusoidalPosEmb(dim=64)(torch.linspace(0, 1, 100)).T);

#### UNet блоки

In [None]:
class Upsample(nn.Sequential):
    def __init__(self, dim, dim_out=None):
        super().__init__()
        self.scale = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)


class Downsample(nn.Sequential):
    def __init__(self, dim, dim_out=None):
        super().__init__()
        self.scale = Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2)
        self.conv = nn.Conv2d(dim * 4, default(dim_out, dim), 1)

In [None]:
class UNetBlock(nn.Module):
    def __init__(
        self,
        inner_block,
        dim,
        dim_inner,
        time_dim,
        *,
        downscale=True,
        upscale=True,
        full_attn=False,
    ):
        super().__init__()
        attention = FullMemAttention if full_attn else LinearMemAttention

        self.down_block1 = ResnetBlock(dim, dim, time_dim=time_dim)
        self.down_block2 = ResnetBlock(dim, dim, time_dim=time_dim)
        self.down_attn = attention(dim)

        if downscale:
            self.downsample = Downsample(dim, dim_inner)
        else:
            self.downsample = nn.Conv2d(dim, dim_inner, 3, padding=1)

        self.inner_block = inner_block

        dim_skip = dim_inner + dim
        self.up_block1 = ResnetBlock(dim_skip, dim_inner, time_dim=time_dim)
        self.up_block2 = ResnetBlock(dim_skip, dim_inner, time_dim=time_dim)
        self.up_attn = attention(dim_inner)

        if upscale:
            self.upsample = Upsample(dim_inner, dim)
        else:
            self.upsample = nn.Conv2d(dim_inner, dim, 3, padding=1)

    def forward(self, x, t):
        x = self.down_block1(x, t)
        skip1 = x

        x = self.down_block2(x, t)
        x = self.down_attn(x)
        skip2 = x

        x = self.downsample(x)
        x = self.inner_block(x, t)

        x = torch.cat((x, skip2), dim=1)
        x = self.up_block1(x, t)

        x = torch.cat((x, skip1), dim=1)
        x = self.up_block2(x, t)
        x = self.up_attn(x)

        x = self.upsample(x)
        return x

In [None]:
class UNetInnerBlock(nn.Module):
    def __init__(self, mid_dim, time_dim):
        super().__init__()
        self.block1 = ResnetBlock(mid_dim, mid_dim, time_dim=time_dim)
        self.attn = FullMemAttention(mid_dim)
        self.block2 = ResnetBlock(mid_dim, mid_dim, time_dim=time_dim)

    def forward(self, x, t):
        x = self.block1(x, t)
        x = self.attn(x)
        x = self.block2(x, t)
        return x

#### Собираем все вместе

In [None]:
class UNet(nn.Module):
    def __init__(
        self,
        dim,
        dim_mults=(1, 2, 4, 8),
        channels=3,
    ):
        super().__init__()

        dims = [dim] + [dim * m for m in dim_mults]
        time_dim = dim * 4

        in_out = list(zip(dims[:-1], dims[1:]))
        last = len(in_out) - 1
        self.downsample_factor = 2**last

        # Embed/project to inner dimension
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim, scale=1),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )
        self.init_conv = nn.Conv2d(channels, dim, 7, padding=3)

        # Start with simple inner-most block
        inner_block = UNetInnerBlock(dims[-1], time_dim)
        for level, (dim_outer, dim_inner) in enumerate(in_out[::-1]):
            inner_block = UNetBlock(
                # Wrap the previous level
                inner_block,
                # Use these dimensions
                dim=dim_outer,
                dim_inner=dim_inner,
                time_dim=time_dim,
                # Use full attention only on the lowest level
                full_attn=level == 0,
                # No downsampling on the lowest level
                downscale=level != 0,
                # No upsampling on the highest level
                upscale=level != last,
            )
        self.inner_block = inner_block

        # Restore to original dimensionality
        self.final_res_block = ResnetBlock(dim * 2, dim, time_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, channels, 1)

    def check_downsample_factor(self, shape):
        for d in shape[-2:]:
            if d % self.downsample_factor:
                raise RuntimeError(
                    f"The input dimensions {shape[-2:]} needs "
                    f"to be divisible by {self.downsample_factor}"
                )

    def forward(self, x, time):
        self.check_downsample_factor(x.shape)

        t = self.time_mlp(time)
        x = self.init_conv(x)

        skip = x
        x = self.inner_block(x, t)
        x = torch.cat((x, skip), dim=1)

        x = self.final_res_block(x, t)
        x = self.final_conv(x)

        return x

### Обучение DDPM

In [None]:
class CelebADenoisingDiffusion(CelebABoilerplate):
    def __init__(self, num_steps=1000, **kwargs):
        super().__init__(**kwargs)
        self.save_hyperparameters()

        self.model = UNet(dim=64)

        betas = torch.linspace(0.0001, 0.02, num_steps)
        self.register_buffer("betas", betas)

        alpha_hats = (1 - betas).cumprod(axis=0)
        self.register_buffer("alpha_hats", alpha_hats)

    def add_noise(self, images, steps, get_noise=False):
        # \/ Your code here \/
        # /\ Your code here /\

        if get_noise:
            return noisy_images, noise
        else:
            return noisy_images

    def forward(self, images, steps):
        steps = steps / (self.hparams.num_steps - 1)
        return self.model(images, steps)

    def step(self, batch, batch_idx):
        kind = "train" if self.training else "valid"
        images, _ = batch

        # \/ Your code here \/

        # Add random noise to images

        # Predict the added noise

        # Minimize the difference between predicted and actual noise

        # /\ Your code here /\

        self._log_loss("loss", loss, kind)
        return loss

    def _log_loss(self, name, loss, kind):
        self.log(
            f"{kind}/{name}",
            loss,
            prog_bar=True,
            logger=True,
            on_epoch=True,
            on_step="train" in kind,
        )

#### Прямой диффузионный процесс

In [None]:
@torch.no_grad()
def show_noise_results(ddpm, cuda=True, num_images=16):
    ddpm = ddpm.eval()
    valid_ds = ddpm.get_dataset("valid")

    idxs = np.random.choice(len(valid_ds), size=num_images)
    imgs = [valid_ds[idx] for idx in idxs]
    imgs = [img for img, _ in imgs]

    if cuda:
        ddpm = ddpm.cuda()
        imgs = [img.cuda() for img in imgs]

    rows = []
    for img in imgs:
        cols = []
        max_steps = ddpm.hparams.num_steps
        for step in np.linspace(0, max_steps - 1, num_images, dtype=int):
            step = torch.tensor([step], device=img.device)

            out = ddpm.add_noise(img, step).squeeze(0).cpu().numpy()
            cols.append(out)
        rows.append(cols)

    show_images(rows)

In [None]:
show_noise_results(CelebADenoisingDiffusion(image_size=64))

#### Обучаем модель

Для экономии времени на семинаре вам предлагается загрузить заранее обученную версию модели. <br/>
Также для справки приведен код, который был использован для обучения этой модели.

In [None]:
# Create an Denoising Diffusion Probabilistic Model
if not SKIP_TRAIN:
    ckpt = get_latest_checkpoint("DenoisingDiffusion")
    if ckpt:
        print(f"Continue training from {ckpt}")
        ddpm = CelebADenoisingDiffusion.load_from_checkpoint(ckpt)
    else:
        ddpm = CelebADenoisingDiffusion(image_size=64)

    trainer = L.Trainer(
        callbacks=[
            L.pytorch.callbacks.DeviceStatsMonitor(),
            L.pytorch.callbacks.EarlyStopping(
                "valid/loss",
                patience=4,
                verbose=True,
            ),
        ],
        default_root_dir="DenoisingDiffusion",
        max_epochs=-1,
    )

    # Train the model
    trainer.fit(ddpm)
else:
    ddpm = CelebADenoisingDiffusion.load_from_checkpoint(
        "DenoisingDiffusion.ckpt",
    )

### Генерация изображений

#### Обратный диффузионный процесс

In [None]:
@torch.no_grad()
def ddpm_remove_noise_step(ddpm, images, step):
    # Predict the noise
    step_ch = torch.tensor([step], device=images.device)
    noise_pr = ddpm(images, step_ch)

    # Remove some of the noise (single step)
    # \/ Your code here \/
    # /\ Your code here /\

    # Add back some noise (except at the final step)
    if step:
        # \/ Your code here \/
        # /\ Your code here /\

    return images

In [None]:
@torch.no_grad()
def ddpm_generate_batch(ddpm, num_images=None, keep_noisy=False):
    shape = (
        num_images or ddpm.hparams.batch_size,
        3,
        ddpm.hparams.image_size,
        ddpm.hparams.image_size,
    )
    steps = ddpm.hparams.num_steps
    device = ddpm.device

    # Generate random noise
    images = torch.randn(*shape, device=device)
    if keep_noisy:
        noisy = [images]

    # Iteratively remove noise
    steps = range(steps)[::-1]
    for step in tqdm.tqdm(steps):
        images = ddpm_remove_noise_step(ddpm, images, step)

        if keep_noisy:
            noisy.append(images)

    if keep_noisy:
        return noisy
    else:
        return images

#### Визуализация результатов

In [None]:
show_generation_process(ddpm, ddpm_generate_batch)

In [None]:
show_generation_results(ddpm, ddpm_generate_batch)

## DDIM

Рассмотрим статью "[`Denoising Diffusion Implicit Models`](https://arxiv.org/pdf/2010.02502.pdf)", опубликованную в Октябре 2020.

In [None]:
@torch.no_grad()
def ddim_remove_noise_step(ddpm, images, cur_step, prev_step):
    images = images.clone()
    device = images.device

    # Predict the noise
    step_ch = torch.tensor([cur_step], device=device)
    noise_pr = ddpm(images, step_ch)

    # \/ Your code here \/

    # Remove some of the noise (single step)

    # Predicted x_0

    # Step in the direction pointing to x_t

    # Add back some random noise

    # /\ Your code here /\

    return images

In [None]:
@torch.no_grad()
def ddim_generate_batch(ddpm, num_steps, num_images=None, keep_noisy=False):
    shape = (
        num_images or ddpm.hparams.batch_size,
        3,
        ddpm.hparams.image_size,
        ddpm.hparams.image_size,
    )
    steps = ddpm.hparams.num_steps
    device = ddpm.device

    # Generate random noise
    images = torch.randn(*shape, device=device)
    if keep_noisy:
        noisy = [images]

    # Select a subset from the DDPM steps
    ddim_steps = np.linspace(steps - 1, 0, num_steps, dtype=int)
    # -1 is a special value for the last step (alpha = 1)
    ddim_steps = np.pad(ddim_steps, (0, 1), constant_values=-1)

    # Iteratively remove noise
    steps = zip(ddim_steps[:-1], ddim_steps[1:])
    for cur_step, prev_step in tqdm.tqdm(steps):
        images = ddim_remove_noise_step(ddpm, images, cur_step, prev_step)

        if keep_noisy:
            noisy.append(images)

    if keep_noisy:
        return noisy
    else:
        return images

In [None]:
show_generation_process(ddpm, ddim_generate_batch, num_steps=100)

In [None]:
show_generation_results(ddpm, ddim_generate_batch, num_steps=100)

---

In [None]:
show_generation_results(ddpm, ddim_generate_batch, num_steps=10)

---

In [None]:
show_generation_results(ddpm, ddim_generate_batch, num_steps=4)

&nbsp;