import torch
import torch.nn as nn
from .base import BaseModel
[docs]class ESPCN(BaseModel):
"""Efficient Sub-Pixel Convolutional Neural Network
https://arxiv.org/pdf/1609.05158v2.pdf
Parameters
----------
scale_factor : int
Super-Resolution scale factor. Determines Low-Resolution downsampling.
pretrained : bool
If True download and load pretrained weights
"""
def __init__(self, scale_factor: int):
super(ESPCN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=3 * scale_factor**2, kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(scale_factor),
)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Super-resolve Low-Resolution input tensor
Parameters
----------
x : torch.Tensor
Input Low-Resolution image as tensor
Returns
-------
torch.Tensor
Super-Resolved image as tensor
"""
x = self.model(x)
return x