Spaces:
Runtime error
Runtime error
| import importlib | |
| import os | |
| import sys | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from tasks.srdiff_df2k import InferDataSet | |
| parent_path = Path(__file__).absolute().parent.parent | |
| sys.path.append(os.path.abspath(parent_path)) | |
| os.chdir(parent_path) | |
| print(f'>-------------> parent path {parent_path}') | |
| print(f'>-------------> current work dir {os.getcwd()}') | |
| cache_path = os.path.join(parent_path, 'cache') | |
| os.environ["HF_DATASETS_CACHE"] = cache_path | |
| os.environ["TRANSFORMERS_CACHE"] = cache_path | |
| os.environ["torch_HOME"] = cache_path | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torch.utils.tensorboard import SummaryWriter | |
| from utils_sr.hparams import hparams, set_hparams | |
| def load_ckpt(ckpt_path, model): | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| stat_dict = checkpoint['state_dict']['model'] | |
| new_state_dict = OrderedDict() | |
| for k, v in stat_dict.items(): | |
| if k[:7] == 'module.': | |
| k = k[7:] # ε»ζ `module.` | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| model.cuda() | |
| def infer(trainer, ckpt_path, img_dir, save_dir): | |
| trainer.build_model() | |
| load_ckpt(ckpt_path, trainer.model) | |
| dataset = InferDataSet(img_dir) | |
| test_dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
| torch.backends.cudnn.benchmark = False | |
| with torch.no_grad(): | |
| trainer.model.eval() | |
| pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader)) | |
| for batch_idx, batch in pbar: | |
| img_lr, img_lr_up, img_name = batch | |
| img_lr = img_lr.to('cuda') | |
| img_lr_up = img_lr_up.to('cuda') | |
| img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape) | |
| img_sr = img_sr.clamp(-1, 1) | |
| img_sr = trainer.tensor2img(img_sr)[0] | |
| img_sr = Image.fromarray(img_sr) | |
| img_sr.save(os.path.join(save_dir, img_name[0])) | |
| if __name__ == '__main__': | |
| set_hparams() | |
| img_dir = hparams['img_dir'] | |
| save_dir = hparams['save_dir'] | |
| ckpt_path = hparams['ckpt_path'] | |
| pkg = ".".join(hparams["trainer_cls"].split(".")[:-1]) | |
| cls_name = hparams["trainer_cls"].split(".")[-1] | |
| trainer = getattr(importlib.import_module(pkg), cls_name)() | |
| os.makedirs(save_dir, exist_ok=True) | |
| infer(trainer, ckpt_path, img_dir, save_dir) | |