Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import argparse | |
| import gc | |
| import json | |
| import os | |
| import time | |
| import warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision as tv | |
| from PIL import Image, ImageFile | |
| from detection_models import networks | |
| from detection_util.util import * | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| def data_transforms(img, full_size, method=Image.BICUBIC): | |
| if full_size == "full_size": | |
| ow, oh = img.size | |
| h = int(round(oh / 16) * 16) | |
| w = int(round(ow / 16) * 16) | |
| if (h == oh) and (w == ow): | |
| return img | |
| return img.resize((w, h), method) | |
| elif full_size == "scale_256": | |
| ow, oh = img.size | |
| pw, ph = ow, oh | |
| if ow < oh: | |
| ow = 256 | |
| oh = ph / pw * 256 | |
| else: | |
| oh = 256 | |
| ow = pw / ph * 256 | |
| h = int(round(oh / 16) * 16) | |
| w = int(round(ow / 16) * 16) | |
| if (h == ph) and (w == pw): | |
| return img | |
| return img.resize((w, h), method) | |
| def scale_tensor(img_tensor, default_scale=256): | |
| _, _, w, h = img_tensor.shape | |
| if w < h: | |
| ow = default_scale | |
| oh = h / w * default_scale | |
| else: | |
| oh = default_scale | |
| ow = w / h * default_scale | |
| oh = int(round(oh / 16) * 16) | |
| ow = int(round(ow / 16) * 16) | |
| return F.interpolate(img_tensor, [ow, oh], mode="bilinear") | |
| def blend_mask(img, mask): | |
| np_img = np.array(img).astype("float") | |
| return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB") | |
| def main(config): | |
| print("initializing the dataloader") | |
| model = networks.UNet( | |
| in_channels=1, | |
| out_channels=1, | |
| depth=4, | |
| conv_num=2, | |
| wf=6, | |
| padding=True, | |
| batch_norm=True, | |
| up_mode="upsample", | |
| with_tanh=False, | |
| sync_bn=True, | |
| antialiasing=True, | |
| ) | |
| ## load model | |
| checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model_state"]) | |
| print("model weights loaded") | |
| if config.GPU >= 0: | |
| model.to(config.GPU) | |
| else: | |
| model.cpu() | |
| model.eval() | |
| ## dataloader and transformation | |
| print("directory of testing image: " + config.test_path) | |
| imagelist = os.listdir(config.test_path) | |
| imagelist.sort() | |
| total_iter = 0 | |
| P_matrix = {} | |
| save_url = os.path.join(config.output_dir) | |
| mkdir_if_not(save_url) | |
| input_dir = os.path.join(save_url, "input") | |
| output_dir = os.path.join(save_url, "mask") | |
| # blend_output_dir=os.path.join(save_url, 'blend_output') | |
| mkdir_if_not(input_dir) | |
| mkdir_if_not(output_dir) | |
| # mkdir_if_not(blend_output_dir) | |
| idx = 0 | |
| results = [] | |
| for image_name in imagelist: | |
| idx += 1 | |
| print("processing", image_name) | |
| scratch_file = os.path.join(config.test_path, image_name) | |
| if not os.path.isfile(scratch_file): | |
| print("Skipping non-file %s" % image_name) | |
| continue | |
| scratch_image = Image.open(scratch_file).convert("RGB") | |
| w, h = scratch_image.size | |
| transformed_image_PIL = data_transforms(scratch_image, config.input_size) | |
| scratch_image = transformed_image_PIL.convert("L") | |
| scratch_image = tv.transforms.ToTensor()(scratch_image) | |
| scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image) | |
| scratch_image = torch.unsqueeze(scratch_image, 0) | |
| _, _, ow, oh = scratch_image.shape | |
| scratch_image_scale = scale_tensor(scratch_image) | |
| if config.GPU >= 0: | |
| scratch_image_scale = scratch_image_scale.to(config.GPU) | |
| else: | |
| scratch_image_scale = scratch_image_scale.cpu() | |
| with torch.no_grad(): | |
| P = torch.sigmoid(model(scratch_image_scale)) | |
| P = P.data.cpu() | |
| P = F.interpolate(P, [ow, oh], mode="nearest") | |
| tv.utils.save_image( | |
| (P >= 0.4).float(), | |
| os.path.join( | |
| output_dir, | |
| image_name[:-4] + ".png", | |
| ), | |
| nrow=1, | |
| padding=0, | |
| normalize=True, | |
| ) | |
| transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png")) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| # parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name') | |
| parser.add_argument("--GPU", type=int, default=0) | |
| parser.add_argument("--test_path", type=str, default=".") | |
| parser.add_argument("--output_dir", type=str, default=".") | |
| parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256") | |
| config = parser.parse_args() | |
| main(config) | |