import os
import shutil
from torchvision.datasets.utils import download_and_extract_archive
from .base import BSDS300_URL, BaseDataset
[docs]class BSDS300(BaseDataset):
url = BSDS300_URL
extensions = ['.jpg']
def __init__(
self,
scale_factor: int = 2,
image_size: int = 256,
color_space: str = 'RGB',
train: bool = True,
data_dir: str = '',
lr_transforms=None,
hr_transforms=None
):
super(BSDS300, self).__init__()
self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space
self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms
if data_dir == '':
data_dir = os.path.join(os.getcwd(), self.base_dir)
self.root_dir = os.path.join(data_dir, 'BSDS300')
self.download(data_dir)
self.set_dir = os.path.join(self.root_dir, 'train' if train else 'test')
self.file_names = self.get_files(self.set_dir)
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()
[docs] def download(self, data_dir: str) -> None:
"""Download dataset
Parameters
----------
data_dir : str
Path to base dataset directory
Returns
-------
None
"""
if not os.path.exists(data_dir):
os.mkdir(data_dir)
if not os.path.exists(self.root_dir):
os.makedirs(self.root_dir)
download_and_extract_archive(self.url, data_dir, remove_finished=True)
# Tidy up
for d in ['train', 'test']:
shutil.move(src=os.path.join(self.root_dir, 'images', d), dst=self.root_dir)
for f in os.listdir(self.root_dir):
if f not in ['train', 'test']:
path = os.path.join(self.root_dir, f)
if os.path.isdir(path):
_ = shutil.rmtree(path)
else:
_ = os.remove(path)