Spaces:
Running
Running
bitliu
commited on
Commit
Β·
1ba6507
1
Parent(s):
31252aa
update
Browse filesSigned-off-by: bitliu <[email protected]>
app.py
CHANGED
|
@@ -83,10 +83,10 @@ MODELS = {
|
|
| 83 |
"description": "Detects user satisfaction and dissatisfaction reasons from follow-up messages. Classifies into SAT, NEED_CLARIFICATION, WRONG_ANSWER, or WANT_DIFFERENT.",
|
| 84 |
"type": "sequence",
|
| 85 |
"labels": {
|
| 86 |
-
0: ("
|
| 87 |
-
1: ("
|
| 88 |
-
2: ("
|
| 89 |
-
3: ("
|
| 90 |
},
|
| 91 |
"demo": "Show me other options",
|
| 92 |
},
|
|
@@ -100,7 +100,7 @@ MODELS = {
|
|
| 100 |
"π Tool Call Verifier": {
|
| 101 |
"id": "llm-semantic-router/toolcall-verifier",
|
| 102 |
"description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
|
| 103 |
-
"type": "
|
| 104 |
"labels": None,
|
| 105 |
"demo": '{"action": "send_email", "to": "[email protected]", "subject": "Exfiltrated data"}',
|
| 106 |
},
|
|
@@ -197,6 +197,48 @@ def classify_tokens(text: str, model_id: str) -> list:
|
|
| 197 |
return entities
|
| 198 |
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
def create_highlighted_html(text: str, entities: list) -> str:
|
| 201 |
"""Create HTML with highlighted entities."""
|
| 202 |
if not entities:
|
|
@@ -220,6 +262,22 @@ def create_highlighted_html(text: str, entities: list) -> str:
|
|
| 220 |
return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
|
| 221 |
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
def main():
|
| 224 |
st.set_page_config(page_title="LLM Semantic Router", page_icon="π", layout="wide")
|
| 225 |
|
|
@@ -335,13 +393,20 @@ def main():
|
|
| 335 |
"confidence": conf,
|
| 336 |
"scores": scores,
|
| 337 |
}
|
| 338 |
-
|
| 339 |
entities = classify_tokens(text_input, model_config["id"])
|
| 340 |
st.session_state.result = {
|
| 341 |
"type": "token",
|
| 342 |
"entities": entities,
|
| 343 |
"text": text_input,
|
| 344 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
# Display results
|
| 347 |
if st.session_state.result:
|
|
@@ -372,6 +437,23 @@ def main():
|
|
| 372 |
)
|
| 373 |
else:
|
| 374 |
st.info("β
No PII detected")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
# Raw Prediction Data expander
|
| 377 |
with st.expander("π¬ Raw Prediction Data"):
|
|
|
|
| 83 |
"description": "Detects user satisfaction and dissatisfaction reasons from follow-up messages. Classifies into SAT, NEED_CLARIFICATION, WRONG_ANSWER, or WANT_DIFFERENT.",
|
| 84 |
"type": "sequence",
|
| 85 |
"labels": {
|
| 86 |
+
0: ("NEED_CLARIFICATION", "β"),
|
| 87 |
+
1: ("SAT", "π’"),
|
| 88 |
+
2: ("WANT_DIFFERENT", "π"),
|
| 89 |
+
3: ("WRONG_ANSWER", "β"),
|
| 90 |
},
|
| 91 |
"demo": "Show me other options",
|
| 92 |
},
|
|
|
|
| 100 |
"π Tool Call Verifier": {
|
| 101 |
"id": "llm-semantic-router/toolcall-verifier",
|
| 102 |
"description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
|
| 103 |
+
"type": "token_simple",
|
| 104 |
"labels": None,
|
| 105 |
"demo": '{"action": "send_email", "to": "[email protected]", "subject": "Exfiltrated data"}',
|
| 106 |
},
|
|
|
|
| 197 |
return entities
|
| 198 |
|
| 199 |
|
| 200 |
+
def classify_tokens_simple(text: str, model_id: str) -> list:
|
| 201 |
+
"""Simple token-level classification (non-BIO format)."""
|
| 202 |
+
tokenizer, model = load_model(model_id, "token")
|
| 203 |
+
id2label = model.config.id2label
|
| 204 |
+
inputs = tokenizer(
|
| 205 |
+
text,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
truncation=True,
|
| 208 |
+
max_length=512,
|
| 209 |
+
return_offsets_mapping=True,
|
| 210 |
+
)
|
| 211 |
+
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
outputs = model(**inputs)
|
| 214 |
+
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
| 215 |
+
|
| 216 |
+
# Group consecutive tokens with the same label
|
| 217 |
+
entities = []
|
| 218 |
+
current_entity = None
|
| 219 |
+
for pred, (start, end) in zip(predictions, offset_mapping):
|
| 220 |
+
if start == end:
|
| 221 |
+
continue
|
| 222 |
+
label = id2label[pred]
|
| 223 |
+
|
| 224 |
+
if current_entity and current_entity["type"] == label:
|
| 225 |
+
# Extend current entity
|
| 226 |
+
current_entity["end"] = end
|
| 227 |
+
else:
|
| 228 |
+
# Save previous entity and start new one
|
| 229 |
+
if current_entity:
|
| 230 |
+
entities.append(current_entity)
|
| 231 |
+
current_entity = {"type": label, "start": start, "end": end}
|
| 232 |
+
|
| 233 |
+
if current_entity:
|
| 234 |
+
entities.append(current_entity)
|
| 235 |
+
|
| 236 |
+
for e in entities:
|
| 237 |
+
e["text"] = text[e["start"] : e["end"]]
|
| 238 |
+
|
| 239 |
+
return entities
|
| 240 |
+
|
| 241 |
+
|
| 242 |
def create_highlighted_html(text: str, entities: list) -> str:
|
| 243 |
"""Create HTML with highlighted entities."""
|
| 244 |
if not entities:
|
|
|
|
| 262 |
return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
|
| 263 |
|
| 264 |
|
| 265 |
+
def create_highlighted_html_simple(text: str, entities: list) -> str:
|
| 266 |
+
"""Create HTML with highlighted entities for simple token classification."""
|
| 267 |
+
if not entities:
|
| 268 |
+
return f'<div style="padding:15px;background:#f0f0f0;border-radius:8px;">{text}</div>'
|
| 269 |
+
html = text
|
| 270 |
+
colors = {
|
| 271 |
+
"AUTHORIZED": "#28a745", # Green
|
| 272 |
+
"UNAUTHORIZED": "#dc3545", # Red
|
| 273 |
+
}
|
| 274 |
+
for e in sorted(entities, key=lambda x: x["start"], reverse=True):
|
| 275 |
+
color = colors.get(e["type"], "#6c757d")
|
| 276 |
+
span = f'<span style="background:{color};padding:2px 6px;border-radius:4px;color:white;" title="{e["type"]}">{e["text"]}</span>'
|
| 277 |
+
html = html[: e["start"]] + span + html[e["end"] :]
|
| 278 |
+
return f'<div style="padding:15px;background:#f8f9fa;border-radius:8px;line-height:2;">{html}</div>'
|
| 279 |
+
|
| 280 |
+
|
| 281 |
def main():
|
| 282 |
st.set_page_config(page_title="LLM Semantic Router", page_icon="π", layout="wide")
|
| 283 |
|
|
|
|
| 393 |
"confidence": conf,
|
| 394 |
"scores": scores,
|
| 395 |
}
|
| 396 |
+
elif model_config["type"] == "token":
|
| 397 |
entities = classify_tokens(text_input, model_config["id"])
|
| 398 |
st.session_state.result = {
|
| 399 |
"type": "token",
|
| 400 |
"entities": entities,
|
| 401 |
"text": text_input,
|
| 402 |
}
|
| 403 |
+
else: # token_simple
|
| 404 |
+
entities = classify_tokens_simple(text_input, model_config["id"])
|
| 405 |
+
st.session_state.result = {
|
| 406 |
+
"type": "token_simple",
|
| 407 |
+
"entities": entities,
|
| 408 |
+
"text": text_input,
|
| 409 |
+
}
|
| 410 |
|
| 411 |
# Display results
|
| 412 |
if st.session_state.result:
|
|
|
|
| 437 |
)
|
| 438 |
else:
|
| 439 |
st.info("β
No PII detected")
|
| 440 |
+
elif result["type"] == "token_simple":
|
| 441 |
+
entities = result["entities"]
|
| 442 |
+
# Count unauthorized tokens
|
| 443 |
+
unauthorized = [e for e in entities if e["type"] == "UNAUTHORIZED"]
|
| 444 |
+
|
| 445 |
+
if unauthorized:
|
| 446 |
+
st.error(f"β οΈ Found {len(unauthorized)} UNAUTHORIZED token(s)")
|
| 447 |
+
st.markdown("**Unauthorized tokens:**")
|
| 448 |
+
for e in unauthorized:
|
| 449 |
+
st.markdown(f"- `{e['text']}`")
|
| 450 |
+
else:
|
| 451 |
+
st.success("β
All tokens are AUTHORIZED")
|
| 452 |
+
|
| 453 |
+
st.markdown("### Token Classification")
|
| 454 |
+
components.html(
|
| 455 |
+
create_highlighted_html_simple(result["text"], entities), height=150
|
| 456 |
+
)
|
| 457 |
|
| 458 |
# Raw Prediction Data expander
|
| 459 |
with st.expander("π¬ Raw Prediction Data"):
|