Source code for torch_enhance.models.vdsr

import torch
import torch.nn as nn

from .base import BaseModel
from .baseline import Bicubic


[docs]class VDSR(BaseModel): """Very Deep Super Resolution https://arxiv.org/pdf/1511.04587.pdf Parameters ---------- scale_factor : int Super-Resolution scale factor. Determines Low-Resolution downsampling. """ def __init__(self, scale_factor): super(VDSR, self).__init__() self.n_layers = 20 self.upsample = Bicubic(scale_factor) # Initial layer layers = [ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), ] # Residual reconstruction for i in range(self.n_layers - 2): layers.append(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)) layers.append(nn.ReLU()) # Output reconstruction layer layers.append(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1)) self.model = nn.Sequential(*layers)
[docs] def forward(self, x): """Super-resolve Low-Resolution input tensor Parameters ---------- x : torch.Tensor Input Low-Resolution image as tensor Returns ------- torch.Tensor Super-Resolved image as tensor """ x = self.upsample(x) x = self.model(x) + x return x