Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import numpy as np | |
| from cv2 import resize | |
| import cv2 | |
| from pathlib import Path | |
| from network import EfficientViT_l1_r224 | |
| from losses import IISLoss, activate | |
| from utils import minmaxnorm, load_from_ckpt | |
| class Busam: | |
| def __init__(self, checkpoint, device, side=224): | |
| out_channels = 16 | |
| use_norm_params = False | |
| net = EfficientViT_l1_r224( | |
| out_channels=out_channels, use_norm_params=use_norm_params, pretrained=False | |
| ) | |
| net = load_from_ckpt(net, checkpoint) | |
| net = net.to(device) | |
| net.eval() | |
| self.net = net | |
| self.device = device | |
| self.side = side | |
| def prepare_img(self, img): | |
| """ | |
| assume H, W, 3 image | |
| """ | |
| assert len(img.shape) == 3, "should be H, W, 3 but is " + str(img.shape) | |
| assert img.shape[2] == 3, "should be H, W, 3 but is " + str(img.shape) | |
| assert img.min() >= 0, "min should be more than 0 but is " + str(img.min()) | |
| assert img.max() <= 255, "max should be less than 255 but is " + str(img.max()) | |
| assert img.dtype == np.uint8, "dtype should be np.uint8 but is " + str( | |
| img.dtype | |
| ) | |
| nimg = resize(img, (self.side, self.side)) | |
| tensorimg = ( | |
| (torch.from_numpy(nimg / 255).permute(2, 0, 1) - 0.5) | |
| .float()[None] | |
| .to(self.device) | |
| ) | |
| return tensorimg | |
| def process_image(self, img, do_activate=False): | |
| with torch.no_grad(): | |
| x = self.prepare_img(img) | |
| pred = self.net(x) | |
| H, W = img.shape[:2] | |
| if do_activate: | |
| B, F, pH, pW = pred.shape | |
| features, _, _, _ = activate( | |
| pred.view(F, pH * pW), None, "symlog", False, False, False | |
| ) | |
| pred = features.view(B, F, pH, pW) | |
| return pred, (H, W) | |
| def get_mask(self, aux, click): | |
| """assume click is (row, col)""" | |
| pred = aux[0][0] # remove batch dim | |
| oH, oW = aux[1] | |
| F, H, W = pred.shape | |
| features = pred.view(F, H * W) | |
| rclick = click[0] * H // oH, click[1] * W // oW | |
| sindex = rclick[0] * W + rclick[1] | |
| mask = IISLoss.get_mask_from_query(features, sindex) | |
| mask = mask.reshape(H, W) | |
| mask = ( | |
| resize((mask.cpu().numpy() * 255).astype(np.uint8), (oW, oH)) > 100 | |
| ).astype(bool) | |
| return mask | |
| def get_gradients(self, pred, size): | |
| F, H, W = pred[0].shape | |
| sobel_x = ( | |
| torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).float().to(pred.device) | |
| ) | |
| sobel_y = sobel_x.T | |
| sobel_x = sobel_x.repeat(F, 1, 1, 1) | |
| sobel_y = sobel_y.repeat(F, 1, 1, 1) | |
| edge_x = torch.nn.functional.conv2d(pred, sobel_x, padding=1, groups=F).view( | |
| F, H, W | |
| ) # 1, F, H, W | |
| edge_y = torch.nn.functional.conv2d(pred, sobel_y, padding=1, groups=F).view( | |
| F, H, W | |
| ) | |
| edge_x = torch.norm(edge_x, dim=0, p=2) # will take sqrt | |
| edge_y = torch.norm(edge_y, dim=0, p=2) # H, W | |
| return edge_x, edge_y | |
| def sobel_from_pred(self, pred, size): | |
| edge_x, edge_y = self.get_gradients(pred, size) | |
| edge = torch.sqrt(edge_x**2 + edge_y**2) | |
| return edge | |
| def canny_from_pred(self, pred, size, th_low=10000, th_high=20000): | |
| th_low = th_low or th_high | |
| th_high = th_high or th_low | |
| edge_x, edge_y = self.get_gradients(pred, size) | |
| amin = min(edge_x.min(), edge_y.min()) | |
| amax = max(edge_x.max(), edge_y.max()) | |
| edge_x, edge_y = (edge_x - amin) / (amax - amin), (edge_y - amin) / ( | |
| amax - amin | |
| ) | |
| canny = cv2.Canny(cast_to_int16(edge_x), cast_to_int16(edge_y), th_low, th_high) | |
| return canny | |
| def cast_to_int16(x): | |
| if isinstance(x, torch.Tensor): | |
| x = x.cpu().numpy() | |
| return (x * 32767).astype(np.int16) | |
| # from segment_anything import sam_model_registry, SamPredictor | |
| # class SAM: | |
| # sam_checkpoint = "sam_vit_b_01ec64.pth" | |
| # model_type = "vit_b" | |
| # def __init__(self, device): | |
| # sam = sam_model_registry[self.model_type](checkpoint=self.sam_checkpoint) | |
| # sam.to(device=device) | |
| # self.predictor = SamPredictor(sam) | |
| # def process_image(self, img): | |
| # self.predictor.set_image(img) | |
| # return None | |
| # def get_mask(self, aux, click): | |
| # input_point = np.array([[click[1], click[0]]]) | |
| # input_label = np.array([1]) | |
| # masks, scores, logits = self.predictor.predict( | |
| # point_coords=input_point, point_labels=input_label, multimask_output=False | |
| # ) | |
| # return masks[0] | |