Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| from sam_diffsr.tasks.srdiff import SRDiffTrainer | |
| from sam_diffsr.utils_sr.dataset import SRDataSet | |
| from sam_diffsr.utils_sr.hparams import hparams | |
| from sam_diffsr.utils_sr.matlab_resize import imresize | |
| class InferDataSet(Dataset): | |
| def __init__(self, img_dir): | |
| super().__init__() | |
| self.img_path_list = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)] | |
| self.to_tensor_norm = transforms.Compose([ | |
| transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| def __getitem__(self, index): | |
| sr_scale = hparams['sr_scale'] | |
| img_path = self.img_path_list[index] | |
| img_name = os.path.basename(img_path) | |
| img_lr = Image.open(img_path).convert('RGB') | |
| img_lr = np.uint8(np.asarray(img_lr)) | |
| h, w, c = img_lr.shape | |
| h, w = h * sr_scale, w * sr_scale | |
| h = h - h % (sr_scale * 2) | |
| w = w - w % (sr_scale * 2) | |
| h_l = h // sr_scale | |
| w_l = w // sr_scale | |
| img_lr = img_lr[:h_l, :w_l] | |
| img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
| img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] | |
| return img_lr, img_lr_up, img_name | |
| def __len__(self): | |
| return len(self.img_path_list) | |
| class Df2kDataSet(SRDataSet): | |
| def __init__(self, prefix='train'): | |
| if prefix == 'valid': | |
| _prefix = 'test' | |
| else: | |
| _prefix = prefix | |
| super().__init__(_prefix) | |
| self.patch_size = hparams['patch_size'] | |
| self.patch_size_lr = hparams['patch_size'] // hparams['sr_scale'] | |
| if prefix == 'valid': | |
| self.len = hparams['eval_batch_size'] * hparams['valid_steps'] | |
| self.data_aug_transforms = transforms.Compose([ | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(20, resample=Image.BICUBIC), | |
| transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), | |
| ]) | |
| def __getitem__(self, index): | |
| item = self._get_item(index) | |
| hparams = self.hparams | |
| sr_scale = hparams['sr_scale'] | |
| img_hr = np.uint8(item['img']) | |
| img_lr = np.uint8(item['img_lr']) | |
| # TODO: clip for SRFlow | |
| h, w, c = img_hr.shape | |
| h = h - h % (sr_scale * 2) | |
| w = w - w % (sr_scale * 2) | |
| h_l = h // sr_scale | |
| w_l = w // sr_scale | |
| img_hr = img_hr[:h, :w] | |
| img_lr = img_lr[:h_l, :w_l] | |
| # random crop | |
| if self.prefix == 'train': | |
| if self.data_augmentation and random.random() < 0.5: | |
| img_hr, img_lr = self.data_augment(img_hr, img_lr) | |
| i = random.randint(0, h - self.patch_size) // sr_scale * sr_scale | |
| i_lr = i // sr_scale | |
| j = random.randint(0, w - self.patch_size) // sr_scale * sr_scale | |
| j_lr = j // sr_scale | |
| img_hr = img_hr[i:i + self.patch_size, j:j + self.patch_size] | |
| img_lr = img_lr[i_lr:i_lr + self.patch_size_lr, j_lr:j_lr + self.patch_size_lr] | |
| img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
| img_hr, img_lr, img_lr_up = [self.to_tensor_norm(x).float() for x in [img_hr, img_lr, img_lr_up]] | |
| return { | |
| 'img_hr': img_hr, 'img_lr': img_lr, | |
| 'img_lr_up': img_lr_up, 'item_name': item['item_name'], | |
| 'loc': np.array(item['loc']), 'loc_bdr': np.array(item['loc_bdr']) | |
| } | |
| def __len__(self): | |
| return self.len | |
| def data_augment(self, img_hr, img_lr): | |
| sr_scale = self.hparams['sr_scale'] | |
| img_hr = Image.fromarray(img_hr) | |
| img_hr = self.data_aug_transforms(img_hr) | |
| img_hr = np.asarray(img_hr) # np.uint8 [H, W, C] | |
| img_lr = imresize(img_hr, 1 / sr_scale) | |
| return img_hr, img_lr | |
| class SRDiffDf2k(SRDiffTrainer): | |
| def __init__(self): | |
| super().__init__() | |
| self.dataset_cls = Df2kDataSet | |