Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from peft import PeftModel | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # --- 1. CONFIGURATION --- | |
| ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora" | |
| BASE_MODEL_ID = "distilbert-base-uncased" | |
| CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
| # --- 2. DYNAMIC METRICS LOADING --- | |
| def fetch_metrics(): | |
| """Downloads evaluation_report.json from the Model Hub.""" | |
| try: | |
| file_path = hf_hub_download(repo_id=ADAPTER_REPO, filename="evaluation_report.json") | |
| with open(file_path, "r") as f: | |
| data = json.load(f) | |
| # Extract numbers | |
| acc = data['overall_metrics']['Accuracy'] | |
| f1 = data['overall_metrics']['F1 Macro'] | |
| return { | |
| "Accuracy": f"{acc:.2%}", | |
| "F1_Score": f"{f1:.4f}" | |
| } | |
| except Exception as e: | |
| print(f"Error loading metrics: {e}") | |
| return {"Accuracy": "N/A", "F1_Score": "N/A"} | |
| # Load metrics on app startup | |
| MODEL_METRICS = fetch_metrics() | |
| # --- 3. MODEL LOADING --- | |
| def load_model(): | |
| print("Loading Base Model...") | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL_ID, | |
| num_labels=len(CLASS_NAMES), | |
| id2label={k: v for k, v in enumerate(CLASS_NAMES.values())}, | |
| label2id={v: k for k, v in CLASS_NAMES.items()} | |
| ) | |
| print("Loading Tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO) | |
| print("Loading Adapters...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) | |
| # Force CPU for Free Tier Spaces | |
| device = torch.device("cpu") | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, device | |
| model, tokenizer, device = load_model() | |
| # --- 4. PREDICTION LOGIC --- | |
| def predict(text): | |
| if not text.strip(): | |
| return None, None, None | |
| inputs = tokenizer( | |
| text, return_tensors="pt", truncation=True, padding="max_length", max_length=128 | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy() | |
| # 1. Get Top Label | |
| pred_idx = np.argmax(probs) | |
| pred_label = CLASS_NAMES[pred_idx] | |
| conf = float(probs[pred_idx]) | |
| # 2. Create Probability Dict for the Chart | |
| class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))} | |
| # 3. Create HTML for the "Confidence Badge" | |
| if conf > 0.85: | |
| bg_color, txt_color, icon = "#d4edda", "#155724", "β" # Green | |
| elif conf > 0.60: | |
| bg_color, txt_color, icon = "#fff3cd", "#856404", "~" # Yellow | |
| else: | |
| bg_color, txt_color, icon = "#f8d7da", "#721c24", "β" # Red | |
| badge_html = f""" | |
| <div style='background-color: {bg_color}; color: {txt_color}; | |
| padding: 8px 16px; border-radius: 5px; display: inline-block; font-weight: bold; font-size: 16px;'> | |
| {icon} Confidence: {conf:.2%} | |
| </div> | |
| """ | |
| # Return: Label Text, Badge HTML, Chart Data | |
| return f"# {pred_label}", badge_html, class_probs | |
| # --- 5. UI LAYOUT (gr.Blocks) --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π° NLP News Classifier") | |
| gr.Markdown("Classify news articles into World, Sports, Business, or Sci/Tech using DistilBERT + LoRA.") | |
| # -- The "Green Banner" (HTML) -- | |
| gr.HTML(f""" | |
| <div style=" | |
| background-color: #d1e7dd; | |
| padding: 15px; | |
| border-radius: 5px; | |
| border: 1px solid #badbcc; | |
| margin-bottom: 20px; | |
| color: #0f5132; | |
| "> | |
| <span style="color: #0f5132; font-weight: bold;">β Model Performance (Test Set):</span> | |
| <span style="color: #0f5132;">Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Column: Input | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| lines=6, | |
| placeholder="Paste a news snippet here...", | |
| label="News Article" | |
| ) | |
| btn = gr.Button("Classify Article", variant="primary") | |
| gr.Markdown("### Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["The stock market rallied today as tech companies reported record profits."], | |
| ["The local team won the championship after a stunning overtime goal."], | |
| ["NASA announces plans to launch a new rover to Mars next July."] | |
| ], | |
| inputs=input_text | |
| ) | |
| # Right Column: Results | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Prediction") | |
| # Output 1: Big Label text | |
| out_label = gr.Markdown() | |
| # Output 2: The Colored Badge | |
| out_badge = gr.HTML() | |
| gr.Markdown("### Probability Breakdown") | |
| # Output 3: Bar Chart | |
| out_chart = gr.Label(num_top_classes=4, label="Confidence Scores") | |
| # Wire up the button | |
| btn.click( | |
| fn=predict, | |
| inputs=input_text, | |
| outputs=[out_label, out_badge, out_chart] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() |