import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
[docs]class VGG(nn.Module):
"""VGG/Perceptual Loss
Parameters
----------
conv_index : str
Convolutional layer in VGG model to use as perceptual output
"""
def __init__(self, conv_index: str = '22'):
super(VGG, self).__init__()
vgg_features = torchvision.models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229, 0.224, 0.225)
#self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
self.vgg.requires_grad = False
[docs] def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor:
"""Compute VGG/Perceptual loss between Super-Resolved and High-Resolution
Parameters
----------
sr : torch.Tensor
Super-Resolved model output tensor
hr : torch.Tensor
High-Resolution image tensor
Returns
-------
loss : torch.Tensor
Perceptual VGG loss between sr and hr
"""
def _forward(x):
#x = self.sub_mean(x)
x = self.vgg(x)
return x
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss