# Сегментация изображений

<div align="center">
    <div style="display: inline-block; width: 30%;">
        <img src="https://courses.cv-gml.ru/storage/seminars/simple-segmentation/seg-image.png"/>
        Изображение
    </div>
    <div style="display: inline-block; width: 10%;"></div>
    <div style="display: inline-block; width: 30%;">
        <img src="https://courses.cv-gml.ru/storage/seminars/simple-segmentation/seg-semantic.png"/>
        Семантическая сегментация
    </div>
</div>

<br/>
<br/>

<div align="center">
    <div style="display: inline-block; width: 30%;">
        <img src="https://courses.cv-gml.ru/storage/seminars/simple-segmentation/seg-instance.png"/>
        Сегментация экземпляров
    </div>
    <div style="display: inline-block; width: 10%;"></div>
    <div style="display: inline-block; width: 30%;">
        <img src="https://courses.cv-gml.ru/storage/seminars/simple-segmentation/seg-panoptic.png"/>
        Паноптическая сегментация
    </div>
</div>

На этом семинаре будет рассмотрена **семантическая** сегментация на простом игрушечном датасете из синтетических картинок.

## Содержание:
* Вспомогательные функции
* Примеры сгенерированных изображения для демо тренировки
* Создание датасетов и даталоудеров
* Построение UNet модели
* Обучение
* Визуализация предсказаний
* Идеи для улучшения

---

## Вспомогательные функции

### Функции визуализации

In [None]:
import itertools
import math
import random
from functools import reduce

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.models
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision.transforms import v2 as transforms

# Not using `bfloat16` matrix multiplication for consistency
# You might get better performance without much precision loss
# by setting this to "medium" on some devices
torch.set_float32_matmul_precision("high")

In [None]:
def plot_img_array(img_array, ncol=3):
    nrow = len(img_array) // ncol
    f, plots = plt.subplots(
        nrow,
        ncol,
        sharex="all",
        sharey="all",
        figsize=(ncol * 4, nrow * 4),
    )
    for i in range(len(img_array)):
        plots[i // ncol, i % ncol].imshow(img_array[i])


def plot_side_by_side(img_arrays):
    flatten_list = reduce(lambda x, y: x + y, zip(*img_arrays))
    plot_img_array(np.array(flatten_list), ncol=len(img_arrays))


colors = np.asarray(
    [
        (255, 255, 255),
        (31, 119, 180),
        (255, 127, 14),
        (44, 160, 44),
        (214, 39, 40),
        (148, 103, 189),
        (140, 86, 75),
    ]
)

def masks_to_colorimg(masks):
    channels, height, width = masks.shape
    colorimg = 255 * np.ones((height, width, 3), dtype=np.float32)
    for y in range(height):
        for x in range(width):
            selected_colors = colors[1:][masks[:, y, x] > 0.5]
            if len(selected_colors) > 0:
                colorimg[y, x, :] = np.mean(selected_colors, axis=0)
    return colorimg.astype(np.uint8)


def labels_to_colorimg(labels):
    colorimg = np.take_along_axis(
        colors[:, :, None, None],
        labels[None, None, :, :],
        axis=0,
    )
    colorimg = colorimg.squeeze(0).transpose(1, 2, 0)
    colorimg = colorimg.astype(np.uint8)
    return colorimg


def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)
    return inp

### Функции генерации синтетических изображий и их масок 

In [None]:
def generate_random_data(height, width, count):
    x, y = zip(*[generate_image_and_label(height, width) for i in range(0, count)])
    X = np.asarray(x) * 255
    X = X.repeat(3, axis=1).transpose([0, 2, 3, 1]).astype(np.uint8)
    Y = np.asarray(y)
    return X, Y


def generate_image_and_label(height, width):
    # Background
    shape = (height, width)
    image = np.zeros(shape, dtype=bool)
    label = [np.zeros(shape, dtype=bool)]

    # 1. Filled square
    square_location = get_random_location(*shape, zoom=0.8)
    mask = add_filled_square(shape, *square_location)
    image |= mask
    label.append(mask)

    # 2. Filled circle
    circle_location = get_random_location(*shape, zoom=0.5)
    mask = add_circle(shape, *circle_location, fill=True)
    image |= mask
    label.append(mask)

    # 3. Triangle
    triangle_location = get_random_location(*shape)
    mask = add_triangle(shape, *triangle_location)
    image |= mask
    label.append(mask)

    # 4. Plus
    plus_location = get_random_location(*shape, zoom=1.2)
    mask = add_plus(shape, *plus_location)
    image |= mask
    label.append(mask)
    
    # 5. Wheel
    wheel_location = get_random_location(*shape, zoom=0.7)
    mask = add_circle(shape, *wheel_location)
    image |= mask
    label.append(mask)

    # 6. Mesh
    mesh_location = get_random_location(*shape)
    # For meshes, use filled square as target label
    image |= add_mesh_square(shape, *mesh_location)    
    label.append(add_filled_square(shape, *mesh_location))

    # Create target labels
    # In case of multiple classes in the same pixel,
    # prefer class with smallest index
    label = np.argmax(label, axis=0)

    image = image[None, :, :].astype(np.float32)
    return image, label


def add_filled_square(shape, x, y, size):
    arr = np.zeros(shape, dtype=bool)
    s = int(size / 2)
    xx, yy = np.mgrid[: arr.shape[0], : arr.shape[1]]
    return np.logical_or(
        arr,
        logical_and(
            [
                xx > x - s,
                xx < x + s,
                yy > y - s,
                yy < y + s,
            ]
        ),
    )


def logical_and(arrays):
    new_array = np.ones(arrays[0].shape, dtype=bool)
    for a in arrays:
        new_array = np.logical_and(new_array, a)
    return new_array


def add_mesh_square(shape, x, y, size):
    arr = np.zeros(shape, dtype=bool)
    s = int(size / 2)
    xx, yy = np.mgrid[: arr.shape[0], : arr.shape[1]]
    return np.logical_or(
        arr,
        logical_and(
            [
                xx > x - s,
                xx < x + s,
                xx % 2 == 1,
                yy > y - s,
                yy < y + s,
                yy % 2 == 1,
            ]
        ),
    )


def add_triangle(shape, x, y, size):
    arr = np.zeros(shape, dtype=bool)
    s = int(size / 2)
    triangle = np.tril(np.ones((size, size), dtype=bool))
    arr[
        x - s : x - s + triangle.shape[0],
        y - s : y - s + triangle.shape[1],
    ] = triangle
    return arr


def add_circle(shape, x, y, size, fill=False):
    arr = np.zeros(shape, dtype=bool)
    xx, yy = np.mgrid[: arr.shape[0], : arr.shape[1]]
    circle = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
    new_arr = np.logical_or(
        arr,
        np.logical_and(
            circle < size,
            circle >= size * 0.7 if not fill else True,
        ),
    )
    return new_arr


def add_plus(shape, x, y, size):
    arr = np.zeros(shape, dtype=bool)
    s = int(size / 2)
    arr[x - 1 : x + 1, y - s : y + s] = True
    arr[x - s : x + s, y - 1 : y + 1] = True
    return arr


def get_random_location(width, height, zoom=1.0):
    x = int(width * random.uniform(0.1, 0.9))
    y = int(height * random.uniform(0.1, 0.9))
    size = int(min(width, height) * random.uniform(0.06, 0.12) * zoom)
    return (x, y, size)

---

## Примеры сгенерированных изображения для демо тренировки

При обучении будет использован синтетически сгенерированный датасет. Давайте посмотрим на несколько примеров из обучающей выборки.

In [None]:
# Generate some random images
input_images, target_labels = generate_random_data(192, 192, count=2)

print(
    "input_images shape and range",
    input_images.shape,
    input_images.min(),
    input_images.max(),
)
print(
    "target_labels shape and range",
    target_labels.shape,
    target_labels.min(),
    target_labels.max(),
)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [x.astype(np.uint8) for x in input_images]

# Map each channel (i.e. class) to each color
target_labels_rgb = [labels_to_colorimg(x) for x in target_labels]

* <strong>Слева: Input image (черно-белое)</strong>

Черный фон, на котором в случайном порядке расположены 6 различных фигур.

* <strong>Справа:  Target labels (одноканальное изображение с индексами от 0 до 6)</strong>

0 - фон, 1-6 - квадрат, круг, треугольник, плюс, колесо, решетка.

In [None]:
plot_side_by_side([input_images_rgb, target_labels_rgb])

---

## Создание датасетов и даталоудеров

In [None]:
class SimDataset(data.Dataset):
    def __init__(self, count, preprocess=None):
        data = generate_random_data(192, 192, count=count)
        self._images, self._labels = data
        self._preprocess = preprocess

    def __len__(self):
        return len(self._images)

    def __getitem__(self, idx):
        image = self._images[idx]
        label = self._labels[idx]

        image = self._preprocess(image)

        return [image, label]

In [None]:
# use the same transformations for train/val in this example
prepr = transforms.Compose(
    [
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),
        ),
    ]
)

# Init Datasets
train_set = SimDataset(2000, preprocess=prepr)
valid_set = SimDataset(200, preprocess=prepr)

# Init Dataloaders
dl_train = data.DataLoader(train_set, batch_size=25, shuffle=True, num_workers=2)
dl_valid = data.DataLoader(valid_set, batch_size=25, shuffle=False, num_workers=2)

---

## Построение UNet модели

![](https://courses.cv-gml.ru/storage/seminars/simple-segmentation/unet.png)

In [None]:
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )


class ResNetUNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        base_model = torchvision.models.resnet18(weights=weights)
        self.base_layers = L = list(base_model.children())[:8]

        self.layer0 = nn.Sequential(*L[:3])  # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)

        self.layer1 = nn.Sequential(*L[3:5])  # size=(N, 64, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(64, 64, 1, 0)

        self.layer2 = L[5]  # size=(N, 128, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(128, 128, 1, 0)

        self.layer3 = L[6]  # size=(N, 256, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(256, 256, 1, 0)

        self.layer4 = L[7]  # size=(N, 512, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(512, 512, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        self.conv_up3 = convrelu(256 + 512, 512, 3, 1)
        self.conv_up2 = convrelu(128 + 512, 256, 3, 1)
        self.conv_up1 = convrelu(64 + 256, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)

        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
        self.conv_last = nn.Conv2d(64, 1 + n_class, 1)

    def forward(self, x_original):
        ...
        return x

---

## Обучение

Близость дискретных областей или множеств можно измерять используя различные индексы вроде Intersection over Union или Dice коэффициента. Однако, их формулы подразумевают использование уже дискретизованных предсказаний и соответственно не являются дифференцируемыми.

Не смотря на это, мы можем придумать обобщенную версию данных дискретных операторов над множествами. Есть несколько способов для подобного "продолжения" этих функций на непрерывных входах, но чаще всего используются следующие:

$$
\left\lvert{A \cap B}\right\rvert \quad\Longrightarrow\quad \sum p_B \cdot p_A \quad\text{или}\quad \sum \mathrm{min}\left(p_B, p_A\right)
$$

<br/>

$$
\left\lvert{A}\right\rvert \quad\Longrightarrow\quad \sum p^2_A
$$

<br/>

$$
\left\lvert{A \cup B}\right\rvert \quad\Longrightarrow\quad \sum \mathrm{max}\left(p_B, p_A\right)
$$

Обратите внимание, что все эти обобщения строго равны своим дискретным аналогам, когда $p \in \left\{0, 1\right\}$ и гладкие везде на $p \in \left(0, 1\right)$.

### Dice Loss

[Dice коэффициент](https://en.wikipedia.org/wiki/Dice-S%C3%B8rensen_coefficient):
$$
\mathrm{DICE} = \frac{
        2 \cdot \left\lvert{gt \cap pr}\right\rvert
    }{
        \left\lvert{gt}\right\rvert + \left\lvert{pr}\right\rvert
    }
$$

Реализуем функцию потерь Dice Loss – гладкое обобщение $\mathrm{DICE}$:

$$
    \mathrm{L}_{dice} = 1 - \frac{
        2 \cdot \sum p_{gt} \cdot p_{pr}
    }{
        \sum p^2_{gt} + \sum p^2_{pr}
    }
$$

In [None]:
def dice_loss(pr, gt):
    # Compute Dice loss only for non-background classes
    per_class_loss = []
    for cls in range(1, pr.shape[1]):
        # For each class, for each image
        loss = ...

        per_class_loss.append(loss)

    loss = sum(per_class_loss) / len(per_class_loss)
    loss = loss.mean()
    return loss

### Jaccard Loss

[Jaccard index](https://en.wikipedia.org/wiki/Jaccard_index) (или IoU):
$$
\mathrm{IoU} = \frac{
        \left\lvert{gt \cap pr}\right\rvert
    }{
        \left\lvert{gt \cup pr}\right\rvert
    }
$$

Реализуем функцию потерь Jaccard Loss – гладкое обобщение $\mathrm{IoU}$:

$$
    \mathrm{L}_{jacc} = 1 - \frac{
        \sum \mathrm{min}\left(p_{gt}, p_{pr}\right)
    }{
        \sum \mathrm{max}\left(p_{gt}, p_{pr}\right)
    }
$$

In [None]:
def jacc_loss(pr, gt):
    # Compute Jaccard loss only for non-background classes
    per_class_loss = []
    for cls in range(1, pr.shape[1]):
        # For each class, for each image
        loss = ...

        per_class_loss.append(loss)

    loss = sum(per_class_loss) / len(per_class_loss)
    loss = loss.mean()
    return loss

### LightningModule

In [None]:
class MyModel(L.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        self.model = ResNetUNet(num_classes)

        # freeze backbone layers
        for l in self.model.base_layers:
            for param in l.parameters():
                param.requires_grad = False

        self.xent_weight = 0.8
        self.dice_weight = 0.1
        self.jacc_weight = 0.1

    def training_step(self, batch):
        return self._step(batch, "train")

    def validation_step(self, batch):
        return self._step(batch, "valid")

    def _step(self, batch, kind):
        x, y = batch

        p_logit = self.model(x)
        xent = ...

        p_proba = p_logit.softmax(axis=1)
        dice = ...
        jacc = ...

        loss = (
            self.xent_weight * xent
            + self.dice_weight * dice
            + self.jacc_weight * jacc
        )

        return self._log_metrics(
            kind, loss,
            xent=xent,
            dice=dice,
            jacc=jacc,
        )

    def _log_metrics(self, kind, loss, **metrics):
        metrics = {f"{kind}_{name}": value for name, value in metrics.items()}
        metrics[f"{kind}_loss"] = loss

        self.log_dict(
            metrics,
            prog_bar=True,
            logger=True,
            on_step=kind == "train",
            on_epoch=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=5e-4)

        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode="min",
            factor=0.1,
            patience=1,
        )
        lr_dict = {
            "scheduler": lr_scheduler,
            "interval": "epoch",
            "frequency": 1,
            "monitor": "valid_loss",
        }

        return [optimizer], [lr_dict]

### Callbacks

In [None]:
## Save the model periodically by monitoring a quantity.
MyModelCheckpoint = L.pytorch.callbacks.ModelCheckpoint(
    filename="{epoch}-{valid_loss:.3f}",
    monitor="valid_loss",
    mode="min",
    save_top_k=1,
)

## Monitor a metric and stop training when it stops improving.
MyEarlyStopping = L.pytorch.callbacks.EarlyStopping(
    monitor="valid_loss",
    mode="min",
    patience=2,
    verbose=True,
)

### Обучение

In [None]:
trainer = L.Trainer(
    max_epochs=7,
    default_root_dir="runs/segmentation",
    callbacks=[MyEarlyStopping, MyModelCheckpoint],
)

model = MyModel(num_classes=6)

In [None]:
trainer.fit(model, dl_train, dl_valid)

---

## Визуализация предсказаний обученной сети

In [None]:
# Set model to the evaluation mode
model.eval()

# Create a new simulation dataset for testing
test_dataset = SimDataset(3, preprocess=prepr)
test_loader = data.DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)

# Get the first batch
inputs, labels = next(iter(test_loader))

# Predict
with torch.no_grad():
    pred = model.model(inputs)
    pred = pred.argmax(axis=1)

# Change channel-order and make 3 channels for matplot
input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]

# Map each channel (i.e. class) to each color
target_labels_rgb = [labels_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [labels_to_colorimg(x) for x in pred.cpu().numpy()]

# plot
plot_side_by_side([input_images_rgb, target_labels_rgb, pred_rgb])

---

## Идеи для улучшения

* другие накопленные знания с прошлых семинаров
* функция потерь (как минимум покрутить `*_weight`)
* разморозка части слоев resnet18
* увелечение batch_size
* learning rates and schedules