Spaces:
Runtime error
Runtime error
| import importlib | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| parent_path = Path(__file__).absolute().parent.parent | |
| sys.path.append(os.path.abspath(parent_path)) | |
| os.chdir(parent_path) | |
| print(f'>-------------> parent path {parent_path}') | |
| print(f'>-------------> current work dir {os.getcwd()}') | |
| cache_path = os.path.join(parent_path, 'cache') | |
| os.environ["HF_DATASETS_CACHE"] = cache_path | |
| os.environ["TRANSFORMERS_CACHE"] = cache_path | |
| os.environ["torch_HOME"] = cache_path | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import numpy as np | |
| from torch.utils.tensorboard import SummaryWriter | |
| from sam_diffsr.utils_sr.hparams import hparams, set_hparams | |
| from sam_diffsr.utils_sr.utils import plot_img, move_to_cuda, load_checkpoint, save_checkpoint, tensors_to_scalars, Measure, \ | |
| get_all_ckpts | |
| class Trainer: | |
| def __init__(self): | |
| self.logger = self.build_tensorboard(save_dir=hparams['work_dir'], name='tb_logs') | |
| self.measure = Measure() | |
| self.dataset_cls = None | |
| self.metric_keys = ['psnr', 'ssim', 'lpips', 'lr_psnr'] | |
| self.metric_2_keys = ['psnr-Y', 'ssim', 'fid'] | |
| self.work_dir = hparams['work_dir'] | |
| self.first_val = True | |
| self.val_steps = hparams['val_steps'] | |
| def build_tensorboard(self, save_dir, name, **kwargs): | |
| log_dir = os.path.join(save_dir, name) | |
| os.makedirs(log_dir, exist_ok=True) | |
| return SummaryWriter(log_dir=log_dir, **kwargs) | |
| def build_train_dataloader(self): | |
| dataset = self.dataset_cls('train') | |
| return torch.utils.data.DataLoader( | |
| dataset, batch_size=hparams['batch_size'], shuffle=True, | |
| pin_memory=False, num_workers=hparams['num_workers']) | |
| def build_val_dataloader(self): | |
| return torch.utils.data.DataLoader( | |
| self.dataset_cls('valid'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
| def build_test_dataloader(self): | |
| return torch.utils.data.DataLoader( | |
| self.dataset_cls('test'), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
| def build_model(self): | |
| raise NotImplementedError | |
| def sample_and_test(self, sample): | |
| raise NotImplementedError | |
| def build_optimizer(self, model): | |
| raise NotImplementedError | |
| def build_scheduler(self, optimizer): | |
| raise NotImplementedError | |
| def training_step(self, batch): | |
| raise NotImplementedError | |
| def train(self): | |
| model = self.build_model() | |
| optimizer = self.build_optimizer(model) | |
| self.global_step = training_step = load_checkpoint(model, optimizer, hparams['work_dir'], steps=self.val_steps) | |
| self.scheduler = scheduler = self.build_scheduler(optimizer) | |
| scheduler.step(training_step) | |
| dataloader = self.build_train_dataloader() | |
| train_pbar = tqdm(dataloader, initial=training_step, total=float('inf'), | |
| dynamic_ncols=True, unit='step') | |
| while self.global_step < hparams['max_updates']: | |
| for batch in train_pbar: | |
| if training_step % hparams['val_check_interval'] == 0: | |
| with torch.no_grad(): | |
| model.eval() | |
| self.validate(training_step) | |
| save_checkpoint(model, optimizer, self.work_dir, training_step, hparams['num_ckpt_keep']) | |
| model.train() | |
| batch = move_to_cuda(batch) | |
| losses, total_loss = self.training_step(batch) | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| optimizer.step() | |
| training_step += 1 | |
| scheduler.step(training_step) | |
| self.global_step = training_step | |
| if training_step % 100 == 0: | |
| self.log_metrics({f'tr/{k}': v for k, v in losses.items()}, training_step) | |
| train_pbar.set_postfix(**tensors_to_scalars(losses)) | |
| def validate(self, training_step): | |
| val_dataloader = self.build_val_dataloader() | |
| pbar = tqdm(enumerate(val_dataloader), total=len(val_dataloader)) | |
| metrics = {} | |
| for batch_idx, batch in pbar: | |
| # 每次运行的第一次validation只跑一小部分数据,来验证代码能否跑通 | |
| if self.first_val and batch_idx > hparams['num_sanity_val_steps'] - 1: | |
| break | |
| batch = move_to_cuda(batch) | |
| img, rrdb_out, ret = self.sample_and_test(batch) | |
| img_hr = batch['img_hr'] | |
| img_lr = batch['img_lr'] | |
| img_lr_up = batch['img_lr_up'] | |
| if img is not None: | |
| self.logger.add_image(f'Pred_{batch_idx}', plot_img(img[0]), self.global_step) | |
| if hparams.get('aux_l1_loss'): | |
| self.logger.add_image(f'rrdb_out_{batch_idx}', plot_img(rrdb_out[0]), self.global_step) | |
| if self.global_step <= hparams['val_check_interval']: | |
| self.logger.add_image(f'HR_{batch_idx}', plot_img(img_hr[0]), self.global_step) | |
| self.logger.add_image(f'LR_{batch_idx}', plot_img(img_lr[0]), self.global_step) | |
| self.logger.add_image(f'BL_{batch_idx}', plot_img(img_lr_up[0]), self.global_step) | |
| metrics = {} | |
| metrics.update({k: np.mean(ret[k]) for k in self.metric_keys}) | |
| pbar.set_postfix(**tensors_to_scalars(metrics)) | |
| if hparams['infer']: | |
| print('Val results:', metrics) | |
| else: | |
| if not self.first_val: | |
| self.log_metrics({f'val/{k}': v for k, v in metrics.items()}, training_step) | |
| print('Val results:', metrics) | |
| else: | |
| print('Sanity val results:', metrics) | |
| self.first_val = False | |
| def build_test_my_dataloader(self, data_name): | |
| return torch.utils.data.DataLoader( | |
| self.dataset_cls(data_name), batch_size=hparams['eval_batch_size'], shuffle=False, pin_memory=False) | |
| def benchmark(self, benchmark_name_list, metric_list): | |
| from sam_diffsr.tools.caculate_iqa import eval_img_IQA | |
| model = self.build_model() | |
| optimizer = self.build_optimizer(model) | |
| training_step = load_checkpoint(model, optimizer, hparams['work_dir'], hparams['val_steps']) | |
| self.global_step = training_step | |
| optimizer = None | |
| for data_name in benchmark_name_list: | |
| test_dataloader = self.build_test_my_dataloader(data_name) | |
| self.results = {k: 0 for k in self.metric_keys} | |
| self.n_samples = 0 | |
| self.gen_dir = f"{hparams['work_dir']}/results_{self.global_step}_{hparams['gen_dir_name']}/benchmark/{data_name}" | |
| if hparams['test_save_png']: | |
| subprocess.check_call(f'rm -rf {self.gen_dir}', shell=True) | |
| os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True) | |
| os.makedirs(f'{self.gen_dir}/SR', exist_ok=True) | |
| self.model.sample_tqdm = False | |
| torch.backends.cudnn.benchmark = False | |
| if hparams['test_save_png']: | |
| if hasattr(self.model.denoise_fn, 'make_generation_fast_'): | |
| self.model.denoise_fn.make_generation_fast_() | |
| os.makedirs(f'{self.gen_dir}/HR', exist_ok=True) | |
| result_dict = {} | |
| with torch.no_grad(): | |
| model.eval() | |
| pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader)) | |
| for batch_idx, batch in pbar: | |
| move_to_cuda(batch) | |
| gen_dir = self.gen_dir | |
| item_names = batch['item_name'] | |
| img_hr = batch['img_hr'] | |
| img_lr = batch['img_lr'] | |
| img_lr_up = batch['img_lr_up'] | |
| res = self.sample_and_test(batch) | |
| if len(res) == 3: | |
| img_sr, rrdb_out, ret = res | |
| else: | |
| img_sr, ret = res | |
| rrdb_out = img_sr | |
| img_lr_up = batch.get('img_lr_up', img_lr_up) | |
| if img_sr is not None: | |
| metrics = list(self.metric_keys) | |
| result_dict[batch['item_name'][0]] = {} | |
| for k in metrics: | |
| self.results[k] += ret[k] | |
| result_dict[batch['item_name'][0]][k] = ret[k] | |
| self.n_samples += ret['n_samples'] | |
| print({k: round(self.results[k] / self.n_samples, 3) for k in self.results}, 'total:', | |
| self.n_samples) | |
| if hparams['test_save_png'] and img_sr is not None: | |
| img_sr = self.tensor2img(img_sr) | |
| img_hr = self.tensor2img(img_hr) | |
| img_lr = self.tensor2img(img_lr) | |
| img_lr_up = self.tensor2img(img_lr_up) | |
| rrdb_out = self.tensor2img(rrdb_out) | |
| for item_name, hr_p, hr_g, lr, lr_up, rrdb_o in zip( | |
| item_names, img_sr, img_hr, img_lr, img_lr_up, rrdb_out): | |
| item_name = os.path.splitext(item_name)[0] | |
| hr_p = Image.fromarray(hr_p) | |
| hr_g = Image.fromarray(hr_g) | |
| hr_p.save(f"{gen_dir}/SR/{item_name}.png") | |
| hr_g.save(f"{gen_dir}/HR/{item_name}.png") | |
| exp_name = hparams['work_dir'].split('/')[-1] | |
| sr_img_dir = f"{gen_dir}/SR/" | |
| gt_img_dir = f"{gen_dir}/HR/" | |
| excel_path = f"{hparams['work_dir']}/IQA-val-benchmark-{exp_name}.xlsx" | |
| epoch = training_step | |
| eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name) | |
| os.makedirs(f'{self.gen_dir}', exist_ok=True) | |
| eval_json_path = os.path.join(self.gen_dir, 'eval.json') | |
| avg_result = {k: round(self.results[k] / self.n_samples, 4) for k in self.results} | |
| with open(eval_json_path, 'w+') as file: | |
| json.dump(avg_result, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False) | |
| json.dump(result_dict, file, sort_keys=True, indent=4, separators=(',', ': '), ensure_ascii=False) | |
| def benchmark_loop(self, benchmark_name_list, metric_list, gt_path): | |
| # infer and evaluation all save checkpoint | |
| from sam_diffsr.tools.caculate_iqa import eval_img_IQA | |
| model = self.build_model() | |
| def get_checkpoint(model, checkpoint): | |
| stat_dict = checkpoint['state_dict']['model'] | |
| new_state_dict = OrderedDict() | |
| for k, v in stat_dict.items(): | |
| if k[:7] == 'module.': | |
| k = k[7:] # 去掉 `module.` | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| model.cuda() | |
| training_step = checkpoint['global_step'] | |
| del checkpoint | |
| torch.cuda.empty_cache() | |
| return training_step | |
| ckpt_paths = get_all_ckpts(hparams['work_dir']) | |
| for ckpt_path in ckpt_paths: | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| training_step = get_checkpoint(model, checkpoint) | |
| self.global_step = training_step | |
| for data_name in benchmark_name_list: | |
| test_dataloader = self.build_test_my_dataloader(data_name) | |
| self.results = {k: 0 for k in self.metric_keys + self.metric_2_keys} | |
| self.n_samples = 0 | |
| self.gen_dir = f"{hparams['work_dir']}/results_{training_step}_{hparams['gen_dir_name']}/benchmark/{data_name}" | |
| os.makedirs(f'{self.gen_dir}/outputs', exist_ok=True) | |
| os.makedirs(f'{self.gen_dir}/SR', exist_ok=True) | |
| self.model.sample_tqdm = False | |
| torch.backends.cudnn.benchmark = False | |
| with torch.no_grad(): | |
| model.eval() | |
| pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader)) | |
| for batch_idx, batch in pbar: | |
| move_to_cuda(batch) | |
| gen_dir = self.gen_dir | |
| item_names = batch['item_name'] | |
| res = self.sample_and_test(batch) | |
| if len(res) == 3: | |
| img_sr, rrdb_out, ret = res | |
| else: | |
| img_sr, ret = res | |
| rrdb_out = img_sr | |
| img_sr = self.tensor2img(img_sr) | |
| for item_name, hr_p in zip(item_names, img_sr): | |
| item_name = os.path.splitext(item_name)[0] | |
| hr_p = Image.fromarray(hr_p) | |
| hr_p.save(f"{gen_dir}/SR/{item_name}.png") | |
| exp_name = hparams['work_dir'].split('/')[-1] | |
| sr_img_dir = f"{gen_dir}/SR/" | |
| gt_img_dir = f"{gt_path}/{data_name}/HR" | |
| excel_path = f"{hparams['work_dir']}/IQA-val-benchmark_loop-{exp_name}.xlsx" | |
| epoch = training_step | |
| eval_img_IQA(gt_img_dir, sr_img_dir, excel_path, metric_list, epoch, data_name) | |
| # utils_sr | |
| def log_metrics(self, metrics, step): | |
| metrics = self.metrics_to_scalars(metrics) | |
| logger = self.logger | |
| for k, v in metrics.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.item() | |
| logger.add_scalar(k, v, step) | |
| def metrics_to_scalars(self, metrics): | |
| new_metrics = {} | |
| for k, v in metrics.items(): | |
| if isinstance(v, torch.Tensor): | |
| v = v.item() | |
| if type(v) is dict: | |
| v = self.metrics_to_scalars(v) | |
| new_metrics[k] = v | |
| return new_metrics | |
| def tensor2img(img): | |
| img = np.round((img.permute(0, 2, 3, 1).cpu().numpy() + 1) * 127.5) | |
| img = img.clip(min=0, max=255).astype(np.uint8) | |
| return img | |
| if __name__ == '__main__': | |
| set_hparams() | |
| pkg = ".".join(hparams["trainer_cls"].split(".")[:-1]) | |
| cls_name = hparams["trainer_cls"].split(".")[-1] | |
| trainer = getattr(importlib.import_module(pkg), cls_name)() | |
| if hparams['benchmark_loop']: | |
| trainer.benchmark_loop(hparams['benchmark_name_list'], hparams['metric_list'], hparams['gt_img_path']) | |
| elif hparams['benchmark']: | |
| trainer.benchmark(hparams['benchmark_name_list'], hparams['metric_list']) | |
| else: | |
| trainer.train() | |