Spaces:
Runtime error
Runtime error
| import datetime | |
| import math | |
| import os | |
| import torch | |
| import time | |
| import skimage.io | |
| import skimage.transform | |
| import matplotlib.pyplot as plt | |
| import glob | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from skimage import exposure | |
| toTensor = transforms.ToTensor() | |
| toPIL = transforms.ToPILImage() | |
| import numpy as np | |
| from PIL import Image | |
| from models import * | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| def remove_dataparallel_wrapper(state_dict): | |
| r"""Converts a DataParallel model to a normal one by removing the "module." | |
| wrapper in the module dictionary | |
| Args: | |
| state_dict: a torch.nn.DataParallel state dictionary | |
| """ | |
| from collections import OrderedDict | |
| new_state_dict = OrderedDict() | |
| for k, vl in state_dict.items(): | |
| name = k[7:] # remove 'module.' of DataParallel | |
| new_state_dict[name] = vl | |
| return new_state_dict | |
| from argparse import Namespace | |
| def GetOptions(): | |
| # training options | |
| opt = Namespace() | |
| opt.model = 'rcan' | |
| opt.n_resgroups = 3 | |
| opt.n_resblocks = 10 | |
| opt.n_feats = 96 | |
| opt.reduction = 16 | |
| opt.narch = 0 | |
| opt.norm = 'minmax' | |
| opt.cpu = False | |
| opt.multigpu = False | |
| opt.undomulti = False | |
| opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') | |
| opt.imageSize = 512 | |
| opt.weights = "model/simrec_simin_gtout_rcan_512_2_ntrain790-final.pth" | |
| opt.root = "model/0080.jpg" | |
| opt.out = "model" | |
| opt.task = 'simin_gtout' | |
| opt.scale = 1 | |
| opt.nch_in = 9 | |
| opt.nch_out = 1 | |
| return opt | |
| def GetOptions_Swin_2702(): | |
| # training options | |
| opt = Namespace() | |
| opt.model = 'swinir_rcab' | |
| opt.norm = 'minmax' | |
| opt.cpu = False | |
| opt.multigpu = False | |
| opt.undomulti = False | |
| opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu') | |
| opt.imageSize = 128 | |
| opt.weights = "model/swinir_rcab_nov2021.pth" | |
| opt.root = "model/0080.jpg" | |
| opt.out = "model" | |
| opt.task = 'simin_gtout' | |
| opt.scale = 2 | |
| opt.nch_in = 9 | |
| opt.nch_out = 1 | |
| opt.model_opts = {"depths":[6, 6, 6, 6, 6], | |
| "embed_dim": 64, | |
| "window_size": 8, | |
| "num_heads": [8,8,8,8,8], | |
| "use_rcab": True, | |
| "mlp_ratio": 2 } | |
| return opt | |
| def LoadModel(opt): | |
| print('Loading model') | |
| print(opt) | |
| net = GetModel(opt) | |
| print('loading checkpoint',opt.weights) | |
| checkpoint = torch.load(opt.weights,map_location=opt.device) | |
| if type(checkpoint) is dict: | |
| if 'params_ema' in checkpoint: | |
| keyname = 'params_ema' | |
| elif 'state_dict' in checkpoint: | |
| keyname = 'state_dict' | |
| else: | |
| keyname = 'params' | |
| state_dict = checkpoint[keyname] | |
| else: | |
| state_dict = checkpoint | |
| if opt.undomulti: | |
| state_dict = remove_dataparallel_wrapper(state_dict) | |
| net.load_state_dict(state_dict) | |
| return net | |
| def prepimg(stack,self): | |
| inputimg = stack[:9] | |
| if self.nch_in == 6: | |
| inputimg = inputimg[[0,1,3,4,6,7]] | |
| elif self.nch_in == 3: | |
| inputimg = inputimg[[0,4,8]] | |
| if inputimg.shape[1] > 512 or inputimg.shape[2] > 512: | |
| print('Over 512x512! Cropping') | |
| inputimg = inputimg[:,:512,:512] | |
| if self.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering | |
| print('Raw input assumed - converting') | |
| # NCHW | |
| # I = np.zeros((9,opt.imageSize,opt.imageSize),dtype='uint16') | |
| # for t in range(9): | |
| # frame = inputimg[t] | |
| # frame = 120 / np.max(frame) * frame | |
| # frame = np.rot90(np.rot90(np.rot90(frame))) | |
| # I[t,:,:] = frame | |
| # inputimg = I | |
| inputimg = np.rot90(inputimg,axes=(1,2)) | |
| inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0] | |
| for i in range(len(inputimg)): | |
| inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i] | |
| elif 'convert' in self.norm: | |
| fac = float(self.norm[7:]) | |
| inputimg = np.rot90(inputimg,axes=(1,2)) | |
| inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0] | |
| for i in range(len(inputimg)): | |
| inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i] | |
| inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255 | |
| widefield = np.mean(inputimg,0) | |
| if self.norm == 'adapthist': | |
| for i in range(len(inputimg)): | |
| inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001) | |
| widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001) | |
| else: | |
| # normalise | |
| inputimg = torch.tensor(inputimg).float() | |
| widefield = torch.tensor(widefield).float() | |
| widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield)) | |
| if self.norm == 'minmax': | |
| for i in range(len(inputimg)): | |
| inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i])) | |
| elif 'minmax' in self.norm: | |
| fac = float(self.norm[6:]) | |
| for i in range(len(inputimg)): | |
| inputimg[i] = fac * (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i])) | |
| # otf = torch.tensor(otf.astype('float') / np.max(otf)).unsqueeze(0).float() | |
| # gt = torch.tensor(gt.astype('float') / 255).unsqueeze(0).float() | |
| # simimg = torch.tensor(simimg.astype('float') / 255).unsqueeze(0).float() | |
| # widefield = torch.mean(inputimg,0).unsqueeze(0) | |
| # normalise | |
| # gt = (gt - torch.min(gt)) / (torch.max(gt) - torch.min(gt)) | |
| # simimg = (simimg - torch.min(simimg)) / (torch.max(simimg) - torch.min(simimg)) | |
| # widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield)) | |
| inputimg = torch.tensor(inputimg).float() | |
| widefield = torch.tensor(widefield).float() | |
| return inputimg,widefield | |
| def save_image(data, filename,cmap): | |
| sizes = np.shape(data) | |
| fig = plt.figure() | |
| fig.set_size_inches(1. * sizes[0] / sizes[1], 1, forward = False) | |
| ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
| ax.set_axis_off() | |
| fig.add_axes(ax) | |
| ax.imshow(data, cmap=cmap) | |
| plt.savefig(filename, dpi = sizes[0]) | |
| plt.close() | |
| def EvaluateModel(net,opt,stack): | |
| outfile = datetime.datetime.utcnow().strftime('%H-%M-%S') | |
| outfile = 'VSR-SIM_%s' % outfile | |
| os.makedirs(opt.out, exist_ok=True) | |
| print(stack.shape) | |
| inputimg, widefield = prepimg(stack, opt) | |
| if opt.norm == 'convert' or 'minmax' in opt.norm or 'adapthist' in opt.norm: | |
| cmap = 'viridis' | |
| else: | |
| cmap = 'gray' | |
| # skimage.io.imsave('%s_wf.png' % outfile,(255*widefield.numpy()).astype('uint8')) | |
| wf = (255*widefield.numpy()).astype('uint8') | |
| wf_upscaled = skimage.transform.resize(wf,(768,768),order=3) # should ideally be done by drawing on client side, in javascript | |
| save_image(wf_upscaled,'%s_wf.png' % outfile,cmap) | |
| # skimage.io.imsave('%s.tif' % outfile, inputimg.numpy()) | |
| inputimg = inputimg.unsqueeze(0) | |
| with torch.no_grad(): | |
| sr = net(inputimg.to(opt.device)) | |
| sr = sr.cpu() | |
| sr = torch.clamp(sr,min=0,max=1) | |
| print('min max',inputimg.min(),inputimg.max()) | |
| pil_sr_img = toPIL(sr[0]) | |
| if opt.norm == 'convert': | |
| pil_sr_img = transforms.functional.rotate(pil_sr_img,-90) | |
| # pil_sr_img.save('%s.png' % outfile) # true output for downloading, no LUT | |
| sr_img = np.array(pil_sr_img) | |
| # sr_img = exposure.equalize_adapthist(sr_img,clip_limit=0.01) | |
| skimage.io.imsave('%s.png' % outfile, sr_img) # true out for downloading, no LUT | |
| sr_img = skimage.transform.resize(sr_img,(768,768),order=3) # should ideally be done by drawing on client side, in javascript | |
| save_image(sr_img,'%s_sr.png' % outfile,cmap) | |
| return outfile + '_sr.png', outfile + '_wf.png', outfile + '.png' | |