Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from pathlib import Path | |
| from busam import Busam | |
| resize_to = 512 | |
| checkpoint = "weights.pth" | |
| device = "cpu" | |
| print("Loading model...") | |
| busam = Busam(checkpoint=checkpoint, device=device, side=resize_to) | |
| minmaxnorm = lambda x: (x - x.min()) / (x.max() - x.min()) | |
| def edge_inference(img, algorithm, th_low=None, th_high=None): | |
| algorithm = algorithm.lower() | |
| print("Loading image...") | |
| img = np.array(img[:, :, :3]) | |
| print("Getting features...") | |
| pred, size = busam.process_image(img, do_activate=True) | |
| print("Computing sobel...") | |
| if algorithm == "sobel": | |
| edge = busam.sobel_from_pred(pred, size) | |
| elif algorithm == "canny": | |
| th_low, th_high = th_low or 5000, th_high or 10000 | |
| edge = busam.canny_from_pred(pred, size, th_low=th_low, th_high=th_high) | |
| else: | |
| raise ValueError("algorithm should be sobel or canny") | |
| edge = edge.cpu().numpy() if isinstance(edge, torch.Tensor) else edge | |
| print("Done") | |
| return Image.fromarray( | |
| (minmaxnorm(edge) * 255).astype(np.uint8) | |
| ).resize(size[::-1]) | |
| def dimred_inference( | |
| img, | |
| algorithm, | |
| resample_pct, | |
| ): | |
| algorithm = algorithm.lower() | |
| img = np.array(img[:, :, :3]) | |
| print("Getting features...") | |
| pred, size = busam.process_image(img, do_activate=True) | |
| # pred is 1, F, S, S | |
| assert pred.shape[1] >= 3, "should have at least 3 channels" | |
| if algorithm == 'pca': | |
| from sklearn.decomposition import PCA | |
| reducer = PCA(n_components=3) | |
| elif algorithm == 'tsne': | |
| from sklearn.manifold import TSNE | |
| reducer = TSNE(n_components=3) | |
| elif algorithm == 'umap': | |
| from umap import UMAP | |
| reducer = UMAP(n_components=3) | |
| else: | |
| raise ValueError('algorithm should be pca, tsne or umap') | |
| np_y_hat = pred.detach().cpu().permute(1, 0, 2, 3).numpy() # F, B, H, W | |
| np_y_hat = np_y_hat.reshape(np_y_hat.shape[0], -1) # F, BHW | |
| np_y_hat = np_y_hat.T # BHW, F | |
| resample_pct = 10**resample_pct | |
| resample_size = int(resample_pct * np_y_hat.shape[0]) | |
| sampled_pixels = np_y_hat[:: np_y_hat.shape[0] // resample_size] | |
| print("dim reduction fit..." + " " * 30, end="\r") | |
| reducer = reducer.fit(sampled_pixels) | |
| print("dim reduction transform..." + " " * 30, end="\r") | |
| reducer.transform(np_y_hat[:10]) # to numba compile the function | |
| np_y_hat = reducer.transform(np_y_hat) # BHW, 3 | |
| print() | |
| print('Done. Saving...') | |
| # revert back to original shape | |
| colors = np_y_hat.reshape(pred.shape[2], pred.shape[3], 3) | |
| return Image.fromarray((minmaxnorm(colors) * 255).astype(np.uint8)).resize( | |
| size[::-1] | |
| ) | |
| def segmentation_inference(img, algorithm, scale): | |
| algorithm = algorithm.lower() | |
| img = np.array(img[:, :, :3]) | |
| print("Getting features...") | |
| pred, size = busam.process_image(img, do_activate=True) | |
| print("Computing segmentation...") | |
| if algorithm == "kmeans": | |
| from sklearn.cluster import KMeans | |
| n_clusters = int(100 / 100**scale) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit( | |
| pred.view(pred.shape[1], -1).T | |
| ) | |
| labels = kmeans.labels_ | |
| labels = labels.reshape(pred.shape[2], pred.shape[3]) | |
| elif algorithm == "felzenszwalb": | |
| from skimage.segmentation import felzenszwalb | |
| labels = felzenszwalb( | |
| (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0), | |
| scale=10**(8*scale-3), | |
| sigma=0, | |
| min_size=50, | |
| ) | |
| elif algorithm == "slic": | |
| from skimage.segmentation import slic | |
| labels = slic( | |
| (minmaxnorm(pred[0].cpu().numpy()) * 255).astype(np.uint8).transpose(1, 2, 0), | |
| n_segments = int(100 / 100**scale), | |
| compactness=0.00001, | |
| sigma=1, | |
| ) | |
| elif algorithm == 'watershed': | |
| from skimage.segmentation import watershed | |
| from skimage.feature import peak_local_max | |
| from scipy import ndimage as ndi | |
| sobel = busam.sobel_from_pred(pred, size) | |
| sobel = sobel.cpu().numpy() if isinstance(sobel, torch.Tensor) else sobel | |
| # contrast stretch sobel with 5% largest | |
| sobel = np.clip(sobel / np.percentile(sobel, 95), 0, 1) | |
| distance = ndi.distance_transform_edt(sobel < 1) # distance to the borders | |
| coords = peak_local_max(distance, min_distance=int(1+100*scale), labels=sobel<1) | |
| mask = np.zeros(sobel.shape, dtype=bool) | |
| mask[tuple(coords.T)] = True | |
| markers, _ = ndi.label(mask) | |
| labels = watershed(sobel, markers) | |
| else: | |
| raise ValueError("algorithm should be kmeans, felzenszwalb or slic") | |
| print("Done") | |
| # the labels have values that are usually close to each other in the image and in magnitude, which complicates visualization | |
| # shuffle the labels to make them more visually distinct | |
| out = labels.copy() | |
| out[labels % 4 == 0] = labels[labels % 4 == 0] * 1 / 4 | |
| out[labels % 4 == 1] = labels[labels % 4 == 1] * 4 // 4 + 1 | |
| out[labels % 4 == 2] = labels[labels % 4 == 2] * 2 // 4 + 2 | |
| out[labels % 4 == 3] = labels[labels % 4 == 3] * 3 // 4 + 3 | |
| return Image.fromarray( | |
| (minmaxnorm(out) * 255).astype(np.uint8) | |
| ).resize(size[::-1]) | |
| def one_click_segmentation(img, row, col, threshold): | |
| row, col = int(row), int(col) | |
| img = np.array(img[:, :, :3]) | |
| click_map = np.zeros(img.shape[:2], dtype=bool) | |
| side = min(img.shape[:2]) // 100 | |
| click_map[max(0, row-side):min(img.shape[0], row+side), max(0, col-side//5):min(img.shape[0], col+side//5)] = True | |
| click_map[max(0, row-side//5):min(img.shape[0], row+side//5), max(0, col-side):min(img.shape[0], col+side)] = True | |
| print("Getting features...") | |
| pred, size = busam.process_image(img, do_activate=True) | |
| print("Getting mask...") | |
| mask = busam.get_mask((pred, size), (row, col)) | |
| print("Done") | |
| print('shapes=', img.shape, mask.shape, click_map.shape) | |
| return (img, [(mask, 'Prediction'), (click_map, 'Click')]) | |
| with gr.Blocks() as demo: | |
| with gr.Tab('Edge detection'): | |
| algorithm = "canny" | |
| with gr.Row(): | |
| def enable_sliders(algorithm): | |
| algorithm = algorithm.lower() | |
| return gr.Slider(visible=algorithm == "canny"), gr.Slider(visible=algorithm == "canny") | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image") | |
| run_button = gr.Button("Run") | |
| algorithm = gr.Radio(["Sobel", "Canny"], label="Algorithm", value="Sobel") | |
| # add sliders for th_low, th_high | |
| th_low_slider = gr.Slider(0, 32768, 10000, label="Canny's low threshold", visible=False) | |
| th_high_slider = gr.Slider(0, 32768, 20000, label="Canny's high threshold", visible=False) | |
| algorithm.change(enable_sliders, inputs=[algorithm], outputs=[th_low_slider, th_high_slider]) | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| run_button.click(edge_inference, inputs=[image_input, algorithm, th_low_slider, th_high_slider], outputs=output_image) | |
| gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) | |
| with gr.Tab('Reduction to 3D'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image") | |
| algorithm = gr.Radio(["PCA", "TSNE", "UMAP"], label="Algorithm", value="PCA") | |
| run_button = gr.Button("Run") | |
| gr.Markdown("⚠️ UMAP is slow, TSNE is ULTRA-slow. They won't run on time. ⚠️") | |
| resample_pct = gr.Slider(-5, 0, -3, label="Resample (10^x)*100%") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| run_button.click(dimred_inference, inputs=[image_input, algorithm, resample_pct], outputs=output_image) | |
| gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) | |
| with gr.Tab('Classical Segmentation'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image") | |
| algorithm = gr.Radio(['KMeans', 'Felzenszwalb', 'SLIC', 'Watershed'], label="Algorithm", value="SLIC") | |
| scale = gr.Slider(0.1, 1.0, 0.5, label="Scale") | |
| run_button = gr.Button("Run") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Output Image") | |
| run_button.click(segmentation_inference, inputs=[image_input, algorithm, scale], outputs=output_image) | |
| gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) | |
| with gr.Tab('One-click segmentation'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image") | |
| threshold = gr.Slider(0, 1, 0.5, label="Threshold") | |
| with gr.Row(): | |
| row = gr.Textbox(10, label="Click's row") | |
| col = gr.Textbox(10, label="Click's column") | |
| run_button = gr.Button("Run") | |
| with gr.Column(): | |
| output_image = gr.AnnotatedImage(label="Output") | |
| run_button.click(one_click_segmentation, inputs=[image_input, row, col, threshold], outputs=output_image) | |
| gr.Examples([str(p) for p in Path('demoimgs').glob('*')], inputs=image_input) | |
| demo.launch(share=False) | |