Spaces:
Runtime error
Runtime error
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import transforms | |
| from sam_diffsr.utils_sr.hparams import set_hparams, hparams | |
| from sam_diffsr.utils_sr.matlab_resize import imresize | |
| from sam_diffsr.tasks.srdiff_df2k_sam import SRDiffDf2k_sam as trainer_ori | |
| ROOT_PATH = os.path.dirname(__file__) | |
| class sam_diffsr_demo: | |
| def __init__(self): | |
| set_hparams() | |
| ckpt_path = os.path.join(ROOT_PATH, 'weight/model_ckpt_steps_400000.ckpt') | |
| self.model_init(ckpt_path) | |
| def get_img_data(self, img_PIL, hparams, sr_scale=4): | |
| img_lr = img_PIL.convert('RGB') | |
| img_lr = np.uint8(np.asarray(img_lr)) | |
| h, w, c = img_lr.shape | |
| h, w = h * sr_scale, w * sr_scale | |
| h = h - h % (sr_scale * 2) | |
| w = w - w % (sr_scale * 2) | |
| h_l = h // sr_scale | |
| w_l = w // sr_scale | |
| img_lr = img_lr[:h_l, :w_l] | |
| to_tensor_norm = transforms.Compose([ | |
| transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| img_lr_up = imresize(img_lr / 256, hparams['sr_scale']) # np.float [H, W, C] | |
| img_lr, img_lr_up = [to_tensor_norm(x).float() for x in [img_lr, img_lr_up]] | |
| img_lr = torch.unsqueeze(img_lr, dim=0) | |
| img_lr_up = torch.unsqueeze(img_lr_up, dim=0) | |
| return img_lr, img_lr_up | |
| def load_checkpoint(self, ckpt_path): | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| print(f'loding check from: {ckpt_path}') | |
| stat_dict = checkpoint['state_dict']['model'] | |
| new_state_dict = OrderedDict() | |
| for k, v in stat_dict.items(): | |
| if k[:7] == 'module.': | |
| k = k[7:] # ε»ζ `module.` | |
| new_state_dict[k] = v | |
| self.model.model.load_state_dict(new_state_dict) | |
| self.model.model.cuda() | |
| del checkpoint | |
| torch.cuda.empty_cache() | |
| def model_init(self, ckpt_path): | |
| self.model = trainer_ori() | |
| self.model.build_model() | |
| self.load_checkpoint(ckpt_path) | |
| torch.backends.cudnn.benchmark = False | |
| def infer(self, img_PIL): | |
| with torch.no_grad(): | |
| self.model.model.eval() | |
| img_lr, img_lr_up = self.get_img_data(img_PIL, hparams, sr_scale=4) | |
| img_lr = img_lr.to('cuda') | |
| img_lr_up = img_lr_up.to('cuda') | |
| img_sr, _ = self.model.model.sample(img_lr, img_lr_up, img_lr_up.shape) | |
| img_sr = img_sr.clamp(-1, 1) | |
| img_sr = self.model.tensor2img(img_sr)[0] | |
| img_sr = Image.fromarray(img_sr) | |
| return img_sr | |