Spaces:
Runtime error
Runtime error
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT License. | |
| import os | |
| import ntpath | |
| import time | |
| from . import util | |
| import scipy.misc | |
| try: | |
| from StringIO import StringIO # Python 2.7 | |
| except ImportError: | |
| from io import BytesIO # Python 3.x | |
| import torchvision.utils as vutils | |
| from tensorboardX import SummaryWriter | |
| import torch | |
| import numpy as np | |
| class Visualizer: | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.tf_log = opt.isTrain and opt.tf_log | |
| self.tensorboard_log = opt.tensorboard_log | |
| self.win_size = opt.display_winsize | |
| self.name = opt.name | |
| if self.tensorboard_log: | |
| if self.opt.isTrain: | |
| self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, "logs") | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir) | |
| self.writer = SummaryWriter(log_dir=self.log_dir) | |
| else: | |
| print("hi :)") | |
| self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.results_dir) | |
| if not os.path.exists(self.log_dir): | |
| os.makedirs(self.log_dir) | |
| if opt.isTrain: | |
| self.log_name = os.path.join(opt.checkpoints_dir, opt.name, "loss_log.txt") | |
| with open(self.log_name, "a") as log_file: | |
| now = time.strftime("%c") | |
| log_file.write("================ Training Loss (%s) ================\n" % now) | |
| # |visuals|: dictionary of images to display or save | |
| def display_current_results(self, visuals, epoch, step): | |
| all_tensor = [] | |
| if self.tensorboard_log: | |
| for key, tensor in visuals.items(): | |
| all_tensor.append((tensor.data.cpu() + 1) / 2) | |
| output = torch.cat(all_tensor, 0) | |
| img_grid = vutils.make_grid(output, nrow=self.opt.batchSize, padding=0, normalize=False) | |
| if self.opt.isTrain: | |
| self.writer.add_image("Face_SPADE/training_samples", img_grid, step) | |
| else: | |
| vutils.save_image( | |
| output, | |
| os.path.join(self.log_dir, str(step) + ".png"), | |
| nrow=self.opt.batchSize, | |
| padding=0, | |
| normalize=False, | |
| ) | |
| # errors: dictionary of error labels and values | |
| def plot_current_errors(self, errors, step): | |
| if self.tf_log: | |
| for tag, value in errors.items(): | |
| value = value.mean().float() | |
| summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) | |
| self.writer.add_summary(summary, step) | |
| if self.tensorboard_log: | |
| self.writer.add_scalar("Loss/GAN_Feat", errors["GAN_Feat"].mean().float(), step) | |
| self.writer.add_scalar("Loss/VGG", errors["VGG"].mean().float(), step) | |
| self.writer.add_scalars( | |
| "Loss/GAN", | |
| { | |
| "G": errors["GAN"].mean().float(), | |
| "D": (errors["D_Fake"].mean().float() + errors["D_real"].mean().float()) / 2, | |
| }, | |
| step, | |
| ) | |
| # errors: same format as |errors| of plotCurrentErrors | |
| def print_current_errors(self, epoch, i, errors, t): | |
| message = "(epoch: %d, iters: %d, time: %.3f) " % (epoch, i, t) | |
| for k, v in errors.items(): | |
| v = v.mean().float() | |
| message += "%s: %.3f " % (k, v) | |
| print(message) | |
| with open(self.log_name, "a") as log_file: | |
| log_file.write("%s\n" % message) | |
| def convert_visuals_to_numpy(self, visuals): | |
| for key, t in visuals.items(): | |
| tile = self.opt.batchSize > 8 | |
| if "input_label" == key: | |
| t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) ## B*H*W*C 0-255 numpy | |
| else: | |
| t = util.tensor2im(t, tile=tile) | |
| visuals[key] = t | |
| return visuals | |
| # save image to the disk | |
| def save_images(self, webpage, visuals, image_path): | |
| visuals = self.convert_visuals_to_numpy(visuals) | |
| image_dir = webpage.get_image_dir() | |
| short_path = ntpath.basename(image_path[0]) | |
| name = os.path.splitext(short_path)[0] | |
| webpage.add_header(name) | |
| ims = [] | |
| txts = [] | |
| links = [] | |
| for label, image_numpy in visuals.items(): | |
| image_name = os.path.join(label, "%s.png" % (name)) | |
| save_path = os.path.join(image_dir, image_name) | |
| util.save_image(image_numpy, save_path, create_dir=True) | |
| ims.append(image_name) | |
| txts.append(label) | |
| links.append(image_name) | |
| webpage.add_images(ims, txts, links, width=self.win_size) | |