Source code for torch_enhance.models.edsr

import torch
import torch.nn as nn
import torch.nn.functional as F

from .base import BaseModel


class UpsampleBlock(nn.Module):
    """Base PixelShuffle Upsample Block

    """
    def __init__(
        self,
        n_upsamples: int,
        channels: int,
        kernel_size: int
    ):

        super(UpsampleBlock, self).__init__()

        layers = []
        for _ in range(n_upsamples):
            layers.extend([
            nn.Conv2d(in_channels=channels, out_channels=channels * 2**2, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
            nn.PixelShuffle(2)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x


class ResidualBlock(nn.Module):
    """Base Residual Block

    """
    def __init__(
        self,
        channels: int,
        kernel_size: int,
        res_scale: int,
        activation
    ):

        super(ResidualBlock, self).__init__()

        self.res_scale = res_scale
        
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
            activation(),
            nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=kernel_size//2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.model(x) * self.res_scale
        x = x + shortcut
        return x


[docs]class EDSR(BaseModel): """Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf Parameters ---------- scale_factor : int Super-Resolution scale factor. Determines Low-Resolution downsampling. """ def __init__(self, scale_factor: int): super(EDSR, self).__init__() self.n_res_blocks = 32 # Pre Residual Blocks self.head = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=256, kernel_size=3, stride=1, padding=1), ) # Residual Blocks self.res_blocks = [ ResidualBlock(channels=256, kernel_size=3, res_scale=0.1, activation=nn.ReLU) \ for _ in range(self.n_res_blocks) ] self.res_blocks.append(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1)) self.res_blocks = nn.Sequential(*self.res_blocks) # Upsamples n_upsamples = 1 if scale_factor == 2 else 2 self.upsample = UpsampleBlock( n_upsamples=n_upsamples, channels=256, kernel_size=3 ) # Output layer self.tail = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Super-resolve Low-Resolution input tensor Parameters ---------- x : torch.Tensor Input Low-Resolution image as tensor Returns ------- torch.Tensor Super-Resolved image as tensor """ x = self.head(x) shortcut = x x = self.res_blocks(x) + shortcut x = self.upsample(x) x = self.tail(x) return x