Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import os | |
| from collections import OrderedDict | |
| import data | |
| from options.test_options import TestOptions | |
| from models.pix2pix_model import Pix2PixModel | |
| from util.visualizer import Visualizer | |
| import torchvision.utils as vutils | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| opt = TestOptions().parse() | |
| dataloader = data.create_dataloader(opt) | |
| model = Pix2PixModel(opt) | |
| model.eval() | |
| visualizer = Visualizer(opt) | |
| single_save_url = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir, "each_img") | |
| if not os.path.exists(single_save_url): | |
| os.makedirs(single_save_url) | |
| for i, data_i in enumerate(dataloader): | |
| if i * opt.batchSize >= opt.how_many: | |
| break | |
| generated = model(data_i, mode="inference") | |
| img_path = data_i["path"] | |
| for b in range(generated.shape[0]): | |
| img_name = os.path.split(img_path[b])[-1] | |
| save_img_url = os.path.join(single_save_url, img_name) | |
| vutils.save_image((generated[b] + 1) / 2, save_img_url) | |