Spaces:
Sleeping
Sleeping
| print("Importing standard...") | |
| import subprocess | |
| import shutil | |
| from pathlib import Path | |
| print("Importing external...") | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| REDUCTION = "pca" | |
| if REDUCTION == "umap": | |
| from umap import UMAP | |
| elif REDUCTION == "tsne": | |
| from sklearn.manifold import TSNE | |
| elif REDUCTION == "pca": | |
| from sklearn.decomposition import PCA | |
| def symlog(x): | |
| return torch.sign(x) * torch.log(torch.abs(x) + 1) | |
| def preprocess_masks_features(masks, features): | |
| # Get shapes right | |
| B, M, H, W = masks.shape | |
| Bf, F, Hf, Wf = features.shape | |
| masks = masks.reshape(B, M, 1, H * W) | |
| # # the following assertions should work, remove due to speed | |
| # assert H == Hf and W == Wf and B == Bf | |
| # assert masks.dtype == torch.bool | |
| # assert (mask_areas > 0).all(), "you shouldn't have empty masks" | |
| # Reduce M if there are empty masks | |
| mask_areas = masks.sum(dim=3) # B, M, 1 | |
| features = features.reshape(B, 1, F, H * W) | |
| # output shapes | |
| # features: B, 1, F, H*W | |
| # masks: B, M, 1, H*W | |
| return masks, features, M, B, H, W, F | |
| def get_row_col(H, W, device): | |
| # get position of pixels in [0, 1] | |
| row = torch.linspace(0, 1, H, device=device) | |
| col = torch.linspace(0, 1, W, device=device) | |
| return row, col | |
| def get_current_git_commit(): | |
| try: | |
| # Run the git command to get the current commit hash | |
| commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip() | |
| # Decode from bytes to a string | |
| return commit_hash.decode("utf-8") | |
| except subprocess.CalledProcessError: | |
| # Handle the case where the command fails (e.g., not a Git repository) | |
| print("An error occurred while trying to retrieve the git commit hash.") | |
| return None | |
| def clean_dir(dirname): | |
| """Removes all directories in dirname that don't have a done.txt file""" | |
| dstdir = Path(dirname) | |
| dstdir.mkdir(exist_ok=True, parents=True) | |
| for f in dstdir.iterdir(): | |
| # if the directory doesn't have a done.txt file remove it | |
| if f.is_dir() and not (f / "done.txt").exists(): | |
| shutil.rmtree(f) | |
| def save_tensor_as_image(tensor, dstfile, global_step): | |
| dstfile = Path(dstfile) | |
| dstfile = (dstfile.parent / (dstfile.stem + "_" + str(global_step))).with_suffix( | |
| ".jpg" | |
| ) | |
| save(tensor, str(dstfile)) | |
| def minmaxnorm(x): | |
| return (x - x.min()) / (x.max() - x.min()) | |
| def save(tensor, name, channel_offset=0): | |
| tensor = to_img(tensor, channel_offset=channel_offset) | |
| Image.fromarray(tensor).save(name) | |
| def to_img(tensor, channel_offset=0): | |
| tensor = minmaxnorm(tensor) | |
| tensor = (tensor * 255).to(torch.uint8) | |
| C, H, W = tensor.shape | |
| if tensor.shape[0] == 1: | |
| tensor = tensor[0] | |
| elif tensor.shape[0] == 2: | |
| tensor = torch.stack([tensor[0], torch.zeros_like(tensor[0]), tensor[1]], dim=0) | |
| tensor = tensor.permute(1, 2, 0) | |
| elif tensor.shape[0] >= 3: | |
| tensor = tensor[channel_offset : channel_offset + 3] | |
| tensor = tensor.permute(1, 2, 0) | |
| tensor = tensor.cpu().numpy() | |
| return tensor | |
| def log_input_output( | |
| name, | |
| x, | |
| y_hat, | |
| global_step, | |
| img_dstdir, | |
| out_dstdir, | |
| reduce_dim=True, | |
| reduction=REDUCTION, | |
| resample_size=20000, | |
| ): | |
| y_hat = y_hat.reshape( | |
| y_hat.shape[0], y_hat.shape[2], y_hat.shape[3], y_hat.shape[4] | |
| ) | |
| if reduce_dim and y_hat.shape[1] >= 3: | |
| reducer = ( | |
| UMAP(n_components=3) | |
| if (reduction == "umap") | |
| else ( | |
| TSNE(n_components=3) | |
| if reduction == "tsne" | |
| else PCA(n_components=3) | |
| if reduction == "pca" | |
| else None | |
| ) | |
| ) | |
| np_y_hat = y_hat.detach().cpu().permute(1, 0, 2, 3).numpy() # F, 1, B, H, W | |
| np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW | |
| np_y_hat = np_y_hat.T # BHW, F | |
| sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size] | |
| print("dim reduction fit..." + " " * 30, end="\r") | |
| reducer = reducer.fit(sampled_pixels) | |
| print("dim reduction transform..." + " " * 30, end="\r") | |
| reducer.transform(np_y_hat[:10]) # to numba compile the function | |
| np_y_hat = reducer.transform(np_y_hat) # BHW, 3 | |
| # revert back to original shape | |
| y_hat2 = ( | |
| torch.from_numpy( | |
| np_y_hat.T.reshape(3, y_hat.shape[0], y_hat.shape[2], y_hat.shape[3]) | |
| ) | |
| .to(y_hat.device) | |
| .permute(1, 0, 2, 3) | |
| ) | |
| print("done" + " " * 30, end="\r") | |
| else: | |
| y_hat2 = y_hat | |
| for i in range(min(len(x), 8)): | |
| save_tensor_as_image( | |
| x[i], | |
| img_dstdir / f"input_{name}_{str(i).zfill(2)}", | |
| global_step=global_step, | |
| ) | |
| for c in range(y_hat.shape[1]): | |
| save_tensor_as_image( | |
| y_hat[i, c : c + 1], | |
| out_dstdir / f"pred_channel_{name}_{str(i).zfill(2)}_{c}", | |
| global_step=global_step, | |
| ) | |
| # log color image | |
| assert len(y_hat2.shape) == 4, "should be B, F, H, W" | |
| if reduce_dim: | |
| save_tensor_as_image( | |
| y_hat2[i][:3], | |
| out_dstdir / f"pred_reduced_{name}_{str(i).zfill(2)}", | |
| global_step=global_step, | |
| ) | |
| save_tensor_as_image( | |
| y_hat[i][:3], | |
| out_dstdir / f"pred_colorchs_{name}_{str(i).zfill(2)}", | |
| global_step=global_step, | |
| ) | |
| def check_for_nan(loss, model, batch): | |
| try: | |
| assert torch.isnan(loss) == False | |
| except Exception as e: | |
| # print things useful to debug | |
| # does the batch contain nan? | |
| print("img batch contains nan?", torch.isnan(batch[0]).any()) | |
| print("mask batch contains nan?", torch.isnan(batch[1]).any()) | |
| # does the model weights contain nan? | |
| for name, param in model.named_parameters(): | |
| if torch.isnan(param).any(): | |
| print(name, "contains nan") | |
| # does the output contain nan? | |
| print("output contains nan?", torch.isnan(model(batch[0])).any()) | |
| # now raise the error | |
| raise e | |
| def calculate_iou(pred, label): | |
| intersection = ((label == 1) & (pred == 1)).sum() | |
| union = ((label == 1) | (pred == 1)).sum() | |
| if not union: | |
| return 0 | |
| else: | |
| iou = intersection.item() / union.item() | |
| return iou | |
| def load_from_ckpt(net, ckpt_path, strict=True): | |
| """Load network weights""" | |
| if ckpt_path and Path(ckpt_path).exists(): | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| if "MODEL_STATE" in ckpt: | |
| ckpt = ckpt["MODEL_STATE"] | |
| elif "state_dict" in ckpt: | |
| ckpt = ckpt["state_dict"] | |
| net.load_state_dict(ckpt, strict=strict) | |
| print("Loaded checkpoint from", ckpt_path) | |
| return net | |