FrAnKu34t23 commited on
Commit
eae49fb
·
verified ·
1 Parent(s): 58867b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -32
app.py CHANGED
@@ -1,38 +1,53 @@
1
  import gradio as gr
2
 
3
- from baseline.baseline_convnext import predict_convnext
4
- from baseline.baseline_infer import predict_baseline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  # --- Placeholder models (for future extensions) ---
8
- def predict_placeholder_1(image):
9
- if image is None:
10
- return "Please upload an image."
11
- return "Model 2 is not available yet. Please check back later."
12
-
13
-
14
  def predict_placeholder_2(image):
15
  if image is None:
16
  return "Please upload an image."
17
- return "Model 3 is not available yet. Please check back later."
18
-
19
 
20
  # --- Main Prediction Logic ---
21
  def predict(model_choice, image):
22
- if model_choice == "Herbarium Species Classifier":
23
- # Friend's ConvNeXt mix-stream CNN baseline
 
 
24
  return predict_convnext(image)
 
25
  elif model_choice == "Baseline (DINOv2 + LogReg)":
26
- # Your plant-pretrained DINOv2 + Logistic Regression baseline
27
  return predict_baseline(image)
28
- elif model_choice == "Future Model 1 (Placeholder)":
29
- return predict_placeholder_1(image)
 
 
 
30
  elif model_choice == "Future Model 2 (Placeholder)":
31
  return predict_placeholder_2(image)
 
32
  else:
33
  return "Invalid model selected."
34
 
35
-
36
  # --- Gradio Interface ---
37
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
38
  with gr.Column(elem_id="app-wrapper"):
@@ -41,24 +56,24 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
41
  """
42
  <div id="app-header">
43
  <h1>🌿 Plant Species Classification</h1>
44
- <h3>AML Group Project – PsychicFireSong</h3>
45
  </div>
46
  """,
47
  elem_id="app-header",
48
  )
49
-
50
  # Badges row
51
  gr.Markdown(
52
  """
53
  <div id="badge-row">
54
- <span class="badge">Herbarium + Field images</span>
55
- <span class="badge">ConvNeXtV2 mix-stream CNN</span>
56
- <span class="badge">DINOv2 + Logistic Regression</span>
57
  </div>
58
  """,
59
  elem_id="badge-row",
60
  )
61
-
62
  # Main card
63
  with gr.Row(elem_id="main-card"):
64
  # Left side: model + image
@@ -68,26 +83,28 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
68
  choices=[
69
  "Herbarium Species Classifier",
70
  "Baseline (DINOv2 + LogReg)",
71
- "Future Model 1 (Placeholder)",
72
  "Future Model 2 (Placeholder)",
73
  ],
74
- value="Herbarium Species Classifier",
75
  )
76
-
77
  gr.Markdown(
78
  """
79
  <div id="model-help">
80
- <b>Herbarium Species Classifier</b> – end-to-end ConvNeXtV2 CNN.<br>
81
- <b>Baseline</b> – plant-pretrained DINOv2 features + logistic regression head.
 
82
  </div>
83
  """,
84
  elem_id="model-help",
85
  )
86
-
87
  image_input = gr.Image(
88
  type="pil",
89
  label="Upload plant image",
90
  )
 
91
  submit_button = gr.Button("Classify ���", variant="primary")
92
 
93
  # Right side: predictions
@@ -103,19 +120,19 @@ with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
103
  outputs=output_label,
104
  )
105
 
106
- # Optional examples (keep empty if you don't have images)
107
  gr.Examples(
108
  examples=[],
109
  inputs=image_input,
110
  outputs=output_label,
111
- fn=lambda img: predict("Herbarium Species Classifier", img),
112
  cache_examples=False,
113
  )
114
 
115
  gr.Markdown(
116
- "Built for the AML course – compare CNN vs. DINOv2 feature-extractor baselines.",
117
  elem_id="footer",
118
  )
119
 
120
  if __name__ == "__main__":
121
- demo.launch()
 
1
  import gradio as gr
2
 
3
+ # --- 1. Import Existing Baselines ---
4
+ # Wrapped in try-except so the app doesn't crash if files are temporarily missing
5
+ try:
6
+ from baseline.baseline_convnext import predict_convnext
7
+ except ImportError:
8
+ def predict_convnext(image): return {"Error": "ConvNeXt module missing"}
9
+
10
+ try:
11
+ from baseline.baseline_infer import predict_baseline
12
+ except ImportError:
13
+ def predict_baseline(image): return {"Error": "Baseline module missing"}
14
+
15
+ # --- 2. Import NEW SPA Approach ---
16
+ # This imports the function from: new_approach/spa_ensemble.py
17
+ try:
18
+ from new_approach.spa_ensemble import predict_spa
19
+ except ImportError:
20
+ def predict_spa(image): return {"Error": "SPA module missing. Check 'new_approach' folder."}
21
 
22
 
23
  # --- Placeholder models (for future extensions) ---
 
 
 
 
 
 
24
  def predict_placeholder_2(image):
25
  if image is None:
26
  return "Please upload an image."
27
+ return "Model 4 is not available yet. Please check back later."
 
28
 
29
  # --- Main Prediction Logic ---
30
  def predict(model_choice, image):
31
+ if image is None: return None
32
+
33
+ if model_choice == "Herbarium Species Classifier (ConvNeXT)":
34
+ # Luna's ConvNeXt mix-stream CNN baseline
35
  return predict_convnext(image)
36
+
37
  elif model_choice == "Baseline (DINOv2 + LogReg)":
38
+ # Islam's baseline
39
  return predict_baseline(image)
40
+
41
+ elif model_choice == "SPA Ensemble (New Approach)":
42
+ # New approach: DINOv2 + BioCLIP + Handcrafted + SPA
43
+ return predict_spa(image)
44
+
45
  elif model_choice == "Future Model 2 (Placeholder)":
46
  return predict_placeholder_2(image)
47
+
48
  else:
49
  return "Invalid model selected."
50
 
 
51
  # --- Gradio Interface ---
52
  with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
53
  with gr.Column(elem_id="app-wrapper"):
 
56
  """
57
  <div id="app-header">
58
  <h1>🌿 Plant Species Classification</h1>
59
+ <h3>AML Group Project – Group 8</h3>
60
  </div>
61
  """,
62
  elem_id="app-header",
63
  )
64
+
65
  # Badges row
66
  gr.Markdown(
67
  """
68
  <div id="badge-row">
69
+ <span class="badge">Herbarium + Field images (ConvNeXT)</span>
70
+ <span class="badge">ConvNeXtV2</span>
71
+ <span class="badge">SPA Ensemble</span>
72
  </div>
73
  """,
74
  elem_id="badge-row",
75
  )
76
+
77
  # Main card
78
  with gr.Row(elem_id="main-card"):
79
  # Left side: model + image
 
83
  choices=[
84
  "Herbarium Species Classifier",
85
  "Baseline (DINOv2 + LogReg)",
86
+ "SPA Ensemble (New Approach)",
87
  "Future Model 2 (Placeholder)",
88
  ],
89
+ value="SPA Ensemble (New Approach)", # Default to your new model
90
  )
91
+
92
  gr.Markdown(
93
  """
94
  <div id="model-help">
95
+ <b>Herbarium Classifier</b> – ConvNeXtV2 CNN.<br>
96
+ <b>Baseline</b> – Simple DINOv2 + LogReg.<br>
97
+ <b>SPA Ensemble</b> – <i>(New)</i> DINOv2 + BioCLIP-2 + Handcrafted features.
98
  </div>
99
  """,
100
  elem_id="model-help",
101
  )
102
+
103
  image_input = gr.Image(
104
  type="pil",
105
  label="Upload plant image",
106
  )
107
+
108
  submit_button = gr.Button("Classify ���", variant="primary")
109
 
110
  # Right side: predictions
 
120
  outputs=output_label,
121
  )
122
 
123
+ # Optional examples
124
  gr.Examples(
125
  examples=[],
126
  inputs=image_input,
127
  outputs=output_label,
128
+ fn=lambda img: predict("SPA Ensemble (New Approach)", img),
129
  cache_examples=False,
130
  )
131
 
132
  gr.Markdown(
133
+ "Built for the AML course – compare CNN vs. DINOv2 feature-extractor baselines with the new approaches to address cross-domain plant identification.",
134
  elem_id="footer",
135
  )
136
 
137
  if __name__ == "__main__":
138
+ demo.launch()