File size: 4,058 Bytes
ef3d1e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr

from baseline.baseline_convnext import predict_convnext
from baseline.baseline_infer import predict_baseline


# --- Placeholder models (for future extensions) ---
def predict_placeholder_1(image):
    if image is None:
        return "Please upload an image."
    return "Model 2 is not available yet. Please check back later."


def predict_placeholder_2(image):
    if image is None:
        return "Please upload an image."
    return "Model 3 is not available yet. Please check back later."


# --- Main Prediction Logic ---
def predict(model_choice, image):
    if model_choice == "Herbarium Species Classifier":
        # Friend's ConvNeXt mix-stream CNN baseline
        return predict_convnext(image)
    elif model_choice == "Baseline (DINOv2 + LogReg)":
        # Your plant-pretrained DINOv2 + Logistic Regression baseline
        return predict_baseline(image)
    elif model_choice == "Future Model 1 (Placeholder)":
        return predict_placeholder_1(image)
    elif model_choice == "Future Model 2 (Placeholder)":
        return predict_placeholder_2(image)
    else:
        return "Invalid model selected."


# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo:
    with gr.Column(elem_id="app-wrapper"):
        # Header
        gr.Markdown(
            """
            <div id="app-header">
              <h1>๐ŸŒฟ Plant Species Classification</h1>
              <h3>AML Group Project โ€“ PsychicFireSong</h3>
            </div>
            """,
            elem_id="app-header",
        )

        # Badges row
        gr.Markdown(
            """
            <div id="badge-row">
              <span class="badge">Herbarium + Field images</span>
              <span class="badge">ConvNeXtV2 mix-stream CNN</span>
              <span class="badge">DINOv2 + Logistic Regression</span>
            </div>
            """,
            elem_id="badge-row",
        )

        # Main card
        with gr.Row(elem_id="main-card"):
            # Left side: model + image
            with gr.Column(scale=1, elem_id="left-panel"):
                model_selector = gr.Dropdown(
                    label="Select model",
                    choices=[
                        "Herbarium Species Classifier",
                        "Baseline (DINOv2 + LogReg)",
                        "Future Model 1 (Placeholder)",
                        "Future Model 2 (Placeholder)",
                    ],
                    value="Herbarium Species Classifier",
                )

                gr.Markdown(
                    """
                    <div id="model-help">
                      <b>Herbarium Species Classifier</b> โ€“ end-to-end ConvNeXtV2 CNN.<br>
                      <b>Baseline</b> โ€“ plant-pretrained DINOv2 features + logistic regression head.
                    </div>
                    """,
                    elem_id="model-help",
                )

                image_input = gr.Image(
                    type="pil",
                    label="Upload plant image",
                )
                submit_button = gr.Button("Classify ๐ŸŒฑ", variant="primary")

            # Right side: predictions
            with gr.Column(scale=1, elem_id="right-panel"):
                output_label = gr.Label(
                    label="Top 5 predictions",
                    num_top_classes=5,
                )

        submit_button.click(
            fn=predict,
            inputs=[model_selector, image_input],
            outputs=output_label,
        )

        # Optional examples (keep empty if you don't have images)
        gr.Examples(
            examples=[],
            inputs=image_input,
            outputs=output_label,
            fn=lambda img: predict("Herbarium Species Classifier", img),
            cache_examples=False,
        )

        gr.Markdown(
            "Built for the AML course โ€“ compare CNN vs. DINOv2 feature-extractor baselines.",
            elem_id="footer",
        )

if __name__ == "__main__":
    demo.launch()