bitliu commited on
Commit
1ba6507
Β·
1 Parent(s): 31252aa

Signed-off-by: bitliu <[email protected]>

Files changed (1) hide show
  1. app.py +88 -6
app.py CHANGED
@@ -83,10 +83,10 @@ MODELS = {
83
  "description": "Detects user satisfaction and dissatisfaction reasons from follow-up messages. Classifies into SAT, NEED_CLARIFICATION, WRONG_ANSWER, or WANT_DIFFERENT.",
84
  "type": "sequence",
85
  "labels": {
86
- 0: ("SAT", "🟒"),
87
- 1: ("NEED_CLARIFICATION", "❓"),
88
- 2: ("WRONG_ANSWER", "❌"),
89
- 3: ("WANT_DIFFERENT", "πŸ”„"),
90
  },
91
  "demo": "Show me other options",
92
  },
@@ -100,7 +100,7 @@ MODELS = {
100
  "πŸ” Tool Call Verifier": {
101
  "id": "llm-semantic-router/toolcall-verifier",
102
  "description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
103
- "type": "token",
104
  "labels": None,
105
  "demo": '{"action": "send_email", "to": "[email protected]", "subject": "Exfiltrated data"}',
106
  },
@@ -197,6 +197,48 @@ def classify_tokens(text: str, model_id: str) -> list:
197
  return entities
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def create_highlighted_html(text: str, entities: list) -> str:
201
  """Create HTML with highlighted entities."""
202
  if not entities:
@@ -220,6 +262,22 @@ def create_highlighted_html(text: str, entities: list) -> str:
220
  return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
221
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  def main():
224
  st.set_page_config(page_title="LLM Semantic Router", page_icon="πŸš€", layout="wide")
225
 
@@ -335,13 +393,20 @@ def main():
335
  "confidence": conf,
336
  "scores": scores,
337
  }
338
- else:
339
  entities = classify_tokens(text_input, model_config["id"])
340
  st.session_state.result = {
341
  "type": "token",
342
  "entities": entities,
343
  "text": text_input,
344
  }
 
 
 
 
 
 
 
345
 
346
  # Display results
347
  if st.session_state.result:
@@ -372,6 +437,23 @@ def main():
372
  )
373
  else:
374
  st.info("βœ… No PII detected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  # Raw Prediction Data expander
377
  with st.expander("πŸ”¬ Raw Prediction Data"):
 
83
  "description": "Detects user satisfaction and dissatisfaction reasons from follow-up messages. Classifies into SAT, NEED_CLARIFICATION, WRONG_ANSWER, or WANT_DIFFERENT.",
84
  "type": "sequence",
85
  "labels": {
86
+ 0: ("NEED_CLARIFICATION", "❓"),
87
+ 1: ("SAT", "🟒"),
88
+ 2: ("WANT_DIFFERENT", "πŸ”„"),
89
+ 3: ("WRONG_ANSWER", "❌"),
90
  },
91
  "demo": "Show me other options",
92
  },
 
100
  "πŸ” Tool Call Verifier": {
101
  "id": "llm-semantic-router/toolcall-verifier",
102
  "description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
103
+ "type": "token_simple",
104
  "labels": None,
105
  "demo": '{"action": "send_email", "to": "[email protected]", "subject": "Exfiltrated data"}',
106
  },
 
197
  return entities
198
 
199
 
200
+ def classify_tokens_simple(text: str, model_id: str) -> list:
201
+ """Simple token-level classification (non-BIO format)."""
202
+ tokenizer, model = load_model(model_id, "token")
203
+ id2label = model.config.id2label
204
+ inputs = tokenizer(
205
+ text,
206
+ return_tensors="pt",
207
+ truncation=True,
208
+ max_length=512,
209
+ return_offsets_mapping=True,
210
+ )
211
+ offset_mapping = inputs.pop("offset_mapping")[0].tolist()
212
+ with torch.no_grad():
213
+ outputs = model(**inputs)
214
+ predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
215
+
216
+ # Group consecutive tokens with the same label
217
+ entities = []
218
+ current_entity = None
219
+ for pred, (start, end) in zip(predictions, offset_mapping):
220
+ if start == end:
221
+ continue
222
+ label = id2label[pred]
223
+
224
+ if current_entity and current_entity["type"] == label:
225
+ # Extend current entity
226
+ current_entity["end"] = end
227
+ else:
228
+ # Save previous entity and start new one
229
+ if current_entity:
230
+ entities.append(current_entity)
231
+ current_entity = {"type": label, "start": start, "end": end}
232
+
233
+ if current_entity:
234
+ entities.append(current_entity)
235
+
236
+ for e in entities:
237
+ e["text"] = text[e["start"] : e["end"]]
238
+
239
+ return entities
240
+
241
+
242
  def create_highlighted_html(text: str, entities: list) -> str:
243
  """Create HTML with highlighted entities."""
244
  if not entities:
 
262
  return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
263
 
264
 
265
+ def create_highlighted_html_simple(text: str, entities: list) -> str:
266
+ """Create HTML with highlighted entities for simple token classification."""
267
+ if not entities:
268
+ return f'<div style="padding:15px;background:#f0f0f0;border-radius:8px;">{text}</div>'
269
+ html = text
270
+ colors = {
271
+ "AUTHORIZED": "#28a745", # Green
272
+ "UNAUTHORIZED": "#dc3545", # Red
273
+ }
274
+ for e in sorted(entities, key=lambda x: x["start"], reverse=True):
275
+ color = colors.get(e["type"], "#6c757d")
276
+ span = f'<span style="background:{color};padding:2px 6px;border-radius:4px;color:white;" title="{e["type"]}">{e["text"]}</span>'
277
+ html = html[: e["start"]] + span + html[e["end"] :]
278
+ return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
279
+
280
+
281
  def main():
282
  st.set_page_config(page_title="LLM Semantic Router", page_icon="πŸš€", layout="wide")
283
 
 
393
  "confidence": conf,
394
  "scores": scores,
395
  }
396
+ elif model_config["type"] == "token":
397
  entities = classify_tokens(text_input, model_config["id"])
398
  st.session_state.result = {
399
  "type": "token",
400
  "entities": entities,
401
  "text": text_input,
402
  }
403
+ else: # token_simple
404
+ entities = classify_tokens_simple(text_input, model_config["id"])
405
+ st.session_state.result = {
406
+ "type": "token_simple",
407
+ "entities": entities,
408
+ "text": text_input,
409
+ }
410
 
411
  # Display results
412
  if st.session_state.result:
 
437
  )
438
  else:
439
  st.info("βœ… No PII detected")
440
+ elif result["type"] == "token_simple":
441
+ entities = result["entities"]
442
+ # Count unauthorized tokens
443
+ unauthorized = [e for e in entities if e["type"] == "UNAUTHORIZED"]
444
+
445
+ if unauthorized:
446
+ st.error(f"⚠️ Found {len(unauthorized)} UNAUTHORIZED token(s)")
447
+ st.markdown("**Unauthorized tokens:**")
448
+ for e in unauthorized:
449
+ st.markdown(f"- `{e['text']}`")
450
+ else:
451
+ st.success("βœ… All tokens are AUTHORIZED")
452
+
453
+ st.markdown("### Token Classification")
454
+ components.html(
455
+ create_highlighted_html_simple(result["text"], entities), height=150
456
+ )
457
 
458
  # Raw Prediction Data expander
459
  with st.expander("πŸ”¬ Raw Prediction Data"):