|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import joblib |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.patches import Patch |
|
|
import matplotlib |
|
|
from shapely.geometry import shape, Point |
|
|
import folium |
|
|
from folium.plugins import Draw |
|
|
from io import BytesIO |
|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from PIL import Image |
|
|
import ee |
|
|
from datetime import datetime, timedelta |
|
|
import rasterio |
|
|
from rasterio.transform import xy |
|
|
|
|
|
|
|
|
|
|
|
KEY_PATH = "gee-service-key.json" |
|
|
|
|
|
|
|
|
SERVICE_ACCOUNT = "[email protected]" |
|
|
|
|
|
credentials = ee.ServiceAccountCredentials(SERVICE_ACCOUNT, KEY_PATH) |
|
|
|
|
|
try: |
|
|
ee.Initialize(credentials) |
|
|
print("β
Earth Engine initialized with service account.") |
|
|
except Exception as e: |
|
|
print(f"β Initialization failed: {e}") |
|
|
|
|
|
|
|
|
crop_season_dict = { |
|
|
"Punjab": { |
|
|
"Rabi": [ |
|
|
"wheat", "barley", "gram (chickpea)", "lentil", "mustard", "rapeseed mustard", |
|
|
"linseed", "peas", "garlic", "onion", "coriander", "fennel", "potato", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
], |
|
|
"Kharif": [ |
|
|
"cotton", "rice", "sugarcane", "maize", "sesame", "millet", "sorghum", "sunflower", |
|
|
"groundnuts", "okra", "tomato", "chillies", "banana", "mango", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
] |
|
|
}, |
|
|
"Sindh": { |
|
|
"Rabi": [ |
|
|
"wheat", "barley", "peas", "gram (chickpea)", "mustard", "onion", "garlic", "spinach", |
|
|
"coriander", "potato", "fennel", "turnip", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
], |
|
|
"Kharif": [ |
|
|
"cotton", "rice", "sugarcane", "maize", "sesame", "millet", "okra", "tomato", |
|
|
"chillies", "banana", "mango", "sunflower", "guava", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
] |
|
|
}, |
|
|
"Balochistan": { |
|
|
"Rabi": [ |
|
|
"wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "potato", |
|
|
"onion", "coriander", "fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
], |
|
|
"Kharif": [ |
|
|
"maize", "rice", "millet", "sorghum", "peach", "apple", "grapes", "tomato", |
|
|
"chillies", "pomegranate", "groundnuts", "sunflower", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
] |
|
|
}, |
|
|
"Khyber Pakhtunkhwa": { |
|
|
"Rabi": [ |
|
|
"wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "onion", |
|
|
"garlic", "turnip", "potato", "coriander", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
], |
|
|
"Kharif": [ |
|
|
"maize", "rice", "sugarcane", "tomato", "chillies", "peach", "plum", "apricot", |
|
|
"apple", "mango", "sunflower", "okra", "sesame", |
|
|
"fallow (agriculture)", "water", "barren", "shrubs", "forest" |
|
|
] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
class CropClassifier(nn.Module): |
|
|
def __init__(self, input_size, num_classes): |
|
|
super(CropClassifier, self).__init__() |
|
|
self.network = nn.Sequential( |
|
|
nn.Linear(input_size, 512), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.LeakyReLU(), |
|
|
nn.Dropout(0.4), |
|
|
nn.Linear(512, 256), |
|
|
nn.BatchNorm1d(256), |
|
|
nn.LeakyReLU(), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(256, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.LeakyReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(128, 64), |
|
|
nn.BatchNorm1d(64), |
|
|
nn.LeakyReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(64, num_classes) |
|
|
) |
|
|
def forward(self, x): |
|
|
return self.network(x) |
|
|
|
|
|
|
|
|
scaler = joblib.load("scaler.pkl") |
|
|
label_to_idx = joblib.load("label_encoder.pkl") |
|
|
feature_columns = joblib.load("feature_columns.pkl") |
|
|
idx_to_label = {v: k for k, v in label_to_idx.items()} |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = CropClassifier(len(feature_columns), len(label_to_idx)).to(device) |
|
|
model.load_state_dict(torch.load("final_crop_model.pth", map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
uncertainty_threshold = 0.2 |
|
|
uncertain_class_idx = len(label_to_idx) |
|
|
idx_to_label[uncertain_class_idx] = "Uncertain" |
|
|
|
|
|
|
|
|
current_polygon_data = None |
|
|
|
|
|
def get_color_palette(n): |
|
|
if n <= 20: |
|
|
palette = list(matplotlib.colors.TABLEAU_COLORS.values()) + list(matplotlib.colors.CSS4_COLORS.values()) |
|
|
return palette[:n] |
|
|
else: |
|
|
return [matplotlib.colors.rgb2hex(matplotlib.cm.hsv(i/n)) for i in range(n)] |
|
|
|
|
|
def assign_crop_colors(unique_crops): |
|
|
palette = get_color_palette(len(unique_crops)) |
|
|
return {crop: palette[i] for i, crop in enumerate(unique_crops)} |
|
|
|
|
|
def get_valid_user_classes(province, season): |
|
|
"""Fetch valid classes based on province and season from crop_season_dict.""" |
|
|
try: |
|
|
user_classes = crop_season_dict.get(province, {}).get(season, []) |
|
|
return [cls for cls in user_classes if cls in label_to_idx] |
|
|
except: |
|
|
return [] |
|
|
|
|
|
|
|
|
def process_upload(file, province, season, date): |
|
|
if file is None: |
|
|
return "No file uploaded. Please upload a .tiff or .tif file.", None |
|
|
|
|
|
if not file.name.endswith(('.tiff', '.tif')): |
|
|
return "Unsupported file format. Please upload a .tiff or .tif file.", None |
|
|
|
|
|
|
|
|
try: |
|
|
with rasterio.open(file) as src: |
|
|
patch = src.read() |
|
|
transform = src.transform |
|
|
rows, cols = patch.shape[1], patch.shape[2] |
|
|
row_indices, col_indices = np.meshgrid(np.arange(rows), np.arange(cols), indexing='ij') |
|
|
lon, lat = xy(transform, row_indices, col_indices) |
|
|
|
|
|
lon_mask = np.array(lon).reshape(rows, cols) |
|
|
lat_mask = np.array(lat).reshape(rows, cols) |
|
|
except Exception as e: |
|
|
return f"Error reading GeoTIFF file: {str(e)}", None |
|
|
|
|
|
|
|
|
if len(patch.shape) != 3 or patch.shape[0] < 7: |
|
|
return "Invalid GeoTIFF file format. Expected at least 7 bands [r, g, b, rededge, nir, swr1, swr2].", None |
|
|
|
|
|
|
|
|
|
|
|
patch = np.transpose(patch, (1, 2, 0)) |
|
|
H, W, _ = patch.shape |
|
|
|
|
|
|
|
|
r, g, b = patch[..., 0], patch[..., 1], patch[..., 2] |
|
|
rgb = np.stack([r, g, b], axis=-1).astype(np.float32) |
|
|
rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6) |
|
|
|
|
|
|
|
|
|
|
|
pixels = [] |
|
|
for i in range(H): |
|
|
for j in range(W): |
|
|
pix = patch[i, j].astype(np.float32) |
|
|
red, green, blue, nir, swr1 = pix[0], pix[1], pix[2], pix[4], pix[5] |
|
|
pixels.append({ |
|
|
"Province": province, |
|
|
"Season": season, |
|
|
"Latitude": lat_mask[i, j], |
|
|
"Longitude": lon_mask[i, j], |
|
|
"NDVI": (nir - red) / (nir + red + 1e-6), |
|
|
"NDWI": (green - nir) / (green + nir + 1e-6), |
|
|
"NDBI": (swr1 - nir) / (swr1 + nir + 1e-6), |
|
|
"Red": red, |
|
|
"Green": green, |
|
|
"Blue": blue, |
|
|
"NIR": nir, |
|
|
"SWIR": swr1, |
|
|
"Date": date |
|
|
}) |
|
|
|
|
|
|
|
|
df = pd.DataFrame(pixels) |
|
|
try: |
|
|
df["Date"] = pd.to_datetime(df["Date"], dayfirst=True) |
|
|
except: |
|
|
return "Invalid date format. Please use DD/MM/YYYY.", None |
|
|
df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1) |
|
|
df["Month"] = df["Date"].dt.month |
|
|
df.drop(columns=["Date"], inplace=True) |
|
|
|
|
|
|
|
|
df = pd.get_dummies(df, columns=['Province', 'Season'], dummy_na=True) |
|
|
missing_cols = set(feature_columns) - set(df.columns) |
|
|
for col in missing_cols: |
|
|
df[col] = 0 |
|
|
df = df[feature_columns] |
|
|
df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps) |
|
|
|
|
|
|
|
|
try: |
|
|
X_scaled = scaler.transform(df) |
|
|
except Exception as e: |
|
|
return f"Error scaling features: {str(e)}", None |
|
|
X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(X_tensor) |
|
|
valid_user_classes = get_valid_user_classes(province, season) |
|
|
user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx] |
|
|
if user_class_indices: |
|
|
mask = torch.ones_like(outputs) * -1e10 |
|
|
for idx in user_class_indices: |
|
|
mask[:, idx] = 0 |
|
|
outputs = outputs + mask |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
max_probs, preds = torch.max(probs, dim=1) |
|
|
uncertain_mask = max_probs < uncertainty_threshold |
|
|
preds[uncertain_mask] = uncertain_class_idx |
|
|
preds = preds.cpu().numpy().reshape(H, W) |
|
|
|
|
|
|
|
|
unique_classes = np.unique(preds) |
|
|
color_map = assign_crop_colors([idx_to_label[cls] for cls in unique_classes]) |
|
|
mask_img = np.zeros((H, W, 3)) |
|
|
for cls, color in color_map.items(): |
|
|
mask_img[preds == label_to_idx.get(cls, uncertain_class_idx)] = matplotlib.colors.to_rgb(color) |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) |
|
|
ax1.imshow(rgb_norm) |
|
|
ax1.set_title("Original RGB Patch") |
|
|
ax1.axis("off") |
|
|
ax2.imshow(mask_img) |
|
|
ax2.set_title("Predicted Crop Classification") |
|
|
ax2.axis("off") |
|
|
legend_elements = [Patch(facecolor=color_map[idx_to_label[cls]], edgecolor='black', label=idx_to_label[cls]) for cls in unique_classes] |
|
|
fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.15, 0.5), title="Predicted Crops") |
|
|
plt.tight_layout() |
|
|
|
|
|
buf = BytesIO() |
|
|
plt.savefig(buf, format="png", bbox_inches="tight") |
|
|
plt.close() |
|
|
buf.seek(0) |
|
|
image = Image.open(buf) |
|
|
|
|
|
|
|
|
stats = "Prediction Statistics:\n" |
|
|
for cls in unique_classes: |
|
|
class_name = idx_to_label[cls] |
|
|
pixel_count = np.sum(preds == cls) |
|
|
percentage = (pixel_count / (H * W)) * 100 |
|
|
stats += f"{class_name}: {pixel_count} pixels ({percentage:.2f}%)\n" |
|
|
|
|
|
return stats, image |
|
|
|
|
|
|
|
|
def generate_grid_points(polygon, spacing_deg): |
|
|
min_lon, min_lat, max_lon, max_lat = polygon.bounds |
|
|
grid_points = [] |
|
|
point_id = 1 |
|
|
lat_step = spacing_deg / 2 |
|
|
lon_step = spacing_deg / 2 |
|
|
lat = min_lat |
|
|
while lat <= max_lat: |
|
|
lon = min_lon |
|
|
while lon <= max_lon: |
|
|
pt = Point(lon, lat) |
|
|
if polygon.contains(pt): |
|
|
is_spaced = True |
|
|
for existing_pt in grid_points: |
|
|
dist = ((existing_pt["latitude"] - lat) ** 2 + (existing_pt["longitude"] - lon) ** 2) ** 0.5 |
|
|
if dist < spacing_deg: |
|
|
is_spaced = False |
|
|
break |
|
|
if is_spaced: |
|
|
grid_points.append({ |
|
|
"point_id": point_id, |
|
|
"latitude": round(lat, 6), |
|
|
"longitude": round(lon, 6) |
|
|
}) |
|
|
point_id += 1 |
|
|
lon += lon_step |
|
|
lat += lat_step |
|
|
return grid_points |
|
|
|
|
|
def get_indices(lat, lon, date_str): |
|
|
try: |
|
|
point = ee.Geometry.Point([lon, lat]) |
|
|
date = datetime.strptime(date_str, "%d/%m/%Y") |
|
|
start = ee.Date(date.strftime('%Y-%m-%d')) |
|
|
end = ee.Date((date + timedelta(days=30)).strftime('%Y-%m-%d')) |
|
|
|
|
|
collection = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") |
|
|
.filterBounds(point) |
|
|
.filterDate(start, end) |
|
|
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10))) |
|
|
|
|
|
image = collection.median().clip(point) |
|
|
|
|
|
band_names = image.bandNames().getInfo() |
|
|
if not band_names: |
|
|
return None |
|
|
|
|
|
B2 = image.select('B2') |
|
|
B3 = image.select('B3') |
|
|
B4 = image.select('B4') |
|
|
B8 = image.select('B8') |
|
|
B11 = image.select('B11') |
|
|
|
|
|
ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI') |
|
|
ndwi = image.normalizedDifference(['B3', 'B8']).rename('NDWI') |
|
|
evi = image.expression( |
|
|
'2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))', |
|
|
{'NIR': B8, 'RED': B4, 'BLUE': B2}).rename('EVI') |
|
|
gndvi = image.normalizedDifference(['B8', 'B3']).rename('GNDVI') |
|
|
savi = image.expression( |
|
|
'((NIR - RED) / (NIR + RED + 0.5)) * 1.5', |
|
|
{'NIR': B8, 'RED': B4}).rename('SAVI') |
|
|
|
|
|
all_bands = image.addBands([ndvi, ndwi, evi, gndvi, savi]) |
|
|
|
|
|
values = all_bands.reduceRegion( |
|
|
reducer=ee.Reducer.first(), |
|
|
geometry=point, |
|
|
scale=10, |
|
|
maxPixels=1e8 |
|
|
).getInfo() |
|
|
|
|
|
return { |
|
|
'NDVI': values.get('NDVI', 0.0), |
|
|
'NDWI': values.get('NDWI', 0.0), |
|
|
'EVI': values.get('EVI', 0.0), |
|
|
'GNDVI': values.get('GNDVI', 0.0), |
|
|
'SAVI': values.get('SAVI', 0.0), |
|
|
'Red': values.get('B4', 0.0), |
|
|
'Green': values.get('B3', 0.0), |
|
|
'Blue': values.get('B2', 0.0), |
|
|
'NIR': values.get('B8', 0.0), |
|
|
'SWIR': values.get('B11', 0.0) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Error fetching indices for lat={lat}, lon={lon}: {str(e)}") |
|
|
return None |
|
|
|
|
|
def predict_crop_description(point, static_features, scaler, feature_columns, province, season): |
|
|
df = pd.DataFrame([{ |
|
|
**static_features, |
|
|
"Latitude": point["latitude"], |
|
|
"Longitude": point["longitude"], |
|
|
"Date": static_features["Date"] |
|
|
}]) |
|
|
df["Date"] = pd.to_datetime(df["Date"], dayfirst=True) |
|
|
df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1) |
|
|
df["Month"] = df["Date"].dt.month |
|
|
df.drop(columns=["Date"], inplace=True) |
|
|
df = pd.get_dummies(df) |
|
|
for col in feature_columns: |
|
|
if col not in df.columns: |
|
|
df[col] = 0 |
|
|
df = df[feature_columns] |
|
|
df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps) |
|
|
scaled = scaler.transform(df) |
|
|
X_tensor = torch.tensor(scaled, dtype=torch.float32).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(X_tensor) |
|
|
valid_user_classes = get_valid_user_classes(province, season) |
|
|
user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx] |
|
|
if user_class_indices: |
|
|
mask = torch.ones_like(outputs) * -1e10 |
|
|
for idx in user_class_indices: |
|
|
mask[:, idx] = 0 |
|
|
outputs = outputs + mask |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
max_probs, preds = torch.max(probs, dim=1) |
|
|
uncertain_mask = max_probs < uncertainty_threshold |
|
|
preds[uncertain_mask] = uncertain_class_idx |
|
|
return idx_to_label[preds.cpu().numpy()[0]] |
|
|
|
|
|
def create_interactive_map(): |
|
|
m = folium.Map(location=[30.809, 73.45], zoom_start=12) |
|
|
Draw( |
|
|
export=True, |
|
|
filename='polygon.geojson', |
|
|
draw_options={ |
|
|
"polyline": False, |
|
|
"rectangle": True, |
|
|
"circle": True, |
|
|
"circlemarker": False, |
|
|
"marker": False, |
|
|
"polygon": True |
|
|
} |
|
|
).add_to(m) |
|
|
return m._repr_html_() |
|
|
|
|
|
def select_polygon(geojson_file): |
|
|
global current_polygon_data |
|
|
if not geojson_file: |
|
|
return "β No GeoJSON file uploaded. Please draw a polygon, export it, and upload the file." |
|
|
|
|
|
try: |
|
|
with open(geojson_file.name, 'r') as f: |
|
|
geojson_data = json.load(f) |
|
|
|
|
|
if geojson_data.get('type') == 'FeatureCollection': |
|
|
features = geojson_data.get('features', []) |
|
|
for feature in features: |
|
|
if feature.get('geometry', {}).get('type') == 'Polygon': |
|
|
current_polygon_data = feature |
|
|
return "β
Polygon selected successfully!" |
|
|
return "β No valid polygon found in the GeoJSON file." |
|
|
except Exception as e: |
|
|
return f"Error reading GeoJSON file: {str(e)}" |
|
|
|
|
|
def process_polygon_prediction(spacing_m, province, season, date, geojson_file): |
|
|
global current_polygon_data |
|
|
|
|
|
try: |
|
|
datetime.strptime(date, "%d/%m/%Y") |
|
|
except ValueError: |
|
|
return "Invalid date format. Please use DD/MM/YYYY.", None, None |
|
|
|
|
|
if not current_polygon_data: |
|
|
return "β No polygon selected. Please draw a polygon, export it as GeoJSON, and upload it.", None, None |
|
|
|
|
|
try: |
|
|
polygon = shape(current_polygon_data['geometry']) |
|
|
except Exception as e: |
|
|
return f"Error parsing polygon: {str(e)}", None, None |
|
|
|
|
|
spacing_deg = spacing_m / 111320.0 |
|
|
points = generate_grid_points(polygon, spacing_deg) |
|
|
print(f"Number of points selected: {len(points)}") |
|
|
|
|
|
if not points: |
|
|
return "No points generated within the polygon. Try increasing the spacing.", None, None |
|
|
|
|
|
predicted_points = [] |
|
|
static_features = { |
|
|
"Province": province, |
|
|
"Season": season, |
|
|
"Date": date |
|
|
} |
|
|
|
|
|
for i, point in enumerate(points, 1): |
|
|
indices = get_indices(point["latitude"], point["longitude"], date) |
|
|
print(f"GEE started for point {i} at lat={point['latitude']}, lon={point['longitude']}") |
|
|
if indices: |
|
|
print(f"GEE values fetched for point {i}") |
|
|
static_features.update({ |
|
|
"NDVI": indices["NDVI"], |
|
|
"NDWI": indices["NDWI"], |
|
|
"EVI": indices["EVI"], |
|
|
"GNDVI": indices["GNDVI"], |
|
|
"SAVI": indices["SAVI"], |
|
|
"Red": indices["Red"], |
|
|
"Green": indices["Green"], |
|
|
"Blue": indices["Blue"], |
|
|
"NIR": indices["NIR"], |
|
|
"SWIR": indices["SWIR"] |
|
|
}) |
|
|
crop = predict_crop_description(point, static_features, scaler, feature_columns, province, season) |
|
|
point.update({ |
|
|
"crop": crop, |
|
|
"NDVI": indices["NDVI"], |
|
|
"NDWI": indices["NDWI"], |
|
|
"EVI": indices["EVI"], |
|
|
"GNDVI": indices["GNDVI"], |
|
|
"SAVI": indices["SAVI"] |
|
|
}) |
|
|
predicted_points.append(point) |
|
|
|
|
|
if not predicted_points: |
|
|
return "No valid data found for any grid points.", None, None |
|
|
|
|
|
pred_df = pd.DataFrame(predicted_points) |
|
|
unique_crops = pred_df['crop'].unique() |
|
|
crop_colors = assign_crop_colors(unique_crops) |
|
|
|
|
|
center_lat = sum(pt["latitude"] for pt in predicted_points) / len(predicted_points) |
|
|
center_lon = sum(pt["longitude"] for pt in predicted_points) / len(predicted_points) |
|
|
pred_map = folium.Map(location=[center_lat, center_lon], zoom_start=12) |
|
|
|
|
|
folium.GeoJson( |
|
|
current_polygon_data, |
|
|
style_function=lambda x: {'color': 'red', 'weight': 3, 'fill': False} |
|
|
).add_to(pred_map) |
|
|
|
|
|
for pt in predicted_points: |
|
|
crop_type = pt.get("crop", "Other") |
|
|
color = crop_colors.get(crop_type, "#808080") |
|
|
folium.Circle( |
|
|
location=[pt["latitude"], pt["longitude"]], |
|
|
radius=spacing_m/2, |
|
|
color='black', |
|
|
weight=1, |
|
|
fill=True, |
|
|
fillColor=color, |
|
|
fillOpacity=0.7, |
|
|
popup=f"Crop: {crop_type}<br>Lat: {pt['latitude']:.4f}<br>Lon: {pt['longitude']:.4f}<br>NDVI: {pt['NDVI']:.3f}<br>NDWI: {pt['NDWI']:.3f}<br>EVI: {pt['EVI']:.3f}<br>GNDVI: {pt['GNDVI']:.3f}<br>SAVI: {pt['SAVI']:.3f}", |
|
|
tooltip=crop_type |
|
|
).add_to(pred_map) |
|
|
|
|
|
legend_html = ''' |
|
|
<div style="position: fixed; bottom: 50px; left: 50px; width: 180px; |
|
|
background-color: white; border:2px solid grey; z-index:9999; |
|
|
font-size:14px; padding: 10px; border-radius: 5px;"> |
|
|
<p style="margin: 0 0 10px 0; font-weight:bold;">πΎ Crop Types</p> |
|
|
''' |
|
|
for crop in unique_crops: |
|
|
color = crop_colors[crop] |
|
|
legend_html += f'<p style="margin: 5px 0;"><span style="color:{color}; font-size:16px;">β</span> {crop}</p>' |
|
|
legend_html += '</div>' |
|
|
pred_map.get_root().html.add_child(folium.Element(legend_html)) |
|
|
|
|
|
crop_stats = pred_df['crop'].value_counts() |
|
|
stats = f"β
Polygon processed successfully!\n\nCrop Distribution (Province: {province}, Season: {season}):\n" |
|
|
for crop, count in crop_stats.items(): |
|
|
percentage = (count / len(predicted_points)) * 100 |
|
|
stats += f"{crop}: {count} points ({percentage:.1f}%)\n" |
|
|
for index in ['NDVI', 'NDWI', 'EVI', 'GNDVI', 'SAVI']: |
|
|
avg = pred_df[index].mean() |
|
|
stats += f"Average {index}: {avg:.3f}\n" |
|
|
|
|
|
csv_file_path = f"crop_predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" |
|
|
try: |
|
|
pred_df.to_csv(csv_file_path, index=False) |
|
|
except Exception as e: |
|
|
print(f"Error creating CSV file: {str(e)}") |
|
|
csv_file_path = None |
|
|
|
|
|
return stats, pred_map._repr_html_(), csv_file_path |
|
|
|
|
|
|
|
|
def predict_instance(province, season, latitude, longitude, date, ndvi, ndwi, ndbi, red, green, blue, nir, swir): |
|
|
static_features = { |
|
|
"Province": province, |
|
|
"Season": season, |
|
|
"NDVI": ndvi, |
|
|
"NDWI": ndwi, |
|
|
"NDBI": ndbi, |
|
|
"Red": red, |
|
|
"Green": green, |
|
|
"Blue": blue, |
|
|
"NIR": nir, |
|
|
"SWIR": swir, |
|
|
"Date": date |
|
|
} |
|
|
crop = predict_crop_description({"latitude": latitude, "longitude": longitude}, static_features, scaler, feature_columns, province, season) |
|
|
return f"{crop}" |
|
|
|
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
sample_dir = Path("samples") |
|
|
sample_files = { |
|
|
"Sample 1": sample_dir / "sample1.tif", |
|
|
"Sample 2": sample_dir / "sample2.tif" |
|
|
} |
|
|
|
|
|
|
|
|
def load_sample_and_predict(sample_name, province, season, date): |
|
|
file_path = sample_files[sample_name] |
|
|
return process_upload(file_path, province, season, date) |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Crop Predictor", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# πΎ Crop Predictor") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("π€ Upload"): |
|
|
gr.Markdown("Upload a .tiff or .tif file with bands [r, g, b, rededge, nir, swr1, swr2]") |
|
|
|
|
|
file_input = gr.File(label="Upload .tiff/.tif file", file_types=[".tiff", ".tif"]) |
|
|
|
|
|
with gr.Row(): |
|
|
province = gr.Textbox(label="Province", value="Punjab") |
|
|
season = gr.Textbox(label="Season", value="Rabi") |
|
|
|
|
|
with gr.Row(): |
|
|
date = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023") |
|
|
|
|
|
upload_btn = gr.Button("π Predict", variant="primary") |
|
|
output_stats = gr.Textbox(label="Prediction Statistics", lines=10) |
|
|
output_image = gr.Image(label="Prediction Result") |
|
|
|
|
|
upload_btn.click( |
|
|
fn=process_upload, |
|
|
inputs=[file_input, province, season, date], |
|
|
outputs=[output_stats, output_image] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Or try with a sample file:") |
|
|
with gr.Row(): |
|
|
for name in sample_files: |
|
|
gr.Button(name).click( |
|
|
fn=load_sample_and_predict, |
|
|
inputs=[gr.State(name), province, season, date], |
|
|
outputs=[output_stats, output_image] |
|
|
) |
|
|
|
|
|
with gr.TabItem("πΊοΈ Map"): |
|
|
gr.Markdown(""" |
|
|
## Interactive Polygon Crop Prediction |
|
|
|
|
|
**Instructions:** |
|
|
1. Draw a polygon on the map below using the polygon tool. |
|
|
2. Click the "Export" button on the map to save the polygon as a GeoJSON file (polygon.geojson). |
|
|
3. Upload the exported GeoJSON file using the file input below. |
|
|
4. Adjust settings and click "π Predict" to process. |
|
|
""") |
|
|
|
|
|
map_html = gr.HTML(create_interactive_map, label="Draw Your Polygon Here") |
|
|
|
|
|
with gr.Row(): |
|
|
geojson_input = gr.File(label="Upload Exported GeoJSON File") |
|
|
select_btn = gr.Button("π― Select My Polygon", variant="secondary") |
|
|
spacing = gr.Slider( |
|
|
label="Grid Spacing (meters)", |
|
|
minimum=10, maximum=1000, value=30, step=100 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
province_map = gr.Textbox(label="Province", value="Punjab") |
|
|
season_map = gr.Textbox(label="Season", value="Multan") |
|
|
date_map = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023") |
|
|
|
|
|
polygon_status = gr.Textbox( |
|
|
label="Selection Status", |
|
|
value="β³ Please draw a polygon, export it, and upload the GeoJSON file.", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("π Predict Crops", variant="primary", size="lg") |
|
|
|
|
|
output_map_stats = gr.Textbox(label="Prediction Results", lines=10) |
|
|
output_map = gr.HTML(label="Crop Prediction Map") |
|
|
output_csv = gr.File(label="π₯ Download Results CSV") |
|
|
|
|
|
select_btn.click( |
|
|
fn=select_polygon, |
|
|
inputs=[geojson_input], |
|
|
outputs=polygon_status |
|
|
) |
|
|
|
|
|
predict_btn.click( |
|
|
fn=process_polygon_prediction, |
|
|
inputs=[spacing, province_map, season_map, date_map, geojson_input], |
|
|
outputs=[output_map_stats, output_map, output_csv] |
|
|
) |
|
|
|
|
|
with gr.TabItem("π Instance"): |
|
|
gr.Markdown("## Single Point Prediction") |
|
|
gr.Markdown("Enter features manually for a single point prediction") |
|
|
|
|
|
with gr.Row(): |
|
|
province_inst = gr.Textbox(label="Province", value="Punjab") |
|
|
season_inst = gr.Textbox(label="Season", value="Rabi") |
|
|
|
|
|
with gr.Row(): |
|
|
latitude_inst = gr.Number(label="Latitude", value=30.809) |
|
|
longitude_inst = gr.Number(label="Longitude", value=73.450) |
|
|
date_inst = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023") |
|
|
|
|
|
gr.Markdown("### Spectral Indices") |
|
|
with gr.Row(): |
|
|
ndvi_inst = gr.Number(label="NDVI", value=0.65) |
|
|
ndwi_inst = gr.Number(label="NDWI", value=-2.0) |
|
|
ndbi_inst = gr.Number(label="NDBI", value=0.10) |
|
|
|
|
|
gr.Markdown("### Band Values") |
|
|
with gr.Row(): |
|
|
red_inst = gr.Number(label="Red", value=678) |
|
|
green_inst = gr.Number(label="Green", value=732) |
|
|
blue_inst = gr.Number(label="Blue", value=620) |
|
|
|
|
|
with gr.Row(): |
|
|
nir_inst = gr.Number(label="NIR", value=3000) |
|
|
swir_inst = gr.Number(label="SWIR", value=1800) |
|
|
|
|
|
instance_btn = gr.Button("π Predict", variant="primary") |
|
|
output_instance = gr.Textbox(label="Prediction Result", lines=3) |
|
|
|
|
|
instance_btn.click( |
|
|
fn=predict_instance, |
|
|
inputs=[province_inst, season_inst, latitude_inst, longitude_inst, |
|
|
date_inst, ndvi_inst, ndwi_inst, ndbi_inst, red_inst, |
|
|
green_inst, blue_inst, nir_inst, swir_inst], |
|
|
outputs=output_instance |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |