theodore-ioann commited on
Commit
5b303e8
·
verified ·
1 Parent(s): fe47d51

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/ISIC_0012880.jpg filter=lfs diff=lfs merge=lfs -text
37
+ examples/ISIC_0015972.jpg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,48 @@
1
- ---
2
- title: Skin Lesion Segmentation
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.31.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Different models trained on skin lesion segmentation
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ISIC 2018 Skin Lesion Segmentation
2
+ This project explores unsupervised and supervised image segmentation methods applied to the **ISIC 2018 skin lesion dataset**. It compares simple segmentation techniques like **KMeans** and **Gaussian Mixture Models (GMM)** against deep learning models (Unet, Inception-inspired CNN and SegFormer). The deep models are trained on ISIC data, evaluated on the test set and its performance is compared with the baseline models.
3
+
4
+ ## Goals
5
+ - Segment skin lesions from dermoscopic images.
6
+ - Compare baseline unsupervised methods (KMeans, GMM) with deep learning models (Unet, Inception-inspired CNN, SegFormer).
7
+ - Evaluate masks using standard metrics: **IoU**, **F1-score**, **Accuracy**.
8
+ - Visualize results with overlays (predictions vs. ground truth).
9
+ - Explore which morphological operations can improve the quality of the segmentations.
10
+ (erosion, dilation, opening, closing)
11
+
12
+ ## Dataset
13
+ - **ISIC 2018 Challenge - Task 1**
14
+ - ~2,600 dermoscopic images and ground truth lesion masks.
15
+ - Downloaded from [ISIC Archive](https://challenge.isic-archive.com/data/#2018).
16
+ The images and masks are stored in different folders.
17
+ - The dataset is split into training, validation, and test sets.
18
+ - Due to the nature of the task, there is a notable class imbalance, which we have to take into account when training and evaluating the models.
19
+
20
+ ![Class Distribution](results/class_distribution.png)
21
+
22
+ ## Unsupervised Methods
23
+ - **KMeans**: Clustering algorithm that partitions the image into K clusters based on distances between pixel values.
24
+ - Results:
25
+ ![KMeans Results](results/kmeans_distribution.png)
26
+
27
+ - **Gaussian Mixture Models (GMM)**: Probabilistic model that assumes the presence of multiple Gaussian distributions in the data.
28
+ - Results:
29
+ ![GMM Results](results/gmm_distribution.png)
30
+
31
+ ## Supervised Methods
32
+ - **Unet**: A convolutional neural network architecture designed for biomedical image segmentation.
33
+ - Results:
34
+ ![Unet Results](results/unet_distribution.png)
35
+ - **Inception CNN**: A custom architecture inspired by the Inception model, designed for image segmentation tasks.
36
+ - Results:
37
+ <!-- ![Inception Results](results/inception_distribution.png) -->
38
+ - **SegFormer**: A transformer-based model that captures long-range dependencies in images, achieving state-of-the-art results in various vision tasks.
39
+ - Results:
40
+ ![SegFormer Results](results/segformer_distribution.png)
41
+
42
+ ## Evaluation results
43
+ From the evaluation of the models on the test set, we can see that the deep learning models outperform the unsupervised methods in terms of IoU, F1-score, and accuracy. The SegFormer model achieves the best results, followed by Unet and Inception CNN. Overall, it is clear that the deep learning models are able to segment skin lesions better than the unsupervised methods. It is also worth noting that all models significantly outperform the majority baseline (which would achieve an accuracy of 76.4%).
44
+
45
+ ## How to Run
46
+ 1. **Setup**: Install dependencies with `pip install -r requirements.txt`.
47
+ 2. **Training**: Run `python main.py --model segformer --epochs 20 --visualize True` to train the SegFormer (or any of the other models).
48
+ 3. **Testing**: Use the `python main.py --model segformer --visualize True` to evaluate and visualize model predictions on the test set, with a `segformer.pt` file in the same directory.
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from sklearn.cluster import KMeans
5
+ from sklearn.mixture import GaussianMixture
6
+ from utils import *
7
+ from supervised import *
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Load Models
12
+ models = {
13
+ "unet": UNet(num_classes=2).to(device),
14
+ "segformer": Segformer(num_classes=2).to(device),
15
+ "inception": Inception(num_classes=2).to(device),
16
+ "kmeans": KMeans(n_clusters=2),
17
+ "gmm": GaussianMixture(n_components=2),
18
+ }
19
+
20
+ models["unet"].load_state_dict(torch.load("unet.pt", map_location=device))
21
+ models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device))
22
+ models["inception"].load_state_dict(torch.load("inception.pt", map_location=device))
23
+
24
+ for model in models.values():
25
+ if isinstance(model, (UNet, Segformer, Inception)):
26
+ model.eval()
27
+
28
+ # Inference function
29
+ def inference(image, model_name, postprocess_mode):
30
+ model = models[model_name]
31
+ status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}"
32
+ bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode)
33
+ return overlay, bw_mask, status_text
34
+
35
+ # Gradio Interface
36
+ with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo:
37
+ gr.Markdown("## 🩺 Skin Lesion Segmentation")
38
+ gr.Markdown("Upload a skin image, choose a model, and view segmentation results.")
39
+
40
+ with gr.Row():
41
+ with gr.Column(scale=1):
42
+ image_input = gr.Image(type='numpy', label="📷 Upload Image")
43
+ model_choice = gr.Radio(
44
+ choices=["unet", "segformer", "inception", "kmeans", "gmm"],
45
+ label="Model",
46
+ value="unet"
47
+ )
48
+ post_choice = gr.Radio(
49
+ choices=["none", "open", "close", "erosion", "dilation"],
50
+ label="Postprocessing",
51
+ value="none"
52
+ )
53
+ run_btn = gr.Button("▶ Run Segmentation")
54
+
55
+ with gr.Column(scale=2):
56
+ with gr.Row():
57
+ overlay_output = gr.Image(type='numpy', label="🎯 Overlay")
58
+ mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask")
59
+ status = gr.Textbox(label="Status", interactive=False)
60
+
61
+ with gr.Row():
62
+ gr.Examples(
63
+ examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"],
64
+ inputs=[image_input],
65
+ label="Use Example Images"
66
+ )
67
+
68
+ with gr.Accordion("ℹ️ Legend", open=False):
69
+ gr.Markdown("""
70
+ - **🔴 Red**: Predicted lesion overlay
71
+ - **⚫ White**: Binary mask
72
+ - **Postprocessing**: Cleans up noisy segmentation
73
+ """)
74
+
75
+ run_btn.click(
76
+ fn=inference,
77
+ inputs=[image_input, model_choice, post_choice],
78
+ outputs=[overlay_output, mask_output, status]
79
+ )
80
+
81
+ demo.launch(share=True)
examples/ISIC_0012880.jpg ADDED

Git LFS Details

  • SHA256: b41abada1baaeb678002cea86ac171b6b802ab24440c6ec0d0ae78076d29eb23
  • Pointer size: 132 Bytes
  • Size of remote file: 2.51 MB
examples/ISIC_0015972.jpg ADDED

Git LFS Details

  • SHA256: fd21eec581f03f74b3abb98e0251ebf124cb73202a459e9bf605ae360287c768
  • Pointer size: 132 Bytes
  • Size of remote file: 3.83 MB
inception.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18c17a1d87be5906b31a76558132e3c3fc16e643747b8e0859c25cb914eadce9
3
+ size 14355657
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ pillow
5
+ scikit-learn
6
+ matplotlib
7
+ opencv-python
8
+ gradio
9
+ transformers
results/class_distribution.png ADDED
results/gmm_distribution.png ADDED
results/kmeans_distribution.png ADDED
results/segformer_distribution.png ADDED
results/unet_distribution.png ADDED
segformer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dcb4e5c9d19ab4caffaa324e4cf26090bcc9d37fb47840e38d81268691d8341
3
+ size 14946661
supervised.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as T
5
+ from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
6
+
7
+ #=======================================
8
+ #========= UNet Architecture ===========
9
+ #=======================================
10
+ class UNet(nn.Module):
11
+ def __init__(self, in_channels=3, num_classes=2):
12
+ super(UNet, self).__init__()
13
+
14
+ def conv_block(in_c, out_c):
15
+ return nn.Sequential(
16
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
17
+ nn.BatchNorm2d(out_c),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(out_c),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+ self.encoder1 = conv_block(in_channels, 64)
25
+ self.pool1 = nn.MaxPool2d(2)
26
+
27
+ self.encoder2 = conv_block(64, 128)
28
+ self.pool2 = nn.MaxPool2d(2)
29
+
30
+ self.encoder3 = conv_block(128, 256)
31
+ self.pool3 = nn.MaxPool2d(2)
32
+
33
+ self.bottleneck = conv_block(256, 512)
34
+
35
+ self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
36
+ self.decoder3 = conv_block(512, 256)
37
+
38
+ self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
39
+ self.decoder2 = conv_block(256, 128)
40
+
41
+ self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
42
+ self.decoder1 = conv_block(128, 64)
43
+
44
+ self.final = nn.Conv2d(64, num_classes, kernel_size=1)
45
+
46
+ def forward(self, x):
47
+ enc1 = self.encoder1(x)
48
+ enc2 = self.encoder2(self.pool1(enc1))
49
+ enc3 = self.encoder3(self.pool2(enc2))
50
+
51
+ bottleneck = self.bottleneck(self.pool3(enc3))
52
+
53
+ dec3 = self.upconv3(bottleneck)
54
+ dec3 = torch.cat((dec3, enc3), dim=1)
55
+ dec3 = self.decoder3(dec3)
56
+
57
+ dec2 = self.upconv2(dec3)
58
+ dec2 = torch.cat((dec2, enc2), dim=1)
59
+ dec2 = self.decoder2(dec2)
60
+
61
+ dec1 = self.upconv1(dec2)
62
+ dec1 = torch.cat((dec1, enc1), dim=1)
63
+ dec1 = self.decoder1(dec1)
64
+
65
+ return self.final(dec1)
66
+
67
+ #=======================================
68
+ #======= Inception Architecture ========
69
+ #=======================================
70
+ class InceptionBlock(nn.Module):
71
+ def __init__(self, in_channels, out_channels):
72
+ super(InceptionBlock, self).__init__()
73
+ self.b1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),
74
+ nn.ReLU(inplace=True))
75
+ self.b2 = nn.Sequential(
76
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
77
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
78
+ nn.ReLU(inplace=True)
79
+ )
80
+ self.b3 = nn.Sequential(
81
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
82
+ nn.Conv2d(out_channels, out_channels, kernel_size=5, padding=2),
83
+ nn.ReLU(inplace=True)
84
+ )
85
+ self.b4 = nn.Sequential(
86
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
87
+ nn.Conv2d(in_channels, out_channels, kernel_size=1),
88
+ nn.ReLU(inplace=True)
89
+ )
90
+
91
+ def forward(self, x):
92
+ b1 = self.b1(x)
93
+ b2 = self.b2(x)
94
+ b3 = self.b3(x)
95
+ b4 = self.b4(x)
96
+ return torch.cat([b1, b2, b3, b4], dim=1)
97
+
98
+ class Inception(nn.Module):
99
+ def __init__(self, in_channels=3, num_classes=2):
100
+ super(Inception, self).__init__()
101
+ self.weights_init()
102
+ self.inception1 = InceptionBlock(in_channels, 64)
103
+ self.inception2 = InceptionBlock(256, 128)
104
+ self.inception3 = InceptionBlock(512, 256)
105
+
106
+ self.conv1x1 = nn.Conv2d(1024, num_classes, kernel_size=1)
107
+ self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
108
+
109
+ def weights_init(self):
110
+ for m in self.modules():
111
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
112
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
113
+
114
+ def forward(self, x):
115
+ height, width = x.shape[2], x.shape[3]
116
+ x = self.inception1(x)
117
+ x = self.inception2(x)
118
+ x = self.inception3(x)
119
+ x = self.conv1x1(x)
120
+ x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
121
+ return x
122
+
123
+ #=======================================
124
+ #======= Swin Transformer ==============
125
+ #=======================================
126
+ class Segformer(nn.Module):
127
+ def __init__(self, model_name='nvidia/segformer-b0-finetuned-ade-512-512', num_classes=2):
128
+ super(Segformer, self).__init__()
129
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
130
+ model_name,
131
+ num_labels=num_classes,
132
+ ignore_mismatched_sizes=True
133
+ )
134
+ self.processor = SegformerImageProcessor.from_pretrained(model_name)
135
+ self.normalizer = T.Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
136
+
137
+ def forward(self, x):
138
+ x = self.normalizer(x)
139
+ logits = self.model(pixel_values=x).logits # Shape: [B, C, H', W']
140
+ logits = F.interpolate(logits, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)
141
+ return logits
unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d58601e6433b7ebeab5ec4249ca02e413b62b28cd9e69b1719a368cb6deae5bb
3
+ size 30861979
utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from supervised import UNet, Segformer, Inception
7
+ from sklearn.cluster import KMeans
8
+ from sklearn.mixture import GaussianMixture
9
+ from torchvision import transforms
10
+ from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
11
+
12
+ def postprocess(masks, mode="open", kernel_size=5, iters=1):
13
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
14
+ if mode == "open":
15
+ new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks]
16
+ elif mode == "close":
17
+ new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks]
18
+ elif mode == "erosion":
19
+ new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
20
+ elif mode == "dilation":
21
+ new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
22
+ else:
23
+ new_masks = masks
24
+ return new_masks
25
+
26
+ def fix_labels(pred_masks, gt_masks, lesion_positive=True):
27
+ """
28
+ Flip predicted masks if needed based on GT, and ensure lesion is 1.
29
+ If lesion_positive=True, final output has lesion as 1.
30
+ """
31
+ fixed_preds = []
32
+
33
+ for pred, gt in zip(pred_masks, gt_masks):
34
+ pred = pred.astype(np.uint8)
35
+ gt = (gt > 0).astype(np.uint8)
36
+
37
+ # Flatten for metric comparison
38
+ pred_flat = pred.flatten()
39
+ gt_flat = gt.flatten()
40
+
41
+ # Try both label assignments
42
+ iou_0 = jaccard_score(gt_flat, (pred_flat == 0))
43
+ iou_1 = jaccard_score(gt_flat, (pred_flat == 1))
44
+
45
+ # Flip if label 0 gives better IoU
46
+ if iou_0 > iou_1:
47
+ pred = 1 - pred
48
+
49
+ # Optional: ensure lesion is positive (class 1)
50
+ if lesion_positive:
51
+ # If GT has more lesion pixels than background, make sure pred does too
52
+ gt_lesion_ratio = np.sum(gt) / gt.size
53
+ pred_lesion_ratio = np.sum(pred) / pred.size
54
+
55
+ if pred_lesion_ratio < 0.5 and gt_lesion_ratio > 0.5:
56
+ pred = 1 - pred
57
+
58
+ fixed_preds.append(pred)
59
+
60
+ return fixed_preds
61
+
62
+ def evaluate_masks(pred_masks, gt_masks):
63
+ """
64
+ Evaluate predicted masks.
65
+ Returns mean metrics (accuracy, iou, f1).
66
+ """
67
+ acc_list = []
68
+ iou_list = []
69
+ f1_list = []
70
+ cm = np.zeros((2, 2), dtype=int)
71
+ for pred, gt in zip(pred_masks, gt_masks):
72
+ pred_flat = pred.flatten()
73
+ gt_flat = (gt.flatten() > 0).astype(np.uint8)
74
+
75
+ acc0 = accuracy_score(gt_flat, (pred_flat == 0))
76
+ acc1 = accuracy_score(gt_flat, (pred_flat == 1))
77
+
78
+ acc = accuracy_score(gt_flat, pred_flat)
79
+ iou = jaccard_score(gt_flat, pred_flat)
80
+ f1 = f1_score(gt_flat, pred_flat)
81
+
82
+ acc_list.append(acc)
83
+ iou_list.append(iou)
84
+ f1_list.append(f1)
85
+ cm += confusion_matrix(gt_flat, pred_flat, labels=[0, 1])
86
+
87
+ mean_acc = np.mean(acc_list)
88
+ mean_iou = np.mean(iou_list)
89
+ mean_f1 = np.mean(f1_list)
90
+
91
+ print(f"Mean Accuracy: {mean_acc:.4f}")
92
+ print(f"Mean IoU (Jaccard): {mean_iou:.4f}")
93
+ print(f"Mean F1 Score (Dice): {mean_f1:.4f}")
94
+
95
+ disp = ConfusionMatrixDisplay(cm, display_labels=["Background", "Lesion"])
96
+ disp.plot(cmap="Blues", values_format="d")
97
+ plt.title("Confusion Matrix (Aggregated)")
98
+ plt.show()
99
+
100
+ # Plot histograms
101
+ plt.figure(figsize=(15, 4))
102
+ plt.subplot(1, 3, 1)
103
+ plt.hist(acc_list, bins=10, color='r', alpha=0.6, edgecolor='black')
104
+ plt.title("Accuracy Distribution")
105
+
106
+ plt.subplot(1, 3, 2)
107
+ plt.hist(iou_list, bins=10, color='g', alpha=0.6, edgecolor='black')
108
+ plt.title("IoU Distribution")
109
+
110
+ plt.subplot(1, 3, 3)
111
+ plt.hist(f1_list, bins=10, color='skyblue', alpha=0.6, edgecolor='black')
112
+ plt.title("F1 Score Distribution")
113
+
114
+ plt.tight_layout()
115
+ plt.show()
116
+
117
+ def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5):
118
+ """
119
+ Overlay a binary mask on top of an image.
120
+ - image: (H, W, 3) numpy array, RGB
121
+ - mask: (H, W) numpy array, 0/1 values or 0/255
122
+ - color: RGB tuple for mask color
123
+ - alpha: transparency factor (0=transparent, 1=opaque)
124
+ """
125
+ image = image.copy()
126
+
127
+ # Make sure mask is binary 0 or 1
128
+ if mask.max() > 1:
129
+ mask = (mask > 127).astype(np.uint8)
130
+
131
+ # Create colored mask
132
+ colored_mask = np.zeros_like(image)
133
+ colored_mask[:, :, 0] = color[0]
134
+ colored_mask[:, :, 1] = color[1]
135
+ colored_mask[:, :, 2] = color[2]
136
+
137
+ # Apply mask
138
+ mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
139
+ overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image)
140
+
141
+ return overlay.astype(np.uint8)
142
+
143
+
144
+ def visualize_overlay(image, gt_mask, pred_mask, post_mask=None, alpha=0.5):
145
+ """
146
+ Plot original image + overlay GT mask and Predicted mask.
147
+ """
148
+ plt.figure(figsize=(18, 6))
149
+
150
+ # Original
151
+ plt.subplot(1, 3, 1)
152
+ plt.imshow(image)
153
+ plt.title("Original Image")
154
+ plt.axis("off")
155
+
156
+ # Ground Truth Overlay
157
+ overlay_gt = overlay_mask(image, gt_mask, color=(0, 255, 0), alpha=alpha)
158
+ plt.subplot(1, 3, 2)
159
+ plt.imshow(overlay_gt)
160
+ plt.title("Ground Truth Overlay (Green)")
161
+ plt.axis("off")
162
+
163
+ # Predicted Overlay
164
+ overlay_pred = overlay_mask(image, pred_mask, color=(255, 0, 0), alpha=alpha)
165
+ plt.subplot(1, 3, 3)
166
+ plt.imshow(overlay_pred)
167
+ plt.title("Prediction Overlay (Red)")
168
+ plt.axis("off")
169
+
170
+ plt.tight_layout()
171
+ plt.show()
172
+
173
+ def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'):
174
+ image = Image.fromarray(image_path).convert('RGB')
175
+ original_np = np.array(image.resize((128, 128)))
176
+
177
+ transform = transforms.Compose([
178
+ transforms.Resize((128, 128)),
179
+ transforms.ToTensor()
180
+ ])
181
+ input_tensor = transform(image).unsqueeze(0).to(device)
182
+
183
+ if isinstance(model, (UNet, Segformer, Inception)):
184
+ with torch.no_grad():
185
+ output = model(input_tensor)
186
+ if isinstance(output, dict):
187
+ output = output.get("logits") or output.get("out")
188
+ pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
189
+ elif isinstance(model, (KMeans, GaussianMixture)):
190
+ model.fit(original_np.reshape(-1, 3))
191
+ pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128)
192
+
193
+ if postprocess_mode != 'none':
194
+ pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0]
195
+
196
+ bw_mask = (pred_mask * 255).astype(np.uint8)
197
+ overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha)
198
+ # Resize outputs to 384x384
199
+ bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST)
200
+ overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha),
201
+ (256, 256),
202
+ interpolation=cv2.INTER_LINEAR
203
+ )
204
+
205
+ return bw_mask, overlay