Остаточная нейросеть (ResNet)

Автор:

Архитектура нейросетей, которая позволяет обучать очень глубокие модели за счёт специальных пропускных связей

ResNet (Residual Network) — это архитектура нейросетей, которая позволяет обучать очень глубокие модели за счёт специальных пропускных связей (skip connections), добавляющих вход слоя напрямую к его выходу. ResNet учит сеть не ломаться, когда она становится глубокой.

До появления ResNet в deep learning была странная и неприятная проблема.
— добавляем больше слоёв;
— получаем более мощную модель;
— качество должно расти.

Но на практике происходило обратное:
— глубокие сети хуже обучались;
— ошибка могла расти, даже на обучающей выборке;
— обучение становилось нестабильным.

Это была не проблема переобучения, а именно проблема обучаемости (optimization). Сеть из 30 слоёв могла работать хуже, чем сеть из 10 слоёв — просто потому, что оптимизатор «терялся» в глубине.

Создатели ResNet предложили простой трюк: вместо того чтобы каждый слой учил новое представление «с нуля», пусть он добавляет небольшое улучшение к уже существующему.

Для этого:
— входные данные слоя напрямую передаётся дальше,
— результат слоя складывается с этим входом.

Если слой не нужен — он может почти ничего не делать. Если нужен — аккуратно скорректировать результат. Это резко упрощает обучение глубоких сетей.

Что это дало на практике

После появления ResNet стало возможным:
— обучать сети глубиной 50, 100 и более слоёв;
— стабильно улучшать качество при увеличении глубины;
— использовать глубину как реальный инструмент, а не риск.

ResNet стал стандартом де-факто в Сomputer Vision и повлиял на дизайн почти всех современных архитектур.

Код ниже — минимальный пример residual-блока:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # если размерности не совпадают — приводим их
        self.shortcut = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


Проверка формы выхода:

x = torch.randn(1, 64, 56, 56)
block = ResidualBlock(64, 128, stride=2)
y = block(x)

print(y.shape) # torch.Size([1, 128, 28, 28])