|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision.models import resnet18 |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import hf_hub_download |
|
|
import numpy as np |
|
|
import random |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import io |
|
|
from torch.utils.data import DataLoader |
|
|
import base64 |
|
|
|
|
|
|
|
|
class ResNet18_Dropout(nn.Module): |
|
|
def __init__(self, in_channels, num_classes, dropout_rate=0.3): |
|
|
super().__init__() |
|
|
self.model = resnet18(weights=None) |
|
|
self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
in_features = self.model.fc.in_features |
|
|
self.model.fc = nn.Sequential( |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(in_features, num_classes) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.model(x) |
|
|
|
|
|
def transform_multispectral_map(example): |
|
|
image = np.array(example["image"], dtype=np.float32) |
|
|
|
|
|
if image.ndim != 3 or image.shape[2] != 13: |
|
|
raise ValueError(f"Expected shape (H, W, 13), got {image.shape}") |
|
|
|
|
|
|
|
|
image = image / 2750.0 |
|
|
image = np.clip(image, 0, 1) |
|
|
|
|
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
image = np.flip(image, axis=1).copy() |
|
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
image = np.flip(image, axis=0).copy() |
|
|
|
|
|
|
|
|
if random.random() < 0.5: |
|
|
k = random.choice([1, 2, 3]) |
|
|
image = np.rot90(image, k=k, axes=(0, 1)).copy() |
|
|
|
|
|
|
|
|
image = image.transpose(2, 0, 1) |
|
|
|
|
|
return { |
|
|
"image": torch.tensor(image, dtype=torch.float32), |
|
|
"label": torch.tensor(example["label"], dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
def load_rgb_from_multispectral_sample(numpy_array): |
|
|
""" |
|
|
Takes a NumPy array with 13 multispectral bands and returns a scaled RGB NumPy array. |
|
|
Equivalent to loading bands 4-3-2 and scaling as GDAL would. |
|
|
""" |
|
|
|
|
|
def scale_band(band): |
|
|
band = np.clip((band / 2750) * 255, 0, 255) |
|
|
return band.astype(np.uint8) |
|
|
|
|
|
|
|
|
bands = [3, 2, 1] |
|
|
|
|
|
|
|
|
if not isinstance(numpy_array, np.ndarray): |
|
|
raise TypeError("Input must be a NumPy array") |
|
|
|
|
|
|
|
|
if numpy_array.shape[-1] != 13: |
|
|
raise ValueError(f"Input array must have 13 channels, but got {numpy_array.shape[-1]}") |
|
|
|
|
|
|
|
|
rgb = np.stack([scale_band(numpy_array[:, :, b]) for b in bands], axis=-1) |
|
|
return rgb |
|
|
|
|
|
def load_rgb_from_transformed_tensor(tensor_image): |
|
|
""" |
|
|
Takes a torch.Tensor with 13 multispectral bands (C, H, W) and returns a scaled RGB NumPy array. |
|
|
""" |
|
|
if not isinstance(tensor_image, torch.Tensor): |
|
|
raise TypeError("Input must be a torch.Tensor") |
|
|
if tensor_image.shape[0] != 13: |
|
|
raise ValueError(f"Expected 13 channels, got {tensor_image.shape[0]}") |
|
|
|
|
|
|
|
|
np_image = tensor_image.numpy() |
|
|
np_image = np.transpose(np_image, (1, 2, 0)) |
|
|
|
|
|
|
|
|
bands = [3, 2, 1] |
|
|
|
|
|
def scale_band(band): |
|
|
band = np.clip((band * 255), 0, 255) |
|
|
return band.astype(np.uint8) |
|
|
|
|
|
rgb = np.stack([scale_band(np_image[:, :, b]) for b in bands], axis=-1) |
|
|
return rgb |
|
|
|
|
|
|
|
|
model = None |
|
|
dataset = None |
|
|
label_names = None |
|
|
label2id = None |
|
|
id2label = None |
|
|
|
|
|
def load_model_and_data(): |
|
|
"""Load the model and dataset""" |
|
|
global model, dataset, label_names, label2id, id2label |
|
|
|
|
|
try: |
|
|
|
|
|
print("Loading dataset...") |
|
|
dataset = load_dataset("blanchon/EuroSAT_MSI", cache_dir="./hf_cache", streaming=False) |
|
|
dataset["test"] = dataset["test"].map(transform_multispectral_map) |
|
|
dataset["test"].set_format(type="torch", columns=["image", "label"]) |
|
|
|
|
|
|
|
|
label_names = dataset["train"].features['label'].names |
|
|
label2id = {name: i for i, name in enumerate(label_names)} |
|
|
id2label = {v: k for k, v in label2id.items()} |
|
|
num_classes = len(label_names) |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model_path = hf_hub_download(repo_id="Rhodham96/Resnet18DropoutSentinel", filename="pytorch_model.bin") |
|
|
model = ResNet18_Dropout(in_channels=13, num_classes=num_classes) |
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
print(f"Model and dataset loaded successfully!") |
|
|
print(f"Classes: {label_names}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model or dataset: {str(e)}") |
|
|
return False |
|
|
|
|
|
def predict_images(): |
|
|
"""Process 16 random images and return results""" |
|
|
global model, dataset, id2label |
|
|
|
|
|
if model is None or dataset is None: |
|
|
return "Model or dataset not loaded. Please wait for initialization." |
|
|
|
|
|
test_dataloader = DataLoader(dataset["test"], batch_size=32, shuffle=True) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
num_batches = 5 |
|
|
collected_images = [] |
|
|
collected_labels = [] |
|
|
collected_preds = [] |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for i, batch in enumerate(test_dataloader): |
|
|
if i >= num_batches: |
|
|
break |
|
|
images = batch['image'] |
|
|
labels = batch['label'] |
|
|
|
|
|
outputs = model(images) |
|
|
_, preds = outputs.max(1) |
|
|
|
|
|
collected_images.append(images.cpu()) |
|
|
collected_labels.append(labels.cpu()) |
|
|
collected_preds.append(preds.cpu()) |
|
|
|
|
|
|
|
|
images = torch.cat(collected_images) |
|
|
labels = torch.cat(collected_labels) |
|
|
preds = torch.cat(collected_preds) |
|
|
|
|
|
|
|
|
indices = random.sample(range(len(images)), 10) |
|
|
|
|
|
|
|
|
selected_images = images[indices] |
|
|
selected_labels = labels[indices] |
|
|
selected_preds = preds[indices] |
|
|
image_to_see_layers = selected_images[0] |
|
|
label_to_see_layers = selected_labels[0] |
|
|
|
|
|
fig, axes = plt.subplots(2, 5, figsize=(15, 6)) |
|
|
axes = axes.flatten() |
|
|
|
|
|
for i in range(10): |
|
|
img = load_rgb_from_transformed_tensor(selected_images[i]) |
|
|
|
|
|
axes[i].imshow(img) |
|
|
axes[i].axis("off") |
|
|
true_label = id2label[selected_labels[i].item()] |
|
|
pred_label = id2label[selected_preds[i].item()] |
|
|
color = "green" if pred_label == true_label else "red" |
|
|
axes[i].set_title(f"T: {true_label}\nP: {pred_label}", color=color) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight') |
|
|
buf.seek(0) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
result_image = Image.open(buf) |
|
|
|
|
|
|
|
|
correct_predictions = (selected_preds == selected_labels).sum().item() |
|
|
accuracy = correct_predictions / len(selected_labels) * 100 |
|
|
summary = f"Accuracy: {correct_predictions}/{len(selected_labels)} ({accuracy:.1f}%)\n" |
|
|
summary += f"Classes: {', '.join(label_names)}" |
|
|
|
|
|
return result_image, summary |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error during prediction: {str(e)}" |
|
|
print(error_msg) |
|
|
|
|
|
placeholder = Image.new('RGB', (800, 600), color='lightgray') |
|
|
return placeholder, error_msg |
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
|
|
|
init_success = load_model_and_data() |
|
|
|
|
|
if not init_success: |
|
|
def error_function(): |
|
|
placeholder = Image.new('RGB', (800, 600), color='red') |
|
|
return placeholder, "Failed to load model or dataset. Please check the logs." |
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=error_function, |
|
|
inputs=[], |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Results"), |
|
|
gr.Textbox(label="Summary") |
|
|
], |
|
|
title="π°οΈ Satellite Image Classification - ERROR", |
|
|
description="Failed to initialize the application." |
|
|
) |
|
|
return interface |
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=predict_images, |
|
|
inputs=[], |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Classification Results (10 Random Images)"), |
|
|
gr.Textbox(label="Summary", lines=3) |
|
|
], |
|
|
title="π°οΈ Satellite Image Classification with ResNet18", |
|
|
description=""" |
|
|
This app classifies satellite images from the EuroSAT dataset using a trained ResNet18 model. |
|
|
|
|
|
**How it works:** |
|
|
- Loads 10 random satellite images from the test set |
|
|
- Each image has 13 spectral bands, converted to RGB for display |
|
|
- Shows true labels vs predicted labels |
|
|
- Green titles = correct predictions, Red titles = incorrect predictions |
|
|
|
|
|
**Dataset:** EuroSAT with 13 multispectral bands |
|
|
**Model:** ResNet18 with dropout, trained on 13-channel input |
|
|
|
|
|
Click "Generate" to process 10 new random images! |
|
|
""", |
|
|
examples=[], |
|
|
cache_examples=False, |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_interface() |
|
|
demo.launch(share=True) |