Source code for torch_enhance.metrics

import torch
import torch.nn.functional as F


[docs]@torch.no_grad() def mse(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """Mean squared error (MSE) metric Parameters ---------- y_pred : torch.Tensor Super-Resolved image tensor y_true : torch.Tensor High Resolution image tensor Returns ------- torch.Tensor Mean squared error between y_true and y_pred """ return F.mse_loss(y_pred, y_true)
[docs]@torch.no_grad() def mae(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """Mean absolute error (MAE) metric Parameters ---------- y_pred : torch.Tensor Super-Resolved image tensor y_true : torch.Tensor High Resolution image tensor Returns ------- torch.Tensor Mean absolute error between y_true and y_pred """ return F.l1_loss(y_pred, y_true)
[docs]@torch.no_grad() def psnr(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: """Peak-signal-noise ratio (PSNR) metric Parameters ---------- y_pred : torch.Tensor Super-Resolved image tensor y_true : torch.Tensor High Resolution image tensor Returns ------- torch.Tensor Peak-signal-noise-ratio between y_true and y_pred """ return 10 * (1 / mse(y_pred, y_true)).log10()