Nexari-Research commited on
Commit
fbc8175
·
verified ·
1 Parent(s): 57e3787

Update behavior_model.py

Browse files
Files changed (1) hide show
  1. behavior_model.py +252 -50
behavior_model.py CHANGED
@@ -1,21 +1,27 @@
 
1
  """
2
- behavior_model.py
3
 
4
- Lightweight neural conversation flow detector for multi-turn input.
5
- Uses a zero-shot classifier (transformers) for label scoring and
6
- embeddings heuristics to consider recent-turn context.
 
 
 
 
 
 
 
7
 
8
- Output: analyze_flow(messages) -> dict {
9
- "flow_label": str,
10
- "confidence": float,
11
- "explanation": str
12
- }
13
-
14
- Lazy-loads the zero-shot classifier and falls back to heuristics.
15
  """
16
  import threading
17
- from typing import List, Dict
18
  import traceback
 
19
 
20
  _flow_classifier = None
21
  _flow_lock = threading.Lock()
@@ -31,6 +37,30 @@ _DEFAULT_LABELS = [
31
  "closing",
32
  ]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _load_flow_model():
35
  global _flow_classifier
36
  try:
@@ -69,45 +99,93 @@ def _concat_recent_messages(messages: List[Dict], max_chars: int = 1200) -> str:
69
  break
70
  return "".join(reversed(parts)).strip()
71
 
72
- def analyze_flow(messages: List[Dict], custom_labels: List[str] = None) -> Dict:
73
- try:
74
- text_blob = _concat_recent_messages(messages, max_chars=1200)
75
- labels = custom_labels or _DEFAULT_LABELS
76
- _ensure_flow_loaded()
77
- if not _flow_classifier or not text_blob:
78
- return _heuristic_flow(text_blob)
79
-
80
- result = _flow_classifier(text_blob, candidate_labels=labels, multi_label=False)
81
- if not result or 'labels' not in result:
82
- return _heuristic_flow(text_blob)
83
-
84
- top_label = result['labels'][0]
85
- top_score = float(result['scores'][0] if result.get('scores') else 0.0)
86
- explanation_map = {
87
- "task_request": "User requests a concrete task — prefer actionable steps.",
88
- "clarification": "User asks for clarification — ask concise clarifying questions.",
89
- "follow_up": "Continuation reference prior answer and continue.",
90
- "escalation": "User shows dissatisfaction — de-escalate, propose solutions.",
91
- "small_talk": "Casual conversation — be friendly and short.",
92
- "information_seeking": "Seeking facts be concise and cite if possible.",
93
- "confirmation": "Yes/no or confirmatory — respond succinctly.",
94
- "closing": "Conversation ending — provide short wrap-up."
95
- }
96
- explanation = explanation_map.get(top_label, "Follow user's flow and be concise.")
97
- scores = {lbl: float(s) for lbl, s in zip(result.get('labels', []), result.get('scores', []))}
98
- return {
99
- "flow_label": top_label,
100
- "confidence": top_score,
101
- "scores": scores,
102
- "explanation": explanation
103
- }
104
- except Exception as e:
105
- traceback.print_exc()
106
- return _heuristic_flow(_concat_recent_messages(messages))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  def _heuristic_flow(blob: str) -> Dict:
 
109
  b = (blob or "").lower()
110
- if any(w in b for w in ["please", "could you", "can you", "i need", "i want"]):
111
  label, conf = "task_request", 0.55
112
  elif any(w in b for w in ["what do you mean", "clarify", "explain", "how so"]):
113
  label, conf = "clarification", 0.55
@@ -117,10 +195,134 @@ def _heuristic_flow(blob: str) -> Dict:
117
  label, conf = "small_talk", 0.6
118
  elif any(w in b for w in ["error", "not working", "frustrat", "angry", "problem"]):
119
  label, conf = "escalation", 0.6
120
- elif any(w in b for w in ["what is", "who is", "when is", "look up", "search"]):
121
  label, conf = "information_seeking", 0.55
122
  else:
123
  label, conf = "follow_up", 0.4
124
 
125
- explanation = f"Fallback heuristic suggests '{label}' (confidence ~{conf}). Mirror user's last message and proceed accordingly."
126
  return {"flow_label": label, "confidence": conf, "scores": {label: conf}, "explanation": explanation}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # behavior_model.py
2
  """
3
+ Improved conversation flow + complexity router.
4
 
5
+ Exports:
6
+ - analyze_flow(messages, custom_labels=None, prefer_fast=True) -> dict
7
+ Adds routing decisions:
8
+ - flow_label: str
9
+ - confidence: float
10
+ - explanation: str
11
+ - is_complex: bool
12
+ - complexity_score: float (0.0 - 1.0)
13
+ - route: "direct" | "planning" (direct => send to LLM immediately, planning => run planner)
14
+ - scores: optional dict of label scores (if classifier used)
15
 
16
+ Design goals:
17
+ - Fast-path for short/simple requests (heuristics only) to reduce latency.
18
+ - Lazy-load zero-shot classifier only when heuristics are ambiguous.
19
+ - Thread-safe lazy loading.
 
 
 
20
  """
21
  import threading
22
+ from typing import List, Dict, Any
23
  import traceback
24
+ import re
25
 
26
  _flow_classifier = None
27
  _flow_lock = threading.Lock()
 
37
  "closing",
38
  ]
39
 
40
+ # fast keyword sets and patterns for complexity heuristics
41
+ _COMPLEX_KEYWORDS = {
42
+ "task": ["implement", "create", "build", "generate", "write", "develop", "deploy", "setup", "configure", "install", "refactor", "optimize", "benchmark"],
43
+ "analysis": ["explain", "why", "how", "analyze", "analysis", "compare", "evaluate", "breakdown", "diagnose"],
44
+ "error": ["error", "exception", "traceback", "stacktrace", "crash", "bug", "not working", "fix", "debug"],
45
+ "code_signs": ["```", "def ", "function(", "class ", "import ", "console.log", "{", "};", ";", "->", "std::", "#include"],
46
+ "data": ["dataset", "csv", "json", "table", "rows", "columns", "api", "endpoint"],
47
+ "math": ["calculate", "compute", "solve", "equation", "integral", "sum", "mean", "variance"]
48
+ }
49
+
50
+ _QUESTION_WORDS = set(["what","why","how","which","when","where","who","whom","whose","do","does","did","can","could","would","should","is","are","was","were","may","might"])
51
+
52
+ # weights for heuristic scoring
53
+ _WEIGHTS = {
54
+ "word_count": 0.15,
55
+ "sentence_count": 0.05,
56
+ "has_code": 0.30,
57
+ "has_numbers_or_urls": 0.05,
58
+ "task_keywords": 0.20,
59
+ "analysis_keywords": 0.15,
60
+ "error_keywords": 0.10,
61
+ "question_words_density": 0.10
62
+ }
63
+
64
  def _load_flow_model():
65
  global _flow_classifier
66
  try:
 
99
  break
100
  return "".join(reversed(parts)).strip()
101
 
102
+ def _fast_complexity_score(text: str) -> Dict[str, Any]:
103
+ """
104
+ Returns a dict:
105
+ { score: float (0-1), features: {...}, explanation: str }
106
+ Higher score -> more complex.
107
+ """
108
+ t = (text or "").strip()
109
+ if not t:
110
+ return {"score": 0.0, "features": {}, "explanation": "empty text -> trivial"}
111
+
112
+ # basic counts
113
+ words = re.findall(r"\w+", t)
114
+ word_count = len(words)
115
+ sentence_count = max(1, len(re.findall(r"[.!?]+", t)) or 1)
116
+
117
+ # flags
118
+ lower = t.lower()
119
+ has_code = any(sig in lower for sig in _COMPLEX_KEYWORDS["code_signs"])
120
+ has_numbers = bool(re.search(r"\d+", t))
121
+ has_url = bool(re.search(r"https?://|www\.|\.[a-z]{2,4}/", lower))
122
+ has_numbers_or_urls = has_numbers or has_url
123
+
124
+ # keyword signals
125
+ task_kw = sum(1 for k in _COMPLEX_KEYWORDS["task"] if k in lower)
126
+ analysis_kw = sum(1 for k in _COMPLEX_KEYWORDS["analysis"] if k in lower)
127
+ error_kw = sum(1 for k in _COMPLEX_KEYWORDS["error"] if k in lower)
128
+ math_kw = sum(1 for k in _COMPLEX_KEYWORDS["math"] if k in lower)
129
+ data_kw = sum(1 for k in _COMPLEX_KEYWORDS["data"] if k in lower)
130
+
131
+ # question word density
132
+ qwords = sum(1 for w in re.findall(r"\w+", lower) if w in _QUESTION_WORDS)
133
+ q_density = qwords / max(1, word_count)
134
+
135
+ # compute raw score using weighted features
136
+ score = 0.0
137
+ score += min(word_count / 200.0, 1.0) * _WEIGHTS["word_count"]
138
+ score += min(sentence_count / 6.0, 1.0) * _WEIGHTS["sentence_count"]
139
+ score += (1.0 if has_code else 0.0) * _WEIGHTS["has_code"]
140
+ score += (1.0 if has_numbers_or_urls else 0.0) * _WEIGHTS["has_numbers_or_urls"]
141
+ score += min(task_kw / 3.0, 1.0) * _WEIGHTS["task_keywords"]
142
+ score += min(analysis_kw / 3.0, 1.0) * _WEIGHTS["analysis_keywords"]
143
+ score += min(error_kw / 2.0, 1.0) * _WEIGHTS["error_keywords"]
144
+ score += min(q_density * 2.0, 1.0) * _WEIGHTS["question_words_density"] # scale
145
+
146
+ # minor boosts for data/math keywords
147
+ if math_kw or data_kw:
148
+ score = min(score + 0.05, 1.0)
149
+
150
+ # normalize (weights sum > 1 so clamp)
151
+ score = max(0.0, min(score, 1.0))
152
+
153
+ features = {
154
+ "word_count": word_count,
155
+ "sentence_count": sentence_count,
156
+ "has_code": has_code,
157
+ "has_numbers_or_urls": has_numbers_or_urls,
158
+ "task_kw": task_kw,
159
+ "analysis_kw": analysis_kw,
160
+ "error_kw": error_kw,
161
+ "math_kw": math_kw,
162
+ "data_kw": data_kw,
163
+ "q_density": round(q_density, 3)
164
+ }
165
+
166
+ # plain language explanation for fast path
167
+ expl_parts = []
168
+ if has_code:
169
+ expl_parts.append("Detected code-like tokens")
170
+ if task_kw:
171
+ expl_parts.append(f"{task_kw} task-related keywords")
172
+ if analysis_kw:
173
+ expl_parts.append(f"{analysis_kw} analysis-related keywords")
174
+ if error_kw:
175
+ expl_parts.append(f"{error_kw} error/debug keywords")
176
+ if word_count > 120:
177
+ expl_parts.append("Long message (>120 words)")
178
+ if q_density > 0.2:
179
+ expl_parts.append("High question density")
180
+
181
+ explanation = "; ".join(expl_parts) if expl_parts else "No strong complex signals detected"
182
+
183
+ return {"score": round(score, 3), "features": features, "explanation": explanation}
184
 
185
  def _heuristic_flow(blob: str) -> Dict:
186
+ # basic fallback from previous implementation, slightly adapted
187
  b = (blob or "").lower()
188
+ if any(w in b for w in ["please", "could you", "can you", "i need", "i want", "please help"]):
189
  label, conf = "task_request", 0.55
190
  elif any(w in b for w in ["what do you mean", "clarify", "explain", "how so"]):
191
  label, conf = "clarification", 0.55
 
195
  label, conf = "small_talk", 0.6
196
  elif any(w in b for w in ["error", "not working", "frustrat", "angry", "problem"]):
197
  label, conf = "escalation", 0.6
198
+ elif any(w in b for w in ["what is", "who is", "when is", "look up", "search", "find"]):
199
  label, conf = "information_seeking", 0.55
200
  else:
201
  label, conf = "follow_up", 0.4
202
 
203
+ explanation = f"Fallback heuristic suggests '{label}' (confidence ~{conf})."
204
  return {"flow_label": label, "confidence": conf, "scores": {label: conf}, "explanation": explanation}
205
+
206
+ def analyze_flow(messages: List[Dict], custom_labels: List[str] = None, prefer_fast: bool = True) -> Dict:
207
+ """
208
+ Main entry.
209
+
210
+ prefer_fast: if True, prefer heuristic-only decisions when confident to reduce latency.
211
+ Returns dict with:
212
+ - flow_label, confidence, explanation
213
+ - is_complex (bool)
214
+ - complexity_score (0-1)
215
+ - route: 'direct' or 'planning'
216
+ - scores: optional (when classifier used)
217
+ """
218
+ try:
219
+ text_blob = _concat_recent_messages(messages, max_chars=1200)
220
+ labels = custom_labels or _DEFAULT_LABELS
221
+
222
+ # run fast heuristic complexity scoring on the user's latest message (most relevant)
223
+ last_user_msg = ""
224
+ if messages:
225
+ # find last user message content
226
+ for m in reversed(messages):
227
+ if m.get("role") == "user" and (m.get("content") or "").strip():
228
+ last_user_msg = m.get("content").strip()
229
+ break
230
+
231
+ fast = _fast_complexity_score(last_user_msg or text_blob)
232
+ complexity_score = float(fast.get("score", 0.0))
233
+ features = fast.get("features", {})
234
+ fast_expl = fast.get("explanation", "")
235
+
236
+ # routing heuristics thresholds (tunable)
237
+ DIRECT_THRESHOLD = 0.25 # <= -> direct (fast)
238
+ PLANNING_THRESHOLD = 0.60 # >= -> planning (complex)
239
+ ambig_low = DIRECT_THRESHOLD
240
+ ambig_high = PLANNING_THRESHOLD
241
+
242
+ # quick decision if confident and prefer_fast
243
+ if prefer_fast and (complexity_score <= ambig_low or complexity_score >= ambig_high):
244
+ route = "direct" if complexity_score <= ambig_low else "planning"
245
+ is_complex = complexity_score >= ambig_high
246
+ # attempt to pick a flow_label via heuristic (fast)
247
+ fallback = _heuristic_flow(text_blob)
248
+ label = fallback.get("flow_label", "follow_up")
249
+ conf = round(0.5 + (0.5 * complexity_score) if is_complex else 0.4, 2)
250
+ explanation = f"Fast-path decision: route='{route}'. {fast_expl} (score={complexity_score})."
251
+ return {
252
+ "flow_label": label,
253
+ "confidence": conf,
254
+ "explanation": explanation,
255
+ "is_complex": bool(is_complex),
256
+ "complexity_score": round(complexity_score, 3),
257
+ "route": route,
258
+ "features": features,
259
+ "scores": {label: conf}
260
+ }
261
+
262
+ # If ambiguous or prefer classifier, try zero-shot classifier (lazy)
263
+ _ensure_flow_loaded()
264
+ if not _flow_classifier or not text_blob:
265
+ # fallback
266
+ fallback = _heuristic_flow(text_blob)
267
+ explanation = f"Classifier unavailable; heuristic fallback. {fast_expl} (score={complexity_score})."
268
+ # route by heuristic score
269
+ route = "planning" if complexity_score >= PLANNING_THRESHOLD else "direct"
270
+ is_complex = complexity_score >= PLANNING_THRESHOLD
271
+ return {
272
+ "flow_label": fallback.get("flow_label", "follow_up"),
273
+ "confidence": fallback.get("confidence", 0.4),
274
+ "explanation": explanation,
275
+ "is_complex": bool(is_complex),
276
+ "complexity_score": round(complexity_score, 3),
277
+ "route": route,
278
+ "features": features,
279
+ "scores": fallback.get("scores")
280
+ }
281
+
282
+ # use classifier to get a more informed flow label
283
+ try:
284
+ result = _flow_classifier(text_blob, candidate_labels=labels, multi_label=False)
285
+ if not result or 'labels' not in result:
286
+ raise ValueError("classifier returned no labels")
287
+ top_label = result['labels'][0]
288
+ top_score = float(result['scores'][0] if result.get('scores') else 0.0)
289
+ # decide complexity/route combining classifier and heuristic
290
+ is_complex = complexity_score >= PLANNING_THRESHOLD or top_label in ("task_request", "escalation", "information_seeking")
291
+ route = "planning" if is_complex or top_score < 0.5 else "direct"
292
+ explanation = (
293
+ f"Classifier suggests '{top_label}' (score={round(top_score,2)}). "
294
+ f"Heuristic complexity score={complexity_score} ({fast_expl}). Routed to '{route}'."
295
+ )
296
+ scores = {lbl: float(s) for lbl, s in zip(result.get('labels', []), result.get('scores', []))}
297
+ return {
298
+ "flow_label": top_label,
299
+ "confidence": round(top_score, 3),
300
+ "explanation": explanation,
301
+ "is_complex": bool(is_complex),
302
+ "complexity_score": round(complexity_score, 3),
303
+ "route": route,
304
+ "features": features,
305
+ "scores": scores
306
+ }
307
+ except Exception as e:
308
+ # classifier error -> fallback heuristics
309
+ traceback.print_exc()
310
+ fallback = _heuristic_flow(text_blob)
311
+ route = "planning" if complexity_score >= PLANNING_THRESHOLD else "direct"
312
+ explanation = f"Classifier error; fallback to heuristic. {fast_expl} (score={complexity_score}). Error: {e}"
313
+ return {
314
+ "flow_label": fallback.get("flow_label", "follow_up"),
315
+ "confidence": fallback.get("confidence", 0.4),
316
+ "explanation": explanation,
317
+ "is_complex": complexity_score >= PLANNING_THRESHOLD,
318
+ "complexity_score": round(complexity_score, 3),
319
+ "route": route,
320
+ "features": features,
321
+ "scores": fallback.get("scores")
322
+ }
323
+
324
+ except Exception as e:
325
+ traceback.print_exc()
326
+ return _heuristic_flow(_concat_recent_messages(messages))
327
+
328
+ # End of behavior_model.py