Upload 15 files
Browse files- .gitattributes +2 -0
- README.md +48 -14
- app.py +81 -0
- examples/ISIC_0012880.jpg +3 -0
- examples/ISIC_0015972.jpg +3 -0
- inception.pt +3 -0
- requirements.txt +9 -0
- results/class_distribution.png +0 -0
- results/gmm_distribution.png +0 -0
- results/kmeans_distribution.png +0 -0
- results/segformer_distribution.png +0 -0
- results/unet_distribution.png +0 -0
- segformer.pt +3 -0
- supervised.py +141 -0
- unet.pt +3 -0
- utils.py +205 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/ISIC_0012880.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/ISIC_0015972.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,14 +1,48 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ISIC 2018 Skin Lesion Segmentation
|
| 2 |
+
This project explores unsupervised and supervised image segmentation methods applied to the **ISIC 2018 skin lesion dataset**. It compares simple segmentation techniques like **KMeans** and **Gaussian Mixture Models (GMM)** against deep learning models (Unet, Inception-inspired CNN and SegFormer). The deep models are trained on ISIC data, evaluated on the test set and its performance is compared with the baseline models.
|
| 3 |
+
|
| 4 |
+
## Goals
|
| 5 |
+
- Segment skin lesions from dermoscopic images.
|
| 6 |
+
- Compare baseline unsupervised methods (KMeans, GMM) with deep learning models (Unet, Inception-inspired CNN, SegFormer).
|
| 7 |
+
- Evaluate masks using standard metrics: **IoU**, **F1-score**, **Accuracy**.
|
| 8 |
+
- Visualize results with overlays (predictions vs. ground truth).
|
| 9 |
+
- Explore which morphological operations can improve the quality of the segmentations.
|
| 10 |
+
(erosion, dilation, opening, closing)
|
| 11 |
+
|
| 12 |
+
## Dataset
|
| 13 |
+
- **ISIC 2018 Challenge - Task 1**
|
| 14 |
+
- ~2,600 dermoscopic images and ground truth lesion masks.
|
| 15 |
+
- Downloaded from [ISIC Archive](https://challenge.isic-archive.com/data/#2018).
|
| 16 |
+
The images and masks are stored in different folders.
|
| 17 |
+
- The dataset is split into training, validation, and test sets.
|
| 18 |
+
- Due to the nature of the task, there is a notable class imbalance, which we have to take into account when training and evaluating the models.
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
## Unsupervised Methods
|
| 23 |
+
- **KMeans**: Clustering algorithm that partitions the image into K clusters based on distances between pixel values.
|
| 24 |
+
- Results:
|
| 25 |
+

|
| 26 |
+
|
| 27 |
+
- **Gaussian Mixture Models (GMM)**: Probabilistic model that assumes the presence of multiple Gaussian distributions in the data.
|
| 28 |
+
- Results:
|
| 29 |
+

|
| 30 |
+
|
| 31 |
+
## Supervised Methods
|
| 32 |
+
- **Unet**: A convolutional neural network architecture designed for biomedical image segmentation.
|
| 33 |
+
- Results:
|
| 34 |
+

|
| 35 |
+
- **Inception CNN**: A custom architecture inspired by the Inception model, designed for image segmentation tasks.
|
| 36 |
+
- Results:
|
| 37 |
+
<!--  -->
|
| 38 |
+
- **SegFormer**: A transformer-based model that captures long-range dependencies in images, achieving state-of-the-art results in various vision tasks.
|
| 39 |
+
- Results:
|
| 40 |
+

|
| 41 |
+
|
| 42 |
+
## Evaluation results
|
| 43 |
+
From the evaluation of the models on the test set, we can see that the deep learning models outperform the unsupervised methods in terms of IoU, F1-score, and accuracy. The SegFormer model achieves the best results, followed by Unet and Inception CNN. Overall, it is clear that the deep learning models are able to segment skin lesions better than the unsupervised methods. It is also worth noting that all models significantly outperform the majority baseline (which would achieve an accuracy of 76.4%).
|
| 44 |
+
|
| 45 |
+
## How to Run
|
| 46 |
+
1. **Setup**: Install dependencies with `pip install -r requirements.txt`.
|
| 47 |
+
2. **Training**: Run `python main.py --model segformer --epochs 20 --visualize True` to train the SegFormer (or any of the other models).
|
| 48 |
+
3. **Testing**: Use the `python main.py --model segformer --visualize True` to evaluate and visualize model predictions on the test set, with a `segformer.pt` file in the same directory.
|
app.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from sklearn.cluster import KMeans
|
| 5 |
+
from sklearn.mixture import GaussianMixture
|
| 6 |
+
from utils import *
|
| 7 |
+
from supervised import *
|
| 8 |
+
|
| 9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 10 |
+
|
| 11 |
+
# Load Models
|
| 12 |
+
models = {
|
| 13 |
+
"unet": UNet(num_classes=2).to(device),
|
| 14 |
+
"segformer": Segformer(num_classes=2).to(device),
|
| 15 |
+
"inception": Inception(num_classes=2).to(device),
|
| 16 |
+
"kmeans": KMeans(n_clusters=2),
|
| 17 |
+
"gmm": GaussianMixture(n_components=2),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
models["unet"].load_state_dict(torch.load("unet.pt", map_location=device))
|
| 21 |
+
models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device))
|
| 22 |
+
models["inception"].load_state_dict(torch.load("inception.pt", map_location=device))
|
| 23 |
+
|
| 24 |
+
for model in models.values():
|
| 25 |
+
if isinstance(model, (UNet, Segformer, Inception)):
|
| 26 |
+
model.eval()
|
| 27 |
+
|
| 28 |
+
# Inference function
|
| 29 |
+
def inference(image, model_name, postprocess_mode):
|
| 30 |
+
model = models[model_name]
|
| 31 |
+
status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}"
|
| 32 |
+
bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode)
|
| 33 |
+
return overlay, bw_mask, status_text
|
| 34 |
+
|
| 35 |
+
# Gradio Interface
|
| 36 |
+
with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo:
|
| 37 |
+
gr.Markdown("## 🩺 Skin Lesion Segmentation")
|
| 38 |
+
gr.Markdown("Upload a skin image, choose a model, and view segmentation results.")
|
| 39 |
+
|
| 40 |
+
with gr.Row():
|
| 41 |
+
with gr.Column(scale=1):
|
| 42 |
+
image_input = gr.Image(type='numpy', label="📷 Upload Image")
|
| 43 |
+
model_choice = gr.Radio(
|
| 44 |
+
choices=["unet", "segformer", "inception", "kmeans", "gmm"],
|
| 45 |
+
label="Model",
|
| 46 |
+
value="unet"
|
| 47 |
+
)
|
| 48 |
+
post_choice = gr.Radio(
|
| 49 |
+
choices=["none", "open", "close", "erosion", "dilation"],
|
| 50 |
+
label="Postprocessing",
|
| 51 |
+
value="none"
|
| 52 |
+
)
|
| 53 |
+
run_btn = gr.Button("▶ Run Segmentation")
|
| 54 |
+
|
| 55 |
+
with gr.Column(scale=2):
|
| 56 |
+
with gr.Row():
|
| 57 |
+
overlay_output = gr.Image(type='numpy', label="🎯 Overlay")
|
| 58 |
+
mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask")
|
| 59 |
+
status = gr.Textbox(label="Status", interactive=False)
|
| 60 |
+
|
| 61 |
+
with gr.Row():
|
| 62 |
+
gr.Examples(
|
| 63 |
+
examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"],
|
| 64 |
+
inputs=[image_input],
|
| 65 |
+
label="Use Example Images"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
with gr.Accordion("ℹ️ Legend", open=False):
|
| 69 |
+
gr.Markdown("""
|
| 70 |
+
- **🔴 Red**: Predicted lesion overlay
|
| 71 |
+
- **⚫ White**: Binary mask
|
| 72 |
+
- **Postprocessing**: Cleans up noisy segmentation
|
| 73 |
+
""")
|
| 74 |
+
|
| 75 |
+
run_btn.click(
|
| 76 |
+
fn=inference,
|
| 77 |
+
inputs=[image_input, model_choice, post_choice],
|
| 78 |
+
outputs=[overlay_output, mask_output, status]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
demo.launch(share=True)
|
examples/ISIC_0012880.jpg
ADDED
|
Git LFS Details
|
examples/ISIC_0015972.jpg
ADDED
|
Git LFS Details
|
inception.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18c17a1d87be5906b31a76558132e3c3fc16e643747b8e0859c25cb914eadce9
|
| 3 |
+
size 14355657
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
numpy
|
| 4 |
+
pillow
|
| 5 |
+
scikit-learn
|
| 6 |
+
matplotlib
|
| 7 |
+
opencv-python
|
| 8 |
+
gradio
|
| 9 |
+
transformers
|
results/class_distribution.png
ADDED
|
results/gmm_distribution.png
ADDED
|
results/kmeans_distribution.png
ADDED
|
results/segformer_distribution.png
ADDED
|
results/unet_distribution.png
ADDED
|
segformer.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0dcb4e5c9d19ab4caffaa324e4cf26090bcc9d37fb47840e38d81268691d8341
|
| 3 |
+
size 14946661
|
supervised.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
|
| 6 |
+
|
| 7 |
+
#=======================================
|
| 8 |
+
#========= UNet Architecture ===========
|
| 9 |
+
#=======================================
|
| 10 |
+
class UNet(nn.Module):
|
| 11 |
+
def __init__(self, in_channels=3, num_classes=2):
|
| 12 |
+
super(UNet, self).__init__()
|
| 13 |
+
|
| 14 |
+
def conv_block(in_c, out_c):
|
| 15 |
+
return nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
| 17 |
+
nn.BatchNorm2d(out_c),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
|
| 20 |
+
nn.BatchNorm2d(out_c),
|
| 21 |
+
nn.ReLU(inplace=True)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
self.encoder1 = conv_block(in_channels, 64)
|
| 25 |
+
self.pool1 = nn.MaxPool2d(2)
|
| 26 |
+
|
| 27 |
+
self.encoder2 = conv_block(64, 128)
|
| 28 |
+
self.pool2 = nn.MaxPool2d(2)
|
| 29 |
+
|
| 30 |
+
self.encoder3 = conv_block(128, 256)
|
| 31 |
+
self.pool3 = nn.MaxPool2d(2)
|
| 32 |
+
|
| 33 |
+
self.bottleneck = conv_block(256, 512)
|
| 34 |
+
|
| 35 |
+
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
|
| 36 |
+
self.decoder3 = conv_block(512, 256)
|
| 37 |
+
|
| 38 |
+
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
|
| 39 |
+
self.decoder2 = conv_block(256, 128)
|
| 40 |
+
|
| 41 |
+
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
| 42 |
+
self.decoder1 = conv_block(128, 64)
|
| 43 |
+
|
| 44 |
+
self.final = nn.Conv2d(64, num_classes, kernel_size=1)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
enc1 = self.encoder1(x)
|
| 48 |
+
enc2 = self.encoder2(self.pool1(enc1))
|
| 49 |
+
enc3 = self.encoder3(self.pool2(enc2))
|
| 50 |
+
|
| 51 |
+
bottleneck = self.bottleneck(self.pool3(enc3))
|
| 52 |
+
|
| 53 |
+
dec3 = self.upconv3(bottleneck)
|
| 54 |
+
dec3 = torch.cat((dec3, enc3), dim=1)
|
| 55 |
+
dec3 = self.decoder3(dec3)
|
| 56 |
+
|
| 57 |
+
dec2 = self.upconv2(dec3)
|
| 58 |
+
dec2 = torch.cat((dec2, enc2), dim=1)
|
| 59 |
+
dec2 = self.decoder2(dec2)
|
| 60 |
+
|
| 61 |
+
dec1 = self.upconv1(dec2)
|
| 62 |
+
dec1 = torch.cat((dec1, enc1), dim=1)
|
| 63 |
+
dec1 = self.decoder1(dec1)
|
| 64 |
+
|
| 65 |
+
return self.final(dec1)
|
| 66 |
+
|
| 67 |
+
#=======================================
|
| 68 |
+
#======= Inception Architecture ========
|
| 69 |
+
#=======================================
|
| 70 |
+
class InceptionBlock(nn.Module):
|
| 71 |
+
def __init__(self, in_channels, out_channels):
|
| 72 |
+
super(InceptionBlock, self).__init__()
|
| 73 |
+
self.b1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 74 |
+
nn.ReLU(inplace=True))
|
| 75 |
+
self.b2 = nn.Sequential(
|
| 76 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 77 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 78 |
+
nn.ReLU(inplace=True)
|
| 79 |
+
)
|
| 80 |
+
self.b3 = nn.Sequential(
|
| 81 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 82 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=5, padding=2),
|
| 83 |
+
nn.ReLU(inplace=True)
|
| 84 |
+
)
|
| 85 |
+
self.b4 = nn.Sequential(
|
| 86 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
| 87 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
| 88 |
+
nn.ReLU(inplace=True)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
b1 = self.b1(x)
|
| 93 |
+
b2 = self.b2(x)
|
| 94 |
+
b3 = self.b3(x)
|
| 95 |
+
b4 = self.b4(x)
|
| 96 |
+
return torch.cat([b1, b2, b3, b4], dim=1)
|
| 97 |
+
|
| 98 |
+
class Inception(nn.Module):
|
| 99 |
+
def __init__(self, in_channels=3, num_classes=2):
|
| 100 |
+
super(Inception, self).__init__()
|
| 101 |
+
self.weights_init()
|
| 102 |
+
self.inception1 = InceptionBlock(in_channels, 64)
|
| 103 |
+
self.inception2 = InceptionBlock(256, 128)
|
| 104 |
+
self.inception3 = InceptionBlock(512, 256)
|
| 105 |
+
|
| 106 |
+
self.conv1x1 = nn.Conv2d(1024, num_classes, kernel_size=1)
|
| 107 |
+
self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
| 108 |
+
|
| 109 |
+
def weights_init(self):
|
| 110 |
+
for m in self.modules():
|
| 111 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 112 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
height, width = x.shape[2], x.shape[3]
|
| 116 |
+
x = self.inception1(x)
|
| 117 |
+
x = self.inception2(x)
|
| 118 |
+
x = self.inception3(x)
|
| 119 |
+
x = self.conv1x1(x)
|
| 120 |
+
x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
#=======================================
|
| 124 |
+
#======= Swin Transformer ==============
|
| 125 |
+
#=======================================
|
| 126 |
+
class Segformer(nn.Module):
|
| 127 |
+
def __init__(self, model_name='nvidia/segformer-b0-finetuned-ade-512-512', num_classes=2):
|
| 128 |
+
super(Segformer, self).__init__()
|
| 129 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained(
|
| 130 |
+
model_name,
|
| 131 |
+
num_labels=num_classes,
|
| 132 |
+
ignore_mismatched_sizes=True
|
| 133 |
+
)
|
| 134 |
+
self.processor = SegformerImageProcessor.from_pretrained(model_name)
|
| 135 |
+
self.normalizer = T.Normalize(mean=self.processor.image_mean, std=self.processor.image_std)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
x = self.normalizer(x)
|
| 139 |
+
logits = self.model(pixel_values=x).logits # Shape: [B, C, H', W']
|
| 140 |
+
logits = F.interpolate(logits, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)
|
| 141 |
+
return logits
|
unet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d58601e6433b7ebeab5ec4249ca02e413b62b28cd9e69b1719a368cb6deae5bb
|
| 3 |
+
size 30861979
|
utils.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from supervised import UNet, Segformer, Inception
|
| 7 |
+
from sklearn.cluster import KMeans
|
| 8 |
+
from sklearn.mixture import GaussianMixture
|
| 9 |
+
from torchvision import transforms
|
| 10 |
+
from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
|
| 11 |
+
|
| 12 |
+
def postprocess(masks, mode="open", kernel_size=5, iters=1):
|
| 13 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 14 |
+
if mode == "open":
|
| 15 |
+
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks]
|
| 16 |
+
elif mode == "close":
|
| 17 |
+
new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks]
|
| 18 |
+
elif mode == "erosion":
|
| 19 |
+
new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
|
| 20 |
+
elif mode == "dilation":
|
| 21 |
+
new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
|
| 22 |
+
else:
|
| 23 |
+
new_masks = masks
|
| 24 |
+
return new_masks
|
| 25 |
+
|
| 26 |
+
def fix_labels(pred_masks, gt_masks, lesion_positive=True):
|
| 27 |
+
"""
|
| 28 |
+
Flip predicted masks if needed based on GT, and ensure lesion is 1.
|
| 29 |
+
If lesion_positive=True, final output has lesion as 1.
|
| 30 |
+
"""
|
| 31 |
+
fixed_preds = []
|
| 32 |
+
|
| 33 |
+
for pred, gt in zip(pred_masks, gt_masks):
|
| 34 |
+
pred = pred.astype(np.uint8)
|
| 35 |
+
gt = (gt > 0).astype(np.uint8)
|
| 36 |
+
|
| 37 |
+
# Flatten for metric comparison
|
| 38 |
+
pred_flat = pred.flatten()
|
| 39 |
+
gt_flat = gt.flatten()
|
| 40 |
+
|
| 41 |
+
# Try both label assignments
|
| 42 |
+
iou_0 = jaccard_score(gt_flat, (pred_flat == 0))
|
| 43 |
+
iou_1 = jaccard_score(gt_flat, (pred_flat == 1))
|
| 44 |
+
|
| 45 |
+
# Flip if label 0 gives better IoU
|
| 46 |
+
if iou_0 > iou_1:
|
| 47 |
+
pred = 1 - pred
|
| 48 |
+
|
| 49 |
+
# Optional: ensure lesion is positive (class 1)
|
| 50 |
+
if lesion_positive:
|
| 51 |
+
# If GT has more lesion pixels than background, make sure pred does too
|
| 52 |
+
gt_lesion_ratio = np.sum(gt) / gt.size
|
| 53 |
+
pred_lesion_ratio = np.sum(pred) / pred.size
|
| 54 |
+
|
| 55 |
+
if pred_lesion_ratio < 0.5 and gt_lesion_ratio > 0.5:
|
| 56 |
+
pred = 1 - pred
|
| 57 |
+
|
| 58 |
+
fixed_preds.append(pred)
|
| 59 |
+
|
| 60 |
+
return fixed_preds
|
| 61 |
+
|
| 62 |
+
def evaluate_masks(pred_masks, gt_masks):
|
| 63 |
+
"""
|
| 64 |
+
Evaluate predicted masks.
|
| 65 |
+
Returns mean metrics (accuracy, iou, f1).
|
| 66 |
+
"""
|
| 67 |
+
acc_list = []
|
| 68 |
+
iou_list = []
|
| 69 |
+
f1_list = []
|
| 70 |
+
cm = np.zeros((2, 2), dtype=int)
|
| 71 |
+
for pred, gt in zip(pred_masks, gt_masks):
|
| 72 |
+
pred_flat = pred.flatten()
|
| 73 |
+
gt_flat = (gt.flatten() > 0).astype(np.uint8)
|
| 74 |
+
|
| 75 |
+
acc0 = accuracy_score(gt_flat, (pred_flat == 0))
|
| 76 |
+
acc1 = accuracy_score(gt_flat, (pred_flat == 1))
|
| 77 |
+
|
| 78 |
+
acc = accuracy_score(gt_flat, pred_flat)
|
| 79 |
+
iou = jaccard_score(gt_flat, pred_flat)
|
| 80 |
+
f1 = f1_score(gt_flat, pred_flat)
|
| 81 |
+
|
| 82 |
+
acc_list.append(acc)
|
| 83 |
+
iou_list.append(iou)
|
| 84 |
+
f1_list.append(f1)
|
| 85 |
+
cm += confusion_matrix(gt_flat, pred_flat, labels=[0, 1])
|
| 86 |
+
|
| 87 |
+
mean_acc = np.mean(acc_list)
|
| 88 |
+
mean_iou = np.mean(iou_list)
|
| 89 |
+
mean_f1 = np.mean(f1_list)
|
| 90 |
+
|
| 91 |
+
print(f"Mean Accuracy: {mean_acc:.4f}")
|
| 92 |
+
print(f"Mean IoU (Jaccard): {mean_iou:.4f}")
|
| 93 |
+
print(f"Mean F1 Score (Dice): {mean_f1:.4f}")
|
| 94 |
+
|
| 95 |
+
disp = ConfusionMatrixDisplay(cm, display_labels=["Background", "Lesion"])
|
| 96 |
+
disp.plot(cmap="Blues", values_format="d")
|
| 97 |
+
plt.title("Confusion Matrix (Aggregated)")
|
| 98 |
+
plt.show()
|
| 99 |
+
|
| 100 |
+
# Plot histograms
|
| 101 |
+
plt.figure(figsize=(15, 4))
|
| 102 |
+
plt.subplot(1, 3, 1)
|
| 103 |
+
plt.hist(acc_list, bins=10, color='r', alpha=0.6, edgecolor='black')
|
| 104 |
+
plt.title("Accuracy Distribution")
|
| 105 |
+
|
| 106 |
+
plt.subplot(1, 3, 2)
|
| 107 |
+
plt.hist(iou_list, bins=10, color='g', alpha=0.6, edgecolor='black')
|
| 108 |
+
plt.title("IoU Distribution")
|
| 109 |
+
|
| 110 |
+
plt.subplot(1, 3, 3)
|
| 111 |
+
plt.hist(f1_list, bins=10, color='skyblue', alpha=0.6, edgecolor='black')
|
| 112 |
+
plt.title("F1 Score Distribution")
|
| 113 |
+
|
| 114 |
+
plt.tight_layout()
|
| 115 |
+
plt.show()
|
| 116 |
+
|
| 117 |
+
def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5):
|
| 118 |
+
"""
|
| 119 |
+
Overlay a binary mask on top of an image.
|
| 120 |
+
- image: (H, W, 3) numpy array, RGB
|
| 121 |
+
- mask: (H, W) numpy array, 0/1 values or 0/255
|
| 122 |
+
- color: RGB tuple for mask color
|
| 123 |
+
- alpha: transparency factor (0=transparent, 1=opaque)
|
| 124 |
+
"""
|
| 125 |
+
image = image.copy()
|
| 126 |
+
|
| 127 |
+
# Make sure mask is binary 0 or 1
|
| 128 |
+
if mask.max() > 1:
|
| 129 |
+
mask = (mask > 127).astype(np.uint8)
|
| 130 |
+
|
| 131 |
+
# Create colored mask
|
| 132 |
+
colored_mask = np.zeros_like(image)
|
| 133 |
+
colored_mask[:, :, 0] = color[0]
|
| 134 |
+
colored_mask[:, :, 1] = color[1]
|
| 135 |
+
colored_mask[:, :, 2] = color[2]
|
| 136 |
+
|
| 137 |
+
# Apply mask
|
| 138 |
+
mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
| 139 |
+
overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image)
|
| 140 |
+
|
| 141 |
+
return overlay.astype(np.uint8)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def visualize_overlay(image, gt_mask, pred_mask, post_mask=None, alpha=0.5):
|
| 145 |
+
"""
|
| 146 |
+
Plot original image + overlay GT mask and Predicted mask.
|
| 147 |
+
"""
|
| 148 |
+
plt.figure(figsize=(18, 6))
|
| 149 |
+
|
| 150 |
+
# Original
|
| 151 |
+
plt.subplot(1, 3, 1)
|
| 152 |
+
plt.imshow(image)
|
| 153 |
+
plt.title("Original Image")
|
| 154 |
+
plt.axis("off")
|
| 155 |
+
|
| 156 |
+
# Ground Truth Overlay
|
| 157 |
+
overlay_gt = overlay_mask(image, gt_mask, color=(0, 255, 0), alpha=alpha)
|
| 158 |
+
plt.subplot(1, 3, 2)
|
| 159 |
+
plt.imshow(overlay_gt)
|
| 160 |
+
plt.title("Ground Truth Overlay (Green)")
|
| 161 |
+
plt.axis("off")
|
| 162 |
+
|
| 163 |
+
# Predicted Overlay
|
| 164 |
+
overlay_pred = overlay_mask(image, pred_mask, color=(255, 0, 0), alpha=alpha)
|
| 165 |
+
plt.subplot(1, 3, 3)
|
| 166 |
+
plt.imshow(overlay_pred)
|
| 167 |
+
plt.title("Prediction Overlay (Red)")
|
| 168 |
+
plt.axis("off")
|
| 169 |
+
|
| 170 |
+
plt.tight_layout()
|
| 171 |
+
plt.show()
|
| 172 |
+
|
| 173 |
+
def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'):
|
| 174 |
+
image = Image.fromarray(image_path).convert('RGB')
|
| 175 |
+
original_np = np.array(image.resize((128, 128)))
|
| 176 |
+
|
| 177 |
+
transform = transforms.Compose([
|
| 178 |
+
transforms.Resize((128, 128)),
|
| 179 |
+
transforms.ToTensor()
|
| 180 |
+
])
|
| 181 |
+
input_tensor = transform(image).unsqueeze(0).to(device)
|
| 182 |
+
|
| 183 |
+
if isinstance(model, (UNet, Segformer, Inception)):
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
output = model(input_tensor)
|
| 186 |
+
if isinstance(output, dict):
|
| 187 |
+
output = output.get("logits") or output.get("out")
|
| 188 |
+
pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
|
| 189 |
+
elif isinstance(model, (KMeans, GaussianMixture)):
|
| 190 |
+
model.fit(original_np.reshape(-1, 3))
|
| 191 |
+
pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128)
|
| 192 |
+
|
| 193 |
+
if postprocess_mode != 'none':
|
| 194 |
+
pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0]
|
| 195 |
+
|
| 196 |
+
bw_mask = (pred_mask * 255).astype(np.uint8)
|
| 197 |
+
overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha)
|
| 198 |
+
# Resize outputs to 384x384
|
| 199 |
+
bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST)
|
| 200 |
+
overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha),
|
| 201 |
+
(256, 256),
|
| 202 |
+
interpolation=cv2.INTER_LINEAR
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return bw_mask, overlay
|