jvillar02's picture
Update app.py
558d7d3 verified
import gradio as gr
import torch
import numpy as np
import json
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
from huggingface_hub import hf_hub_download
import os
# --- 1. CONFIGURATION ---
ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
BASE_MODEL_ID = "distilbert-base-uncased"
CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
# --- 2. DYNAMIC METRICS LOADING ---
def fetch_metrics():
"""Downloads evaluation_report.json from the Model Hub."""
try:
file_path = hf_hub_download(repo_id=ADAPTER_REPO, filename="evaluation_report.json")
with open(file_path, "r") as f:
data = json.load(f)
# Extract numbers
acc = data['overall_metrics']['Accuracy']
f1 = data['overall_metrics']['F1 Macro']
return {
"Accuracy": f"{acc:.2%}",
"F1_Score": f"{f1:.4f}"
}
except Exception as e:
print(f"Error loading metrics: {e}")
return {"Accuracy": "N/A", "F1_Score": "N/A"}
# Load metrics on app startup
MODEL_METRICS = fetch_metrics()
# --- 3. MODEL LOADING ---
def load_model():
print("Loading Base Model...")
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL_ID,
num_labels=len(CLASS_NAMES),
id2label={k: v for k, v in enumerate(CLASS_NAMES.values())},
label2id={v: k for k, v in CLASS_NAMES.items()}
)
print("Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
print("Loading Adapters...")
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
# Force CPU for Free Tier Spaces
device = torch.device("cpu")
model.to(device)
model.eval()
return model, tokenizer, device
model, tokenizer, device = load_model()
# --- 4. PREDICTION LOGIC ---
def predict(text):
if not text.strip():
return None, None, None
inputs = tokenizer(
text, return_tensors="pt", truncation=True, padding="max_length", max_length=128
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy()
# 1. Get Top Label
pred_idx = np.argmax(probs)
pred_label = CLASS_NAMES[pred_idx]
conf = float(probs[pred_idx])
# 2. Create Probability Dict for the Chart
class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
# 3. Create HTML for the "Confidence Badge"
if conf > 0.85:
bg_color, txt_color, icon = "#d4edda", "#155724", "↑" # Green
elif conf > 0.60:
bg_color, txt_color, icon = "#fff3cd", "#856404", "~" # Yellow
else:
bg_color, txt_color, icon = "#f8d7da", "#721c24", "↓" # Red
badge_html = f"""
<div style='background-color: {bg_color}; color: {txt_color};
padding: 8px 16px; border-radius: 5px; display: inline-block; font-weight: bold; font-size: 16px;'>
{icon} Confidence: {conf:.2%}
</div>
"""
# Return: Label Text, Badge HTML, Chart Data
return f"# {pred_label}", badge_html, class_probs
# --- 5. UI LAYOUT (gr.Blocks) ---
with gr.Blocks() as demo:
gr.Markdown("# πŸ“° NLP News Classifier")
gr.Markdown("Classify news articles into World, Sports, Business, or Sci/Tech using DistilBERT + LoRA.")
# -- The "Green Banner" (HTML) --
gr.HTML(f"""
<div style="
background-color: #d1e7dd;
padding: 15px;
border-radius: 5px;
border: 1px solid #badbcc;
margin-bottom: 20px;
color: #0f5132;
">
<span style="color: #0f5132; font-weight: bold;">βœ… Model Performance (Test Set):</span>
<span style="color: #0f5132;">Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}</span>
</div>
""")
with gr.Row():
# Left Column: Input
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=6,
placeholder="Paste a news snippet here...",
label="News Article"
)
btn = gr.Button("Classify Article", variant="primary")
gr.Markdown("### Examples")
gr.Examples(
examples=[
["The stock market rallied today as tech companies reported record profits."],
["The local team won the championship after a stunning overtime goal."],
["NASA announces plans to launch a new rover to Mars next July."]
],
inputs=input_text
)
# Right Column: Results
with gr.Column(scale=1):
gr.Markdown("### Prediction")
# Output 1: Big Label text
out_label = gr.Markdown()
# Output 2: The Colored Badge
out_badge = gr.HTML()
gr.Markdown("### Probability Breakdown")
# Output 3: Bar Chart
out_chart = gr.Label(num_top_classes=4, label="Confidence Scores")
# Wire up the button
btn.click(
fn=predict,
inputs=input_text,
outputs=[out_label, out_badge, out_chart]
)
# Launch
if __name__ == "__main__":
demo.launch()