Spaces:
Runtime error
Runtime error
| import math | |
| import numpy as np | |
| import torch | |
| from matplotlib import pyplot as plt | |
| from torchvision.utils import make_grid | |
| def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): | |
| """Convert torch Tensors into image numpy arrays. | |
| After clamping to (min, max), image values will be normalized to [0, 1]. | |
| For different tensor shapes, this function will have different behaviors: | |
| 1. 4D mini-batch Tensor of shape (N x 3/1 x H x W): | |
| Use `make_grid` to stitch images in the batch dimension, and then | |
| convert it to numpy array. | |
| 2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W): | |
| Directly change to numpy array. | |
| Note that the image channel in input tensors should be RGB order. This | |
| function will convert it to cv2 convention, i.e., (H x W x C) with BGR | |
| order. | |
| Args: | |
| tensor (Tensor | list[Tensor]): Input tensors. | |
| out_type (numpy type): Output types. If ``np.uint8``, transform outputs | |
| to uint8 type with range [0, 255]; otherwise, float type with | |
| range [0, 1]. Default: ``np.uint8``. | |
| min_max (tuple): min and max values for clamp. | |
| Returns: | |
| (Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray | |
| of shape (H x W). | |
| """ | |
| if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): | |
| raise TypeError( | |
| f'tensor or list of tensors expected, got {type(tensor)}') | |
| if torch.is_tensor(tensor): | |
| tensor = [tensor] | |
| result = [] | |
| for _tensor in tensor: | |
| # Squeeze two times so that: | |
| # 1. (1, 1, h, w) -> (h, w) or | |
| # 3. (1, 3, h, w) -> (3, h, w) or | |
| # 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w) | |
| _tensor = _tensor.squeeze(0).squeeze(0) | |
| _tensor = _tensor.float().detach().cpu().clamp_(*min_max) | |
| _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) | |
| n_dim = _tensor.dim() | |
| if n_dim == 4: | |
| img_np = make_grid( | |
| _tensor, nrow=int(math.sqrt(_tensor.size(0))), | |
| normalize=False).numpy() | |
| img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) | |
| elif n_dim == 3: | |
| img_np = _tensor.numpy() | |
| img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) | |
| elif n_dim == 2: | |
| img_np = _tensor.numpy() | |
| else: | |
| raise ValueError('Only support 4D, 3D or 2D tensor. ' | |
| f'But received with dimension: {n_dim}') | |
| if out_type == np.uint8: | |
| # Unlike MATLAB, numpy.unit8() WILL NOT round by default. | |
| img_np = (img_np * 255.0).round() | |
| img_np = img_np.astype(out_type) | |
| result.append(img_np) | |
| result = result[0] if len(result) == 1 else result | |
| return result | |
| def plt_tensor_img(tensor, save_path=None): | |
| plt.imshow(tensor2img(tensor)) | |
| plt.show() | |
| if save_path: | |
| plt.savefig(save_path) | |
| def plt_tensor_img_one(tensor, t_dim=1): | |
| if isinstance(tensor, list): | |
| tensor = torch.cat(tensor, dim=t_dim) | |
| nums = tensor.shape[t_dim] | |
| mash = math.ceil(math.sqrt(nums)) | |
| plt.figure(dpi=300) | |
| plt_range = min(nums, mash ** 2) | |
| for i in range(plt_range): | |
| plt.subplot(mash, mash, i + 1) | |
| if t_dim == 1: | |
| img = tensor2img(tensor[:, i, ...]) | |
| elif t_dim == 0: | |
| img = tensor2img(tensor[i, ...]) | |
| plt.imshow(img) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| plt.subplots_adjust(wspace=0, hspace=0) | |
| plt.tight_layout() | |
| plt.show() | |
| def plt_img(img, save_path=None): | |
| plt.imshow(img) | |
| plt.show() | |
| if save_path: | |
| plt.savefig(save_path) | |