Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from baseline.baseline_convnext import predict_convnext | |
| from baseline.baseline_infer import predict_baseline | |
| # --- Placeholder models (for future extensions) --- | |
| def predict_placeholder_1(image): | |
| if image is None: | |
| return "Please upload an image." | |
| return "Model 2 is not available yet. Please check back later." | |
| def predict_placeholder_2(image): | |
| if image is None: | |
| return "Please upload an image." | |
| return "Model 3 is not available yet. Please check back later." | |
| # --- Main Prediction Logic --- | |
| def predict(model_choice, image): | |
| if model_choice == "Herbarium Species Classifier": | |
| # Friend's ConvNeXt mix-stream CNN baseline | |
| return predict_convnext(image) | |
| elif model_choice == "Baseline (DINOv2 + LogReg)": | |
| # Your plant-pretrained DINOv2 + Logistic Regression baseline | |
| return predict_baseline(image) | |
| elif model_choice == "Future Model 1 (Placeholder)": | |
| return predict_placeholder_1(image) | |
| elif model_choice == "Future Model 2 (Placeholder)": | |
| return predict_placeholder_2(image) | |
| else: | |
| return "Invalid model selected." | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), css="style.css") as demo: | |
| with gr.Column(elem_id="app-wrapper"): | |
| # Header | |
| gr.Markdown( | |
| """ | |
| <div id="app-header"> | |
| <h1>πΏ Plant Species Classification</h1> | |
| <h3>AML Group Project β PsychicFireSong</h3> | |
| </div> | |
| """, | |
| elem_id="app-header", | |
| ) | |
| # Badges row | |
| gr.Markdown( | |
| """ | |
| <div id="badge-row"> | |
| <span class="badge">Herbarium + Field images</span> | |
| <span class="badge">ConvNeXtV2 mix-stream CNN</span> | |
| <span class="badge">DINOv2 + Logistic Regression</span> | |
| </div> | |
| """, | |
| elem_id="badge-row", | |
| ) | |
| # Main card | |
| with gr.Row(elem_id="main-card"): | |
| # Left side: model + image | |
| with gr.Column(scale=1, elem_id="left-panel"): | |
| model_selector = gr.Dropdown( | |
| label="Select model", | |
| choices=[ | |
| "Herbarium Species Classifier", | |
| "Baseline (DINOv2 + LogReg)", | |
| "Future Model 1 (Placeholder)", | |
| "Future Model 2 (Placeholder)", | |
| ], | |
| value="Herbarium Species Classifier", | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div id="model-help"> | |
| <b>Herbarium Species Classifier</b> β end-to-end ConvNeXtV2 CNN.<br> | |
| <b>Baseline</b> β plant-pretrained DINOv2 features + logistic regression head. | |
| </div> | |
| """, | |
| elem_id="model-help", | |
| ) | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Upload plant image", | |
| ) | |
| submit_button = gr.Button("Classify π±", variant="primary") | |
| # Right side: predictions | |
| with gr.Column(scale=1, elem_id="right-panel"): | |
| output_label = gr.Label( | |
| label="Top 5 predictions", | |
| num_top_classes=5, | |
| ) | |
| submit_button.click( | |
| fn=predict, | |
| inputs=[model_selector, image_input], | |
| outputs=output_label, | |
| ) | |
| # Optional examples (keep empty if you don't have images) | |
| gr.Examples( | |
| examples=[], | |
| inputs=image_input, | |
| outputs=output_label, | |
| fn=lambda img: predict("Herbarium Species Classifier", img), | |
| cache_examples=False, | |
| ) | |
| gr.Markdown( | |
| "Built for the AML course β compare CNN vs. DINOv2 feature-extractor baselines.", | |
| elem_id="footer", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |