Spaces:
Runtime error
Runtime error
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from .hparams import hparams | |
| from .indexed_datasets import IndexedDataset | |
| from .matlab_resize import imresize | |
| class SRDataSet(Dataset): | |
| def __init__(self, prefix='train'): | |
| self.hparams = hparams | |
| self.data_dir = hparams['binary_data_dir'] | |
| self.prefix = prefix | |
| self.len = len(IndexedDataset(f'{self.data_dir}/{self.prefix}')) | |
| self.to_tensor_norm = transforms.Compose([ | |
| transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| assert hparams['data_interp'] in ['bilinear', 'bicubic'] | |
| self.data_augmentation = hparams['data_augmentation'] | |
| self.indexed_ds = None | |
| if self.prefix == 'valid': | |
| self.len = hparams['eval_batch_size'] * hparams['valid_steps'] | |
| def _get_item(self, index): | |
| if self.indexed_ds is None: | |
| self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') | |
| return self.indexed_ds[index] | |
| def __getitem__(self, index): | |
| item = self._get_item(index) | |
| hparams = self.hparams | |
| img_hr = item['img'] | |
| img_hr = Image.fromarray(np.uint8(img_hr)) | |
| img_hr = self.pre_process(img_hr) # PIL | |
| img_hr = np.asarray(img_hr) # np.uint8 [H, W, C] | |
| img_lr = imresize(img_hr, 1 / hparams['sr_scale'], method=hparams['data_interp']) # np.uint8 [H, W, C] | |
| img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
| img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]] | |
| return { | |
| 'img_hr': img_hr, 'img_lr': img_lr, 'img_lr_up': img_lr_up, | |
| 'item_name': item['item_name'] | |
| } | |
| def pre_process(self, img_hr): | |
| return img_hr | |
| def __len__(self): | |
| return self.len | |