Abdulahad79 commited on
Commit
39aaa38
Β·
verified Β·
1 Parent(s): 0a592d0

Upload 4 files

Browse files
Files changed (4) hide show
  1. feature_columns.pkl +3 -0
  2. label_encoder.pkl +3 -0
  3. main.py +717 -0
  4. scaler.pkl +3 -0
feature_columns.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9edecbbd51d8e519880bd49f32000dc1ca4c66b4081dda095be6c2ad8d5f4a0
3
+ size 274
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5dba0f890adde9d70d9ea3596a990a3eb3b14c692cdcf5a2c00761be1a3a500
3
+ size 448
main.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import torch.nn as nn
6
+ import joblib
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.patches import Patch
9
+ import matplotlib
10
+ from shapely.geometry import shape, Point
11
+ import folium
12
+ from folium.plugins import Draw
13
+ from io import BytesIO
14
+ import base64
15
+ import json
16
+ import os
17
+ from PIL import Image
18
+ import ee
19
+ from datetime import datetime, timedelta
20
+ import rasterio
21
+ from rasterio.transform import xy
22
+
23
+ # Initialize Earth Engine
24
+ try:
25
+ ee.Initialize(project='artful-striker-466710-b3')
26
+ except Exception as e:
27
+ print(f"Error initializing GEE: {str(e)}")
28
+ ee.Authenticate()
29
+ ee.Initialize(project='artful-striker-466710-b3')
30
+
31
+ # Define crop season dictionary
32
+ crop_season_dict = {
33
+ "Punjab": {
34
+ "Rabi": [
35
+ "wheat", "barley", "gram (chickpea)", "lentil", "mustard", "rapeseed mustard",
36
+ "linseed", "peas", "garlic", "onion", "coriander", "fennel", "potato",
37
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
38
+ ],
39
+ "Kharif": [
40
+ "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "sorghum", "sunflower",
41
+ "groundnuts", "okra", "tomato", "chillies", "banana", "mango",
42
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
43
+ ]
44
+ },
45
+ "Sindh": {
46
+ "Rabi": [
47
+ "wheat", "barley", "peas", "gram (chickpea)", "mustard", "onion", "garlic", "spinach",
48
+ "coriander", "potato", "fennel", "turnip",
49
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
50
+ ],
51
+ "Kharif": [
52
+ "cotton", "rice", "sugarcane", "maize", "sesame", "millet", "okra", "tomato",
53
+ "chillies", "banana", "mango", "sunflower", "guava",
54
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
55
+ ]
56
+ },
57
+ "Balochistan": {
58
+ "Rabi": [
59
+ "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "potato",
60
+ "onion", "coriander", "fallow (agriculture)", "water", "barren", "shrubs", "forest"
61
+ ],
62
+ "Kharif": [
63
+ "maize", "rice", "millet", "sorghum", "peach", "apple", "grapes", "tomato",
64
+ "chillies", "pomegranate", "groundnuts", "sunflower",
65
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
66
+ ]
67
+ },
68
+ "Khyber Pakhtunkhwa": {
69
+ "Rabi": [
70
+ "wheat", "barley", "gram (chickpea)", "lentil", "peas", "mustard", "onion",
71
+ "garlic", "turnip", "potato", "coriander",
72
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
73
+ ],
74
+ "Kharif": [
75
+ "maize", "rice", "sugarcane", "tomato", "chillies", "peach", "plum", "apricot",
76
+ "apple", "mango", "sunflower", "okra", "sesame",
77
+ "fallow (agriculture)", "water", "barren", "shrubs", "forest"
78
+ ]
79
+ }
80
+ }
81
+
82
+ # Define model
83
+ class CropClassifier(nn.Module):
84
+ def __init__(self, input_size, num_classes):
85
+ super(CropClassifier, self).__init__()
86
+ self.network = nn.Sequential(
87
+ nn.Linear(input_size, 512),
88
+ nn.BatchNorm1d(512),
89
+ nn.LeakyReLU(),
90
+ nn.Dropout(0.4),
91
+ nn.Linear(512, 256),
92
+ nn.BatchNorm1d(256),
93
+ nn.LeakyReLU(),
94
+ nn.Dropout(0.3),
95
+ nn.Linear(256, 128),
96
+ nn.BatchNorm1d(128),
97
+ nn.LeakyReLU(),
98
+ nn.Dropout(0.2),
99
+ nn.Linear(128, 64),
100
+ nn.BatchNorm1d(64),
101
+ nn.LeakyReLU(),
102
+ nn.Dropout(0.1),
103
+ nn.Linear(64, num_classes)
104
+ )
105
+ def forward(self, x):
106
+ return self.network(x)
107
+
108
+ # Load saved objects
109
+ scaler = joblib.load("scaler.pkl")
110
+ label_to_idx = joblib.load("label_encoder.pkl")
111
+ feature_columns = joblib.load("feature_columns.pkl")
112
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
113
+
114
+ # Load model
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ model = CropClassifier(len(feature_columns), len(label_to_idx)).to(device)
117
+ model.load_state_dict(torch.load("final_crop_model.pth", map_location=device))
118
+ model.eval()
119
+
120
+ # Uncertainty threshold
121
+ uncertainty_threshold = 0.2
122
+ uncertain_class_idx = len(label_to_idx)
123
+ idx_to_label[uncertain_class_idx] = "Uncertain"
124
+
125
+ # Global variable to store current polygon
126
+ current_polygon_data = None
127
+
128
+ def get_color_palette(n):
129
+ if n <= 20:
130
+ palette = list(matplotlib.colors.TABLEAU_COLORS.values()) + list(matplotlib.colors.CSS4_COLORS.values())
131
+ return palette[:n]
132
+ else:
133
+ return [matplotlib.colors.rgb2hex(matplotlib.cm.hsv(i/n)) for i in range(n)]
134
+
135
+ def assign_crop_colors(unique_crops):
136
+ palette = get_color_palette(len(unique_crops))
137
+ return {crop: palette[i] for i, crop in enumerate(unique_crops)}
138
+
139
+ def get_valid_user_classes(province, season):
140
+ """Fetch valid classes based on province and season from crop_season_dict."""
141
+ try:
142
+ user_classes = crop_season_dict.get(province, {}).get(season, [])
143
+ return [cls for cls in user_classes if cls in label_to_idx]
144
+ except:
145
+ return []
146
+
147
+ # --- Upload Processing Function ---
148
+ def process_upload(file, province, season, date):
149
+ if file is None:
150
+ return "No file uploaded. Please upload a .tiff or .tif file.", None
151
+
152
+ if not file.name.endswith(('.tiff', '.tif')):
153
+ return "Unsupported file format. Please upload a .tiff or .tif file.", None
154
+
155
+ # Load GeoTIFF file
156
+ try:
157
+ with rasterio.open(file) as src:
158
+ patch = src.read() # Shape: (bands, height, width)
159
+ transform = src.transform
160
+ rows, cols = patch.shape[1], patch.shape[2]
161
+ row_indices, col_indices = np.meshgrid(np.arange(rows), np.arange(cols), indexing='ij')
162
+ lon, lat = xy(transform, row_indices, col_indices)
163
+ # Convert lon, lat to 2D arrays (shape: [rows, cols])
164
+ lon_mask = np.array(lon).reshape(rows, cols)
165
+ lat_mask = np.array(lat).reshape(rows, cols)
166
+ except Exception as e:
167
+ return f"Error reading GeoTIFF file: {str(e)}", None
168
+
169
+ # Validate the number of bands
170
+ if len(patch.shape) != 3 or patch.shape[0] < 7:
171
+ return "Invalid GeoTIFF file format. Expected at least 7 bands [r, g, b, rededge, nir, swr1, swr2].", None
172
+
173
+ # # Resize patch to 500x500 if necessary
174
+ # patch = patch[:, :500, :500]
175
+ patch = np.transpose(patch, (1, 2, 0)) # Shape: (H, W, 7)
176
+ H, W, _ = patch.shape
177
+
178
+ # Extract RGB for visualization
179
+ r, g, b = patch[..., 0], patch[..., 1], patch[..., 2]
180
+ rgb = np.stack([r, g, b], axis=-1).astype(np.float32)
181
+ rgb_norm = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)
182
+
183
+
184
+ # Process pixels for prediction
185
+ pixels = []
186
+ for i in range(H):
187
+ for j in range(W):
188
+ pix = patch[i, j].astype(np.float32)
189
+ red, green, blue, nir, swr1 = pix[0], pix[1], pix[2], pix[4], pix[5]
190
+ pixels.append({
191
+ "Province": province,
192
+ "Season": season,
193
+ "Latitude": lat_mask[i, j],
194
+ "Longitude": lon_mask[i, j],
195
+ "NDVI": (nir - red) / (nir + red + 1e-6),
196
+ "NDWI": (green - nir) / (green + nir + 1e-6),
197
+ "NDBI": (swr1 - nir) / (swr1 + nir + 1e-6),
198
+ "Red": red,
199
+ "Green": green,
200
+ "Blue": blue,
201
+ "NIR": nir,
202
+ "SWIR": swr1,
203
+ "Date": date
204
+ })
205
+
206
+ # Create DataFrame and preprocess
207
+ df = pd.DataFrame(pixels)
208
+ try:
209
+ df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
210
+ except:
211
+ return "Invalid date format. Please use DD/MM/YYYY.", None
212
+ df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
213
+ df["Month"] = df["Date"].dt.month
214
+ df.drop(columns=["Date"], inplace=True)
215
+
216
+ # Dummy encoding and feature alignment
217
+ df = pd.get_dummies(df, columns=['Province', 'Season'], dummy_na=True)
218
+ missing_cols = set(feature_columns) - set(df.columns)
219
+ for col in missing_cols:
220
+ df[col] = 0
221
+ df = df[feature_columns]
222
+ df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
223
+
224
+ # Model prediction
225
+ try:
226
+ X_scaled = scaler.transform(df)
227
+ except Exception as e:
228
+ return f"Error scaling features: {str(e)}", None
229
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device)
230
+ with torch.no_grad():
231
+ outputs = model(X_tensor)
232
+ valid_user_classes = get_valid_user_classes(province, season)
233
+ user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
234
+ if user_class_indices:
235
+ mask = torch.ones_like(outputs) * -1e10
236
+ for idx in user_class_indices:
237
+ mask[:, idx] = 0
238
+ outputs = outputs + mask
239
+ probs = torch.softmax(outputs, dim=1)
240
+ max_probs, preds = torch.max(probs, dim=1)
241
+ uncertain_mask = max_probs < uncertainty_threshold
242
+ preds[uncertain_mask] = uncertain_class_idx
243
+ preds = preds.cpu().numpy().reshape(H, W)
244
+
245
+ # Create visualization
246
+ unique_classes = np.unique(preds)
247
+ color_map = assign_crop_colors([idx_to_label[cls] for cls in unique_classes])
248
+ mask_img = np.zeros((H, W, 3))
249
+ for cls, color in color_map.items():
250
+ mask_img[preds == label_to_idx.get(cls, uncertain_class_idx)] = matplotlib.colors.to_rgb(color)
251
+
252
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
253
+ ax1.imshow(rgb_norm)
254
+ ax1.set_title("Original RGB Patch")
255
+ ax1.axis("off")
256
+ ax2.imshow(mask_img)
257
+ ax2.set_title("Predicted Crop Classification")
258
+ ax2.axis("off")
259
+ legend_elements = [Patch(facecolor=color_map[idx_to_label[cls]], edgecolor='black', label=idx_to_label[cls]) for cls in unique_classes]
260
+ fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.15, 0.5), title="Predicted Crops")
261
+ plt.tight_layout()
262
+
263
+ buf = BytesIO()
264
+ plt.savefig(buf, format="png", bbox_inches="tight")
265
+ plt.close()
266
+ buf.seek(0)
267
+ image = Image.open(buf)
268
+
269
+ # Generate prediction statistics
270
+ stats = "Prediction Statistics:\n"
271
+ for cls in unique_classes:
272
+ class_name = idx_to_label[cls]
273
+ pixel_count = np.sum(preds == cls)
274
+ percentage = (pixel_count / (H * W)) * 100
275
+ stats += f"{class_name}: {pixel_count} pixels ({percentage:.2f}%)\n"
276
+
277
+ return stats, image
278
+
279
+ # --- Map Interface ---
280
+ def generate_grid_points(polygon, spacing_deg):
281
+ min_lon, min_lat, max_lon, max_lat = polygon.bounds
282
+ grid_points = []
283
+ point_id = 1
284
+ lat_step = spacing_deg / 2
285
+ lon_step = spacing_deg / 2
286
+ lat = min_lat
287
+ while lat <= max_lat:
288
+ lon = min_lon
289
+ while lon <= max_lon:
290
+ pt = Point(lon, lat)
291
+ if polygon.contains(pt):
292
+ is_spaced = True
293
+ for existing_pt in grid_points:
294
+ dist = ((existing_pt["latitude"] - lat) ** 2 + (existing_pt["longitude"] - lon) ** 2) ** 0.5
295
+ if dist < spacing_deg:
296
+ is_spaced = False
297
+ break
298
+ if is_spaced:
299
+ grid_points.append({
300
+ "point_id": point_id,
301
+ "latitude": round(lat, 6),
302
+ "longitude": round(lon, 6)
303
+ })
304
+ point_id += 1
305
+ lon += lon_step
306
+ lat += lat_step
307
+ return grid_points
308
+
309
+ def get_indices(lat, lon, date_str):
310
+ try:
311
+ point = ee.Geometry.Point([lon, lat])
312
+ date = datetime.strptime(date_str, "%d/%m/%Y")
313
+ start = ee.Date(date.strftime('%Y-%m-%d'))
314
+ end = ee.Date((date + timedelta(days=30)).strftime('%Y-%m-%d'))
315
+
316
+ collection = (ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
317
+ .filterBounds(point)
318
+ .filterDate(start, end)
319
+ .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10)))
320
+
321
+ image = collection.median().clip(point)
322
+
323
+ band_names = image.bandNames().getInfo()
324
+ if not band_names:
325
+ return None
326
+
327
+ B2 = image.select('B2') # Blue
328
+ B3 = image.select('B3') # Green
329
+ B4 = image.select('B4') # Red
330
+ B8 = image.select('B8') # NIR
331
+ B11 = image.select('B11') # SWIR
332
+
333
+ ndvi = image.normalizedDifference(['B8', 'B4']).rename('NDVI')
334
+ ndwi = image.normalizedDifference(['B3', 'B8']).rename('NDWI')
335
+ evi = image.expression(
336
+ '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
337
+ {'NIR': B8, 'RED': B4, 'BLUE': B2}).rename('EVI')
338
+ gndvi = image.normalizedDifference(['B8', 'B3']).rename('GNDVI')
339
+ savi = image.expression(
340
+ '((NIR - RED) / (NIR + RED + 0.5)) * 1.5',
341
+ {'NIR': B8, 'RED': B4}).rename('SAVI')
342
+
343
+ all_bands = image.addBands([ndvi, ndwi, evi, gndvi, savi])
344
+
345
+ values = all_bands.reduceRegion(
346
+ reducer=ee.Reducer.first(),
347
+ geometry=point,
348
+ scale=10,
349
+ maxPixels=1e8
350
+ ).getInfo()
351
+
352
+ return {
353
+ 'NDVI': values.get('NDVI', 0.0),
354
+ 'NDWI': values.get('NDWI', 0.0),
355
+ 'EVI': values.get('EVI', 0.0),
356
+ 'GNDVI': values.get('GNDVI', 0.0),
357
+ 'SAVI': values.get('SAVI', 0.0),
358
+ 'Red': values.get('B4', 0.0),
359
+ 'Green': values.get('B3', 0.0),
360
+ 'Blue': values.get('B2', 0.0),
361
+ 'NIR': values.get('B8', 0.0),
362
+ 'SWIR': values.get('B11', 0.0)
363
+ }
364
+ except Exception as e:
365
+ print(f"Error fetching indices for lat={lat}, lon={lon}: {str(e)}")
366
+ return None
367
+
368
+ def predict_crop_description(point, static_features, scaler, feature_columns, province, season):
369
+ df = pd.DataFrame([{
370
+ **static_features,
371
+ "Latitude": point["latitude"],
372
+ "Longitude": point["longitude"],
373
+ "Date": static_features["Date"]
374
+ }])
375
+ df["Date"] = pd.to_datetime(df["Date"], dayfirst=True)
376
+ df["HalfMonth"] = df["Date"].dt.day.apply(lambda x: 0 if x <= 15 else 1)
377
+ df["Month"] = df["Date"].dt.month
378
+ df.drop(columns=["Date"], inplace=True)
379
+ df = pd.get_dummies(df)
380
+ for col in feature_columns:
381
+ if col not in df.columns:
382
+ df[col] = 0
383
+ df = df[feature_columns]
384
+ df = df.replace([np.inf, -np.inf], np.finfo(np.float32).eps)
385
+ scaled = scaler.transform(df)
386
+ X_tensor = torch.tensor(scaled, dtype=torch.float32).to(device)
387
+ with torch.no_grad():
388
+ outputs = model(X_tensor)
389
+ valid_user_classes = get_valid_user_classes(province, season)
390
+ user_class_indices = [label_to_idx[cls] for cls in valid_user_classes if cls in label_to_idx]
391
+ if user_class_indices:
392
+ mask = torch.ones_like(outputs) * -1e10
393
+ for idx in user_class_indices:
394
+ mask[:, idx] = 0
395
+ outputs = outputs + mask
396
+ probs = torch.softmax(outputs, dim=1)
397
+ max_probs, preds = torch.max(probs, dim=1)
398
+ uncertain_mask = max_probs < uncertainty_threshold
399
+ preds[uncertain_mask] = uncertain_class_idx
400
+ return idx_to_label[preds.cpu().numpy()[0]]
401
+
402
+ def create_interactive_map():
403
+ m = folium.Map(location=[30.809, 73.45], zoom_start=12)
404
+ Draw(
405
+ export=True,
406
+ filename='polygon.geojson',
407
+ draw_options={
408
+ "polyline": False,
409
+ "rectangle": True,
410
+ "circle": True,
411
+ "circlemarker": False,
412
+ "marker": False,
413
+ "polygon": True
414
+ }
415
+ ).add_to(m)
416
+ return m._repr_html_()
417
+
418
+ def select_polygon(geojson_file):
419
+ global current_polygon_data
420
+ if not geojson_file:
421
+ return "❌ No GeoJSON file uploaded. Please draw a polygon, export it, and upload the file."
422
+
423
+ try:
424
+ with open(geojson_file.name, 'r') as f:
425
+ geojson_data = json.load(f)
426
+
427
+ if geojson_data.get('type') == 'FeatureCollection':
428
+ features = geojson_data.get('features', [])
429
+ for feature in features:
430
+ if feature.get('geometry', {}).get('type') == 'Polygon':
431
+ current_polygon_data = feature
432
+ return "βœ… Polygon selected successfully!"
433
+ return "❌ No valid polygon found in the GeoJSON file."
434
+ except Exception as e:
435
+ return f"Error reading GeoJSON file: {str(e)}"
436
+
437
+ def process_polygon_prediction(spacing_m, province, season, date, geojson_file):
438
+ global current_polygon_data
439
+
440
+ try:
441
+ datetime.strptime(date, "%d/%m/%Y")
442
+ except ValueError:
443
+ return "Invalid date format. Please use DD/MM/YYYY.", None, None
444
+
445
+ if not current_polygon_data:
446
+ return "❌ No polygon selected. Please draw a polygon, export it as GeoJSON, and upload it.", None, None
447
+
448
+ try:
449
+ polygon = shape(current_polygon_data['geometry'])
450
+ except Exception as e:
451
+ return f"Error parsing polygon: {str(e)}", None, None
452
+
453
+ spacing_deg = spacing_m / 111320.0
454
+ points = generate_grid_points(polygon, spacing_deg)
455
+ print(f"Number of points selected: {len(points)}")
456
+
457
+ if not points:
458
+ return "No points generated within the polygon. Try increasing the spacing.", None, None
459
+
460
+ predicted_points = []
461
+ static_features = {
462
+ "Province": province,
463
+ "Season": season,
464
+ "Date": date
465
+ }
466
+
467
+ for i, point in enumerate(points, 1):
468
+ indices = get_indices(point["latitude"], point["longitude"], date)
469
+ print(f"GEE started for point {i} at lat={point['latitude']}, lon={point['longitude']}")
470
+ if indices:
471
+ print(f"GEE values fetched for point {i}")
472
+ static_features.update({
473
+ "NDVI": indices["NDVI"],
474
+ "NDWI": indices["NDWI"],
475
+ "EVI": indices["EVI"],
476
+ "GNDVI": indices["GNDVI"],
477
+ "SAVI": indices["SAVI"],
478
+ "Red": indices["Red"],
479
+ "Green": indices["Green"],
480
+ "Blue": indices["Blue"],
481
+ "NIR": indices["NIR"],
482
+ "SWIR": indices["SWIR"]
483
+ })
484
+ crop = predict_crop_description(point, static_features, scaler, feature_columns, province, season)
485
+ point.update({
486
+ "crop": crop,
487
+ "NDVI": indices["NDVI"],
488
+ "NDWI": indices["NDWI"],
489
+ "EVI": indices["EVI"],
490
+ "GNDVI": indices["GNDVI"],
491
+ "SAVI": indices["SAVI"]
492
+ })
493
+ predicted_points.append(point)
494
+
495
+ if not predicted_points:
496
+ return "No valid data found for any grid points.", None, None
497
+
498
+ pred_df = pd.DataFrame(predicted_points)
499
+ unique_crops = pred_df['crop'].unique()
500
+ crop_colors = assign_crop_colors(unique_crops)
501
+
502
+ center_lat = sum(pt["latitude"] for pt in predicted_points) / len(predicted_points)
503
+ center_lon = sum(pt["longitude"] for pt in predicted_points) / len(predicted_points)
504
+ pred_map = folium.Map(location=[center_lat, center_lon], zoom_start=12)
505
+
506
+ folium.GeoJson(
507
+ current_polygon_data,
508
+ style_function=lambda x: {'color': 'red', 'weight': 3, 'fill': False}
509
+ ).add_to(pred_map)
510
+
511
+ for pt in predicted_points:
512
+ crop_type = pt.get("crop", "Other")
513
+ color = crop_colors.get(crop_type, "#808080")
514
+ folium.Circle(
515
+ location=[pt["latitude"], pt["longitude"]],
516
+ radius=spacing_m/2,
517
+ color='black',
518
+ weight=1,
519
+ fill=True,
520
+ fillColor=color,
521
+ fillOpacity=0.7,
522
+ popup=f"Crop: {crop_type}<br>Lat: {pt['latitude']:.4f}<br>Lon: {pt['longitude']:.4f}<br>NDVI: {pt['NDVI']:.3f}<br>NDWI: {pt['NDWI']:.3f}<br>EVI: {pt['EVI']:.3f}<br>GNDVI: {pt['GNDVI']:.3f}<br>SAVI: {pt['SAVI']:.3f}",
523
+ tooltip=crop_type
524
+ ).add_to(pred_map)
525
+
526
+ legend_html = '''
527
+ <div style="position: fixed; bottom: 50px; left: 50px; width: 180px;
528
+ background-color: white; border:2px solid grey; z-index:9999;
529
+ font-size:14px; padding: 10px; border-radius: 5px;">
530
+ <p style="margin: 0 0 10px 0; font-weight:bold;">🌾 Crop Types</p>
531
+ '''
532
+ for crop in unique_crops:
533
+ color = crop_colors[crop]
534
+ legend_html += f'<p style="margin: 5px 0;"><span style="color:{color}; font-size:16px;">●</span> {crop}</p>'
535
+ legend_html += '</div>'
536
+ pred_map.get_root().html.add_child(folium.Element(legend_html))
537
+
538
+ crop_stats = pred_df['crop'].value_counts()
539
+ stats = f"βœ… Polygon processed successfully!\n\nCrop Distribution (Province: {province}, Season: {season}):\n"
540
+ for crop, count in crop_stats.items():
541
+ percentage = (count / len(predicted_points)) * 100
542
+ stats += f"{crop}: {count} points ({percentage:.1f}%)\n"
543
+ for index in ['NDVI', 'NDWI', 'EVI', 'GNDVI', 'SAVI']:
544
+ avg = pred_df[index].mean()
545
+ stats += f"Average {index}: {avg:.3f}\n"
546
+
547
+ csv_file_path = f"crop_predictions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
548
+ try:
549
+ pred_df.to_csv(csv_file_path, index=False)
550
+ except Exception as e:
551
+ print(f"Error creating CSV file: {str(e)}")
552
+ csv_file_path = None
553
+
554
+ return stats, pred_map._repr_html_(), csv_file_path
555
+
556
+ # --- Instance Interface ---
557
+ def predict_instance(province, season, latitude, longitude, date, ndvi, ndwi, ndbi, red, green, blue, nir, swir):
558
+ static_features = {
559
+ "Province": province,
560
+ "Season": season,
561
+ "NDVI": ndvi,
562
+ "NDWI": ndwi,
563
+ "NDBI": ndbi,
564
+ "Red": red,
565
+ "Green": green,
566
+ "Blue": blue,
567
+ "NIR": nir,
568
+ "SWIR": swir,
569
+ "Date": date
570
+ }
571
+ crop = predict_crop_description({"latitude": latitude, "longitude": longitude}, static_features, scaler, feature_columns, province, season)
572
+ return f"{crop}"
573
+
574
+ from pathlib import Path
575
+ import gradio as gr
576
+
577
+ # Sample file paths
578
+ sample_dir = Path("samples") # Ensure this directory exists with .tif files
579
+ sample_files = {
580
+ "Sample 1": sample_dir / "sample1.tif",
581
+ "Sample 2": sample_dir / "sample2.tif"
582
+ }
583
+
584
+ # Function to simulate upload when sample is clicked
585
+ def load_sample_and_predict(sample_name, province, season, date):
586
+ file_path = sample_files[sample_name]
587
+ return process_upload(file_path, province, season, date)
588
+
589
+ # --- Gradio Interface ---
590
+ with gr.Blocks(title="Crop Predictor", theme=gr.themes.Soft()) as demo:
591
+ gr.Markdown("# 🌾 Crop Predictor")
592
+
593
+ with gr.Tabs():
594
+ with gr.TabItem("πŸ“€ Upload"):
595
+ gr.Markdown("Upload a .tiff or .tif file with bands [r, g, b, rededge, nir, swr1, swr2]")
596
+
597
+ file_input = gr.File(label="Upload .tiff/.tif file", file_types=[".tiff", ".tif"])
598
+
599
+ with gr.Row():
600
+ province = gr.Textbox(label="Province", value="Punjab")
601
+ season = gr.Textbox(label="Season", value="Rabi")
602
+
603
+ with gr.Row():
604
+ date = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
605
+
606
+ upload_btn = gr.Button("πŸ” Predict", variant="primary")
607
+ output_stats = gr.Textbox(label="Prediction Statistics", lines=10)
608
+ output_image = gr.Image(label="Prediction Result")
609
+
610
+ upload_btn.click(
611
+ fn=process_upload,
612
+ inputs=[file_input, province, season, date],
613
+ outputs=[output_stats, output_image]
614
+ )
615
+
616
+ # -- Add Sample File Buttons Here --
617
+ gr.Markdown("### Or try with a sample file:")
618
+ with gr.Row():
619
+ for name in sample_files:
620
+ gr.Button(name).click(
621
+ fn=load_sample_and_predict,
622
+ inputs=[gr.State(name), province, season, date],
623
+ outputs=[output_stats, output_image]
624
+ )
625
+
626
+ with gr.TabItem("πŸ—ΊοΈ Map"):
627
+ gr.Markdown("""
628
+ ## Interactive Polygon Crop Prediction
629
+
630
+ **Instructions:**
631
+ 1. Draw a polygon on the map below using the polygon tool.
632
+ 2. Click the "Export" button on the map to save the polygon as a GeoJSON file (polygon.geojson).
633
+ 3. Upload the exported GeoJSON file using the file input below.
634
+ 4. Adjust settings and click "πŸ” Predict" to process.
635
+ """)
636
+
637
+ map_html = gr.HTML(create_interactive_map, label="Draw Your Polygon Here")
638
+
639
+ with gr.Row():
640
+ geojson_input = gr.File(label="Upload Exported GeoJSON File")
641
+ select_btn = gr.Button("🎯 Select My Polygon", variant="secondary")
642
+ spacing = gr.Slider(
643
+ label="Grid Spacing (meters)",
644
+ minimum=10, maximum=1000, value=30, step=100
645
+ )
646
+
647
+ with gr.Row():
648
+ province_map = gr.Textbox(label="Province", value="Punjab")
649
+ season_map = gr.Textbox(label="Season", value="Multan")
650
+ date_map = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
651
+
652
+ polygon_status = gr.Textbox(
653
+ label="Selection Status",
654
+ value="⏳ Please draw a polygon, export it, and upload the GeoJSON file.",
655
+ interactive=False
656
+ )
657
+
658
+ predict_btn = gr.Button("πŸ” Predict Crops", variant="primary", size="lg")
659
+
660
+ output_map_stats = gr.Textbox(label="Prediction Results", lines=10)
661
+ output_map = gr.HTML(label="Crop Prediction Map")
662
+ output_csv = gr.File(label="πŸ“₯ Download Results CSV")
663
+
664
+ select_btn.click(
665
+ fn=select_polygon,
666
+ inputs=[geojson_input],
667
+ outputs=polygon_status
668
+ )
669
+
670
+ predict_btn.click(
671
+ fn=process_polygon_prediction,
672
+ inputs=[spacing, province_map, season_map, date_map, geojson_input],
673
+ outputs=[output_map_stats, output_map, output_csv]
674
+ )
675
+
676
+ with gr.TabItem("πŸ“Š Instance"):
677
+ gr.Markdown("## Single Point Prediction")
678
+ gr.Markdown("Enter features manually for a single point prediction")
679
+
680
+ with gr.Row():
681
+ province_inst = gr.Textbox(label="Province", value="Punjab")
682
+ season_inst = gr.Textbox(label="Season", value="Rabi")
683
+
684
+ with gr.Row():
685
+ latitude_inst = gr.Number(label="Latitude", value=30.809)
686
+ longitude_inst = gr.Number(label="Longitude", value=73.450)
687
+ date_inst = gr.Textbox(label="Date (DD/MM/YYYY)", value="10/01/2023")
688
+
689
+ gr.Markdown("### Spectral Indices")
690
+ with gr.Row():
691
+ ndvi_inst = gr.Number(label="NDVI", value=0.65)
692
+ ndwi_inst = gr.Number(label="NDWI", value=-2.0)
693
+ ndbi_inst = gr.Number(label="NDBI", value=0.10)
694
+
695
+ gr.Markdown("### Band Values")
696
+ with gr.Row():
697
+ red_inst = gr.Number(label="Red", value=678)
698
+ green_inst = gr.Number(label="Green", value=732)
699
+ blue_inst = gr.Number(label="Blue", value=620)
700
+
701
+ with gr.Row():
702
+ nir_inst = gr.Number(label="NIR", value=3000)
703
+ swir_inst = gr.Number(label="SWIR", value=1800)
704
+
705
+ instance_btn = gr.Button("πŸ” Predict", variant="primary")
706
+ output_instance = gr.Textbox(label="Prediction Result", lines=3)
707
+
708
+ instance_btn.click(
709
+ fn=predict_instance,
710
+ inputs=[province_inst, season_inst, latitude_inst, longitude_inst,
711
+ date_inst, ndvi_inst, ndwi_inst, ndbi_inst, red_inst,
712
+ green_inst, blue_inst, nir_inst, swir_inst],
713
+ outputs=output_instance
714
+ )
715
+
716
+ if __name__ == "__main__":
717
+ demo.launch(share=True)
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da797b275739c1567c2311b4d475c8f2d09c2a69d02e91f29b51bb3ad4366840
3
+ size 2183