Joaquin Villar commited on
Commit
79f8da9
·
verified ·
1 Parent(s): ccb6a69

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from peft import PeftModel
5
+
6
+ # --- CONFIGURATION ---
7
+ # Replace with your specific repo name
8
+ ADAPTER_REPO = "jvillar-sheff/news-classifier-demo"
9
+ BASE_MODEL_ID = "distilbert-base-uncased"
10
+
11
+ CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
12
+
13
+ def load_model():
14
+ print("Loading Base Model...")
15
+ # 1. Load the Base Model (Generic DistilBERT)
16
+ base_model = AutoModelForSequenceClassification.from_pretrained(
17
+ BASE_MODEL_ID,
18
+ num_labels=len(CLASS_NAMES),
19
+ id2label={k: v for k, v in CLASS_NAMES.items()},
20
+ label2id={v: k for k, v in CLASS_NAMES.items()}
21
+ )
22
+
23
+ # 2. Load the Tokenizer from YOUR repo (ensures consistency)
24
+ tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
25
+
26
+ # 3. Load and Apply your LoRA Adapters
27
+ print(f"Loading LoRA Adapters from {ADAPTER_REPO}...")
28
+ model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
29
+
30
+ # Optimize for CPU (Free Tier Spaces are CPU)
31
+ device = torch.device("cpu")
32
+ model.to(device)
33
+ model.eval()
34
+
35
+ return model, tokenizer, device
36
+
37
+ # Load model once on startup
38
+ model, tokenizer, device = load_model()
39
+
40
+ def classify_news(text):
41
+ if not text:
42
+ return None
43
+
44
+ # Preprocess
45
+ inputs = tokenizer(
46
+ text,
47
+ return_tensors="pt",
48
+ truncation=True,
49
+ padding="max_length",
50
+ max_length=128
51
+ ).to(device)
52
+
53
+ # Predict
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+
57
+ # Get Probabilities
58
+ logits = outputs.logits
59
+ probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
60
+
61
+ # Format Output
62
+ results = {}
63
+ for i, prob in enumerate(probabilities):
64
+ results[CLASS_NAMES[i]] = float(prob)
65
+
66
+ return results
67
+
68
+ # Create Interface
69
+ iface = gr.Interface(
70
+ fn=classify_news,
71
+ inputs=gr.Textbox(
72
+ lines=5,
73
+ placeholder="Paste a news article here...",
74
+ label="News Text"
75
+ ),
76
+ outputs=gr.Label(num_top_classes=4, label="Prediction"),
77
+ title="AI News Classifier (DistilBERT + LoRA)",
78
+ description="This model classifies news into World, Sports, Business, or Sci/Tech categories. Trained on AG News using Parameter-Efficient Fine-Tuning.",
79
+ examples=[
80
+ ["The stock market rallied today as tech companies reported record profits."],
81
+ ["The team scored a goal in the final minute to win the championship."],
82
+ ["New research shows that drinking coffee may increase life expectancy."],
83
+ ["Diplomats gathered in Geneva to discuss the peace treaty."]
84
+ ],
85
+ theme="soft"
86
+ )
87
+
88
+ iface.launch()