Joaquin Villar commited on
Commit
d149d93
·
verified ·
1 Parent(s): 0f53989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -52
app.py CHANGED
@@ -1,87 +1,137 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  from peft import PeftModel
5
 
6
- # --- CONFIGURATION ---
7
- # Replace with your specific repo name
 
 
 
 
8
  ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
9
  BASE_MODEL_ID = "distilbert-base-uncased"
10
-
11
  CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
12
 
 
13
  def load_model():
14
  print("Loading Base Model...")
15
- # 1. Load the Base Model (Generic DistilBERT)
16
  base_model = AutoModelForSequenceClassification.from_pretrained(
17
  BASE_MODEL_ID,
18
  num_labels=len(CLASS_NAMES),
19
- id2label={k: v for k, v in CLASS_NAMES.items()},
20
  label2id={v: k for k, v in CLASS_NAMES.items()}
21
  )
22
 
23
- # 2. Load the Tokenizer from YOUR repo (ensures consistency)
24
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
25
 
26
- # 3. Load and Apply your LoRA Adapters
27
- print(f"Loading LoRA Adapters from {ADAPTER_REPO}...")
28
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
29
 
30
- # Optimize for CPU (Free Tier Spaces are CPU)
31
  device = torch.device("cpu")
32
  model.to(device)
33
  model.eval()
34
-
35
  return model, tokenizer, device
36
 
37
- # Load model once on startup
38
  model, tokenizer, device = load_model()
39
 
40
- def classify_news(text):
41
- if not text:
42
- return None
43
-
44
- # Preprocess
45
  inputs = tokenizer(
46
- text,
47
- return_tensors="pt",
48
- truncation=True,
49
- padding="max_length",
50
- max_length=128
51
  ).to(device)
52
 
53
- # Predict
54
  with torch.no_grad():
55
  outputs = model(**inputs)
56
 
57
- # Get Probabilities
58
  logits = outputs.logits
59
- probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
60
-
61
- # Format Output
62
- results = {}
63
- for i, prob in enumerate(probabilities):
64
- results[CLASS_NAMES[i]] = float(prob)
65
-
66
- return results
67
-
68
- # Create Interface
69
- iface = gr.Interface(
70
- fn=classify_news,
71
- inputs=gr.Textbox(
72
- lines=5,
73
- placeholder="Paste a news article here...",
74
- label="News Text"
75
- ),
76
- outputs=gr.Label(num_top_classes=4, label="Prediction"),
77
- title="AI News Classifier (DistilBERT + LoRA)",
78
- description="This model classifies news into World, Sports, Business, or Sci/Tech categories. Trained on AG News using Parameter-Efficient Fine-Tuning.",
79
- examples=[
80
- ["The stock market rallied today as tech companies reported record profits."],
81
- ["The team scored a goal in the final minute to win the championship."],
82
- ["New research shows that drinking coffee may increase life expectancy."],
83
- ["Diplomats gathered in Geneva to discuss the peace treaty."]
84
- ]
85
- )
86
-
87
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from peft import PeftModel
6
 
7
+ # --- 1. CONFIGURATION ---
8
+ MODEL_METRICS = {
9
+ "Accuracy": "89.20%",
10
+ "F1_Score": "0.8931"
11
+ }
12
+
13
  ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
14
  BASE_MODEL_ID = "distilbert-base-uncased"
 
15
  CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
16
 
17
+ # --- 2. MODEL LOADING ---
18
  def load_model():
19
  print("Loading Base Model...")
 
20
  base_model = AutoModelForSequenceClassification.from_pretrained(
21
  BASE_MODEL_ID,
22
  num_labels=len(CLASS_NAMES),
23
+ id2label={k: v for k, v in enumerate(CLASS_NAMES.values())},
24
  label2id={v: k for k, v in CLASS_NAMES.items()}
25
  )
26
 
27
+ print("Loading Tokenizer...")
28
  tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
29
 
30
+ print("Loading Adapters...")
 
31
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
32
 
 
33
  device = torch.device("cpu")
34
  model.to(device)
35
  model.eval()
 
36
  return model, tokenizer, device
37
 
 
38
  model, tokenizer, device = load_model()
39
 
40
+ # --- 3. PREDICTION LOGIC ---
41
+ def predict(text):
42
+ if not text.strip():
43
+ return None, None, None
44
+
45
  inputs = tokenizer(
46
+ text, return_tensors="pt", truncation=True, padding="max_length", max_length=128
 
 
 
 
47
  ).to(device)
48
 
 
49
  with torch.no_grad():
50
  outputs = model(**inputs)
51
 
 
52
  logits = outputs.logits
53
+ probs = torch.nn.functional.softmax(logits, dim=1).squeeze().cpu().numpy()
54
+
55
+ # 1. Get Top Label
56
+ pred_idx = np.argmax(probs)
57
+ pred_label = CLASS_NAMES[pred_idx]
58
+ conf = float(probs[pred_idx])
59
+
60
+ # 2. Create Probability Dict for the Chart
61
+ class_probs = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
62
+
63
+ # 3. Create HTML for the "Confidence Badge" (Mimicking Streamlit)
64
+ if conf > 0.85:
65
+ bg_color, txt_color = "#d4edda", "#155724" # Green
66
+ elif conf > 0.60:
67
+ bg_color, txt_color = "#fff3cd", "#856404" # Yellow
68
+ else:
69
+ bg_color, txt_color = "#f8d7da", "#721c24" # Red
70
+
71
+ badge_html = f"""
72
+ <div style='background-color: {bg_color}; color: {txt_color};
73
+ padding: 8px 12px; border-radius: 5px; display: inline-block; font-weight: bold; font-size: 16px;'>
74
+ Confidence: {conf:.2%}
75
+ </div>
76
+ """
77
+
78
+ # Return: Label Text, Badge HTML, Chart Data
79
+ return f"# {pred_label}", badge_html, class_probs
80
+
81
+ # --- 4. UI LAYOUT (gr.Blocks) ---
82
+ # We use Soft theme to match Streamlit's clean look
83
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
+
85
+ # Title
86
+ gr.Markdown("# 📰 NLP News Classifier")
87
+ gr.Markdown("Classify news articles into World, Sports, Business, or Sci/Tech using DistilBERT + LoRA.")
88
+
89
+ # -- The "Green Banner" (HTML) --
90
+ gr.HTML(f"""
91
+ <div style="background-color: #d1e7dd; color: #0f5132; padding: 15px; border-radius: 5px; border: 1px solid #badbcc; margin-bottom: 20px;">
92
+ ✅ <b>Model Performance:</b> Accuracy: {MODEL_METRICS['Accuracy']} | F1 Score: {MODEL_METRICS['F1_Score']}
93
+ </div>
94
+ """)
95
+
96
+ with gr.Row():
97
+ # Left Column: Input
98
+ with gr.Column(scale=1):
99
+ input_text = gr.Textbox(
100
+ lines=6,
101
+ placeholder="Paste a news snippet here...",
102
+ label="News Article"
103
+ )
104
+ btn = gr.Button("Classify Article", variant="primary")
105
+
106
+ gr.Markdown("### Examples")
107
+ gr.Examples(
108
+ examples=[
109
+ ["The stock market rallied today as tech companies reported record profits."],
110
+ ["The local team won the championship after a stunning overtime goal."],
111
+ ["NASA announces plans to launch a new rover to Mars next July."]
112
+ ],
113
+ inputs=input_text
114
+ )
115
+
116
+ # Right Column: Results
117
+ with gr.Column(scale=1):
118
+ gr.Markdown("### Prediction")
119
+ # Output 1: Big Label text
120
+ out_label = gr.Markdown()
121
+ # Output 2: The Colored Badge
122
+ out_badge = gr.HTML()
123
+
124
+ gr.Markdown("### Probability Breakdown")
125
+ # Output 3: Bar Chart (Label component handles this beautifully)
126
+ out_chart = gr.Label(num_top_classes=4, label="Confidence Scores")
127
+
128
+ # Wire up the button
129
+ btn.click(
130
+ fn=predict,
131
+ inputs=input_text,
132
+ outputs=[out_label, out_badge, out_chart]
133
+ )
134
+
135
+ # Launch
136
+ if __name__ == "__main__":
137
+ demo.launch()