Spaces:
Sleeping
Sleeping
| import contextlib | |
| import numpy as np | |
| import random | |
| import shutil | |
| import os | |
| import torch | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"): | |
| filename = os.path.join(checkpoint_path, filename) | |
| torch.save(state, filename) | |
| if is_best: | |
| shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt")) | |
| def load_checkpoint(model, path): | |
| best_checkpoint = torch.load(path) | |
| model.load_state_dict(best_checkpoint["state_dict"]) | |
| def log_metrics(set_name, metrics, logger): | |
| logger.info( | |
| "{}: Loss: {:.5f} | spec_acc: {:.5f}, rgb_acc: {:.5f}".format( | |
| set_name, metrics["loss"], metrics["spec_acc"], metrics["rgb_acc"] | |
| ) | |
| ) | |
| def numpy_seed(seed, *addl_seeds): | |
| """Context manager which seeds the NumPy PRNG with the specified seed and | |
| restores the state afterward""" | |
| if seed is None: | |
| yield | |
| return | |
| if len(addl_seeds) > 0: | |
| seed = int(hash((seed, *addl_seeds)) % 1e6) | |
| state = np.random.get_state() | |
| np.random.seed(seed) | |
| try: | |
| yield | |
| finally: | |
| np.random.set_state(state) | |