Nexari-Research commited on
Commit
9d1f57d
·
verified ·
1 Parent(s): 9d0288b

Update behavior_model.py

Browse files
Files changed (1) hide show
  1. behavior_model.py +348 -318
behavior_model.py CHANGED
@@ -1,328 +1,358 @@
1
- # behavior_model.py
2
  """
3
- Improved conversation flow + complexity router.
4
-
5
- Exports:
6
- - analyze_flow(messages, custom_labels=None, prefer_fast=True) -> dict
7
- Adds routing decisions:
8
- - flow_label: str
9
- - confidence: float
10
- - explanation: str
11
- - is_complex: bool
12
- - complexity_score: float (0.0 - 1.0)
13
- - route: "direct" | "planning" (direct => send to LLM immediately, planning => run planner)
14
- - scores: optional dict of label scores (if classifier used)
15
-
16
- Design goals:
17
- - Fast-path for short/simple requests (heuristics only) to reduce latency.
18
- - Lazy-load zero-shot classifier only when heuristics are ambiguous.
19
- - Thread-safe lazy loading.
20
  """
21
- import threading
22
- from typing import List, Dict, Any
23
- import traceback
24
- import re
25
 
26
- _flow_classifier = None
27
- _flow_lock = threading.Lock()
28
-
29
- _DEFAULT_LABELS = [
30
- "task_request",
31
- "clarification",
32
- "follow_up",
33
- "escalation",
34
- "small_talk",
35
- "information_seeking",
36
- "confirmation",
37
- "closing",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ]
39
 
40
- # fast keyword sets and patterns for complexity heuristics
41
- _COMPLEX_KEYWORDS = {
42
- "task": ["implement", "create", "build", "generate", "write", "develop", "deploy", "setup", "configure", "install", "refactor", "optimize", "benchmark"],
43
- "analysis": ["explain", "why", "how", "analyze", "analysis", "compare", "evaluate", "breakdown", "diagnose"],
44
- "error": ["error", "exception", "traceback", "stacktrace", "crash", "bug", "not working", "fix", "debug"],
45
- "code_signs": ["```", "def ", "function(", "class ", "import ", "console.log", "{", "};", ";", "->", "std::", "#include"],
46
- "data": ["dataset", "csv", "json", "table", "rows", "columns", "api", "endpoint"],
47
- "math": ["calculate", "compute", "solve", "equation", "integral", "sum", "mean", "variance"]
48
- }
49
-
50
- _QUESTION_WORDS = set(["what","why","how","which","when","where","who","whom","whose","do","does","did","can","could","would","should","is","are","was","were","may","might"])
51
-
52
- # weights for heuristic scoring
53
- _WEIGHTS = {
54
- "word_count": 0.15,
55
- "sentence_count": 0.05,
56
- "has_code": 0.30,
57
- "has_numbers_or_urls": 0.05,
58
- "task_keywords": 0.20,
59
- "analysis_keywords": 0.15,
60
- "error_keywords": 0.10,
61
- "question_words_density": 0.10
62
- }
63
-
64
- def _load_flow_model():
65
- global _flow_classifier
66
- try:
67
- from transformers import pipeline
68
- _flow_classifier = pipeline("zero-shot-classification",
69
- model="typeform/distilbert-base-uncased-mnli")
70
- except Exception as e:
71
- print("BehaviorModel: failed to load flow classifier:", e)
72
- _flow_classifier = None
73
-
74
- def _ensure_flow_loaded():
75
- if _flow_classifier is None:
76
- with _flow_lock:
77
- if _flow_classifier is None:
78
- _load_flow_model()
79
-
80
- def _concat_recent_messages(messages: List[Dict], max_chars: int = 1200) -> str:
81
  if not messages:
82
- return ""
83
- rev = list(reversed(messages))
84
- parts = []
85
- total = 0
86
- for m in rev:
87
- c = (m.get("content") or "").strip()
88
- if not c:
89
- continue
90
- add = f"{m.get('role','user')}: {c}\n"
91
- if total + len(add) > max_chars:
92
- remaining = max_chars - total
93
- if remaining <= 0:
94
- break
95
- add = add[:remaining]
96
- parts.append(add)
97
- total += len(add)
98
- if total >= max_chars:
99
- break
100
- return "".join(reversed(parts)).strip()
101
-
102
- def _fast_complexity_score(text: str) -> Dict[str, Any]:
103
- """
104
- Returns a dict:
105
- { score: float (0-1), features: {...}, explanation: str }
106
- Higher score -> more complex.
 
 
 
 
 
 
 
 
 
 
 
 
107
  """
108
- t = (text or "").strip()
109
- if not t:
110
- return {"score": 0.0, "features": {}, "explanation": "empty text -> trivial"}
111
-
112
- # basic counts
113
- words = re.findall(r"\w+", t)
114
- word_count = len(words)
115
- sentence_count = max(1, len(re.findall(r"[.!?]+", t)) or 1)
116
-
117
- # flags
118
- lower = t.lower()
119
- has_code = any(sig in lower for sig in _COMPLEX_KEYWORDS["code_signs"])
120
- has_numbers = bool(re.search(r"\d+", t))
121
- has_url = bool(re.search(r"https?://|www\.|\.[a-z]{2,4}/", lower))
122
- has_numbers_or_urls = has_numbers or has_url
123
-
124
- # keyword signals
125
- task_kw = sum(1 for k in _COMPLEX_KEYWORDS["task"] if k in lower)
126
- analysis_kw = sum(1 for k in _COMPLEX_KEYWORDS["analysis"] if k in lower)
127
- error_kw = sum(1 for k in _COMPLEX_KEYWORDS["error"] if k in lower)
128
- math_kw = sum(1 for k in _COMPLEX_KEYWORDS["math"] if k in lower)
129
- data_kw = sum(1 for k in _COMPLEX_KEYWORDS["data"] if k in lower)
130
-
131
- # question word density
132
- qwords = sum(1 for w in re.findall(r"\w+", lower) if w in _QUESTION_WORDS)
133
- q_density = qwords / max(1, word_count)
134
-
135
- # compute raw score using weighted features
136
- score = 0.0
137
- score += min(word_count / 200.0, 1.0) * _WEIGHTS["word_count"]
138
- score += min(sentence_count / 6.0, 1.0) * _WEIGHTS["sentence_count"]
139
- score += (1.0 if has_code else 0.0) * _WEIGHTS["has_code"]
140
- score += (1.0 if has_numbers_or_urls else 0.0) * _WEIGHTS["has_numbers_or_urls"]
141
- score += min(task_kw / 3.0, 1.0) * _WEIGHTS["task_keywords"]
142
- score += min(analysis_kw / 3.0, 1.0) * _WEIGHTS["analysis_keywords"]
143
- score += min(error_kw / 2.0, 1.0) * _WEIGHTS["error_keywords"]
144
- score += min(q_density * 2.0, 1.0) * _WEIGHTS["question_words_density"] # scale
145
-
146
- # minor boosts for data/math keywords
147
- if math_kw or data_kw:
148
- score = min(score + 0.05, 1.0)
149
-
150
- # normalize (weights sum > 1 so clamp)
151
- score = max(0.0, min(score, 1.0))
152
-
153
- features = {
154
- "word_count": word_count,
155
- "sentence_count": sentence_count,
156
- "has_code": has_code,
157
- "has_numbers_or_urls": has_numbers_or_urls,
158
- "task_kw": task_kw,
159
- "analysis_kw": analysis_kw,
160
- "error_kw": error_kw,
161
- "math_kw": math_kw,
162
- "data_kw": data_kw,
163
- "q_density": round(q_density, 3)
164
  }
165
-
166
- # plain language explanation for fast path
167
- expl_parts = []
168
- if has_code:
169
- expl_parts.append("Detected code-like tokens")
170
- if task_kw:
171
- expl_parts.append(f"{task_kw} task-related keywords")
172
- if analysis_kw:
173
- expl_parts.append(f"{analysis_kw} analysis-related keywords")
174
- if error_kw:
175
- expl_parts.append(f"{error_kw} error/debug keywords")
176
- if word_count > 120:
177
- expl_parts.append("Long message (>120 words)")
178
- if q_density > 0.2:
179
- expl_parts.append("High question density")
180
-
181
- explanation = "; ".join(expl_parts) if expl_parts else "No strong complex signals detected"
182
-
183
- return {"score": round(score, 3), "features": features, "explanation": explanation}
184
-
185
- def _heuristic_flow(blob: str) -> Dict:
186
- # basic fallback from previous implementation, slightly adapted
187
- b = (blob or "").lower()
188
- if any(w in b for w in ["please", "could you", "can you", "i need", "i want", "please help"]):
189
- label, conf = "task_request", 0.55
190
- elif any(w in b for w in ["what do you mean", "clarify", "explain", "how so"]):
191
- label, conf = "clarification", 0.55
192
- elif any(w in b for w in ["thanks", "thank you", "bye", "goodbye", "see you"]):
193
- label, conf = "closing", 0.7
194
- elif any(w in b for w in ["hi", "hello", "hey", "namaste"]):
195
- label, conf = "small_talk", 0.6
196
- elif any(w in b for w in ["error", "not working", "frustrat", "angry", "problem"]):
197
- label, conf = "escalation", 0.6
198
- elif any(w in b for w in ["what is", "who is", "when is", "look up", "search", "find"]):
199
- label, conf = "information_seeking", 0.55
 
 
 
 
 
 
200
  else:
201
- label, conf = "follow_up", 0.4
202
-
203
- explanation = f"Fallback heuristic suggests '{label}' (confidence ~{conf})."
204
- return {"flow_label": label, "confidence": conf, "scores": {label: conf}, "explanation": explanation}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- def analyze_flow(messages: List[Dict], custom_labels: List[str] = None, prefer_fast: bool = True) -> Dict:
207
- """
208
- Main entry.
209
-
210
- prefer_fast: if True, prefer heuristic-only decisions when confident to reduce latency.
211
- Returns dict with:
212
- - flow_label, confidence, explanation
213
- - is_complex (bool)
214
- - complexity_score (0-1)
215
- - route: 'direct' or 'planning'
216
- - scores: optional (when classifier used)
217
- """
218
- try:
219
- text_blob = _concat_recent_messages(messages, max_chars=1200)
220
- labels = custom_labels or _DEFAULT_LABELS
221
-
222
- # run fast heuristic complexity scoring on the user's latest message (most relevant)
223
- last_user_msg = ""
224
- if messages:
225
- # find last user message content
226
- for m in reversed(messages):
227
- if m.get("role") == "user" and (m.get("content") or "").strip():
228
- last_user_msg = m.get("content").strip()
229
- break
230
-
231
- fast = _fast_complexity_score(last_user_msg or text_blob)
232
- complexity_score = float(fast.get("score", 0.0))
233
- features = fast.get("features", {})
234
- fast_expl = fast.get("explanation", "")
235
-
236
- # routing heuristics thresholds (tunable)
237
- DIRECT_THRESHOLD = 0.25 # <= -> direct (fast)
238
- PLANNING_THRESHOLD = 0.60 # >= -> planning (complex)
239
- ambig_low = DIRECT_THRESHOLD
240
- ambig_high = PLANNING_THRESHOLD
241
-
242
- # quick decision if confident and prefer_fast
243
- if prefer_fast and (complexity_score <= ambig_low or complexity_score >= ambig_high):
244
- route = "direct" if complexity_score <= ambig_low else "planning"
245
- is_complex = complexity_score >= ambig_high
246
- # attempt to pick a flow_label via heuristic (fast)
247
- fallback = _heuristic_flow(text_blob)
248
- label = fallback.get("flow_label", "follow_up")
249
- conf = round(0.5 + (0.5 * complexity_score) if is_complex else 0.4, 2)
250
- explanation = f"Fast-path decision: route='{route}'. {fast_expl} (score={complexity_score})."
251
- return {
252
- "flow_label": label,
253
- "confidence": conf,
254
- "explanation": explanation,
255
- "is_complex": bool(is_complex),
256
- "complexity_score": round(complexity_score, 3),
257
- "route": route,
258
- "features": features,
259
- "scores": {label: conf}
260
- }
261
-
262
- # If ambiguous or prefer classifier, try zero-shot classifier (lazy)
263
- _ensure_flow_loaded()
264
- if not _flow_classifier or not text_blob:
265
- # fallback
266
- fallback = _heuristic_flow(text_blob)
267
- explanation = f"Classifier unavailable; heuristic fallback. {fast_expl} (score={complexity_score})."
268
- # route by heuristic score
269
- route = "planning" if complexity_score >= PLANNING_THRESHOLD else "direct"
270
- is_complex = complexity_score >= PLANNING_THRESHOLD
271
- return {
272
- "flow_label": fallback.get("flow_label", "follow_up"),
273
- "confidence": fallback.get("confidence", 0.4),
274
- "explanation": explanation,
275
- "is_complex": bool(is_complex),
276
- "complexity_score": round(complexity_score, 3),
277
- "route": route,
278
- "features": features,
279
- "scores": fallback.get("scores")
280
- }
281
-
282
- # use classifier to get a more informed flow label
283
- try:
284
- result = _flow_classifier(text_blob, candidate_labels=labels, multi_label=False)
285
- if not result or 'labels' not in result:
286
- raise ValueError("classifier returned no labels")
287
- top_label = result['labels'][0]
288
- top_score = float(result['scores'][0] if result.get('scores') else 0.0)
289
- # decide complexity/route combining classifier and heuristic
290
- is_complex = complexity_score >= PLANNING_THRESHOLD or top_label in ("task_request", "escalation", "information_seeking")
291
- route = "planning" if is_complex or top_score < 0.5 else "direct"
292
- explanation = (
293
- f"Classifier suggests '{top_label}' (score={round(top_score,2)}). "
294
- f"Heuristic complexity score={complexity_score} ({fast_expl}). Routed to '{route}'."
295
- )
296
- scores = {lbl: float(s) for lbl, s in zip(result.get('labels', []), result.get('scores', []))}
297
- return {
298
- "flow_label": top_label,
299
- "confidence": round(top_score, 3),
300
- "explanation": explanation,
301
- "is_complex": bool(is_complex),
302
- "complexity_score": round(complexity_score, 3),
303
- "route": route,
304
- "features": features,
305
- "scores": scores
306
- }
307
- except Exception as e:
308
- # classifier error -> fallback heuristics
309
- traceback.print_exc()
310
- fallback = _heuristic_flow(text_blob)
311
- route = "planning" if complexity_score >= PLANNING_THRESHOLD else "direct"
312
- explanation = f"Classifier error; fallback to heuristic. {fast_expl} (score={complexity_score}). Error: {e}"
313
- return {
314
- "flow_label": fallback.get("flow_label", "follow_up"),
315
- "confidence": fallback.get("confidence", 0.4),
316
- "explanation": explanation,
317
- "is_complex": complexity_score >= PLANNING_THRESHOLD,
318
- "complexity_score": round(complexity_score, 3),
319
- "route": route,
320
- "features": features,
321
- "scores": fallback.get("scores")
322
- }
323
-
324
- except Exception as e:
325
- traceback.print_exc()
326
- return _heuristic_flow(_concat_recent_messages(messages))
327
-
328
- # End of behavior_model.py
 
1
+ # behavior_model.py -- REPLACED with "Neural Structure / MoE-style Dispatcher"
2
  """
3
+ Large, modular 'neural structure' dispatcher (software MoE) for intent/complexity routing.
4
+
5
+ How to use:
6
+ - Replace your existing behavior_model.py with this file.
7
+ - app.py expects analyze_flow(messages) -> dict with keys:
8
+ { route: "direct"|"planning", is_complex: bool, flow_label: str, confidence: float, explanation: str, experts: [...] }
9
+
10
+ Design:
11
+ - Feature extractor -> gating network (scoring) -> top-K expert selection -> combine/explain decision
12
+ - Experts are modular callables; by default they are heuristic "experts".
13
+ - To scale: implement Expert.run(...) to call real submodels/endpoints (local small models, remote microservices).
 
 
 
 
 
 
14
  """
 
 
 
 
15
 
16
+ from typing import List, Dict, Any, Callable, Tuple
17
+ import re
18
+ import math
19
+ import json
20
+ import os
21
+ import statistics
22
+
23
+ # ---------- Configurable constants ----------
24
+ TOP_K = int(os.environ.get("NS_TOP_K", "2")) # how many experts to activate per request
25
+ SOFTMAX_TEMPERATURE = float(os.environ.get("NS_TEMP", "1.0"))
26
+ MIN_COMPLEX_CONF_FOR_PLANNING = float(os.environ.get("NS_MIN_COMPLEX_CONF", "0.56"))
27
+ MAX_EXPERTS = int(os.environ.get("NS_MAX_EXPERTS", "12"))
28
+
29
+ # Weights (tunables)
30
+ WEIGHT_LENGTH = float(os.environ.get("NS_W_LENGTH", "1.0"))
31
+ WEIGHT_KEYWORD = float(os.environ.get("NS_W_KEYWORD", "1.9"))
32
+ WEIGHT_CODE = float(os.environ.get("NS_W_CODE", "2.4"))
33
+ WEIGHT_NUMERIC = float(os.environ.get("NS_W_NUMERIC", "1.2"))
34
+ WEIGHT_QUESTION = float(os.environ.get("NS_W_QUESTION", "0.6"))
35
+ WEIGHT_URGENT = float(os.environ.get("NS_W_URGENT", "2.2"))
36
+ WEIGHT_HISTORY = float(os.environ.get("NS_W_HISTORY", "0.8"))
37
+
38
+ # ---------- Regex / keyword lists ----------
39
+ _code_fence_re = re.compile(r"```.+?```", flags=re.DOTALL | re.IGNORECASE)
40
+ _inline_code_re = re.compile(r"`[^`]+`")
41
+ _number_re = re.compile(r"\b\d+(\.\d+)?\b")
42
+ _list_marker_re = re.compile(r"(^\s*[-*•]\s+)|(^\s*\d+\.\s+)", flags=re.MULTILINE)
43
+ _url_re = re.compile(r"https?://\S+")
44
+ _question_word_re = re.compile(r"^\s*(who|what|why|how|when|which|where)\b", flags=re.IGNORECASE)
45
+ _question_mark_re = re.compile(r"\?$")
46
+
47
+ _task_keywords = set(k.lower() for k in [
48
+ "build", "create", "implement", "develop", "deploy", "install", "setup", "configure",
49
+ "optimi", "debug", "fix", "error", "crash", "stacktrace", "exception", "traceback",
50
+ "code", "script", "function", "api", "endpoint", "database", "sql", "mongodb", "mysql",
51
+ "docker", "deno", "node", "express", "php", "python", "java", "rust", "golang", "compile",
52
+ "performance", "latency", "bandwidth", "optimization", "optimize",
53
+ "algorithm", "complexity", "big o", "time complexity", "space complexity",
54
+ "report", "plan", "design", "architecture", "integration", "migrate", "refactor",
55
+ "test case", "unit test", "e2e test",
56
+ "prove", "derive", "integral", "differentiate", "matrix", "neural network", "train", "model",
57
+ ])
58
+
59
+ _urgent_words = set(w.lower() for w in ["urgent", "asap", "immediately", "now", "critical", "important", "priority", "must"])
60
+ _short_chat_terms = set(w.lower() for w in ["hi", "hello", "thanks", "thank you", "bye", "ok", "okay", "nice", "cool", "🙂", "😊"])
61
+
62
+ # ---------- Utility functions ----------
63
+ def _word_count(text: str) -> int:
64
+ return len(re.findall(r"\w+", text)) if text else 0
65
+
66
+ def _has_code(text: str) -> bool:
67
+ if not text: return False
68
+ return bool(_code_fence_re.search(text) or _inline_code_re.search(text) or re.search(r"\bdef\s+\w+\(|;\s*$", text, flags=re.IGNORECASE))
69
+
70
+ def _has_list(text: str) -> bool:
71
+ return bool(_list_marker_re.search(text))
72
+
73
+ def _keyword_matches(text: str) -> int:
74
+ if not text: return 0
75
+ t = text.lower()
76
+ cnt = 0
77
+ for kw in _task_keywords:
78
+ if kw in t:
79
+ cnt += 1
80
+ return cnt
81
+
82
+ def _numeric_count(text: str) -> int:
83
+ return len(_number_re.findall(text or ""))
84
+
85
+ def _is_urgent(text: str) -> bool:
86
+ t = (text or "").lower()
87
+ return any(w in t for w in _urgent_words)
88
+
89
+ def _short_chat_score(text: str) -> bool:
90
+ t = (text or "").strip().lower()
91
+ if len(t.split()) <= 2 and any(tok in t for tok in _short_chat_terms):
92
+ return True
93
+ return False
94
+
95
+ def _question_score(text: str) -> float:
96
+ s = 0.0
97
+ if _question_mark_re.search(text or ""): s += 1.0
98
+ if _question_word_re.match((text or "").strip()): s += 0.6
99
+ return s
100
+
101
+ def _history_signal(messages: List[Dict[str,str]]) -> float:
102
+ # simple heuristic: if previous user messages contained technical keywords recently, boost
103
+ if not messages or len(messages) < 2: return 0.0
104
+ prev = " ".join(m.get("content","") for m in messages[-4:-1] if isinstance(m, dict))
105
+ return float(min(3, _keyword_matches(prev))) * 0.2
106
+
107
+ # ---------- Softmax helper ----------
108
+ def _softmax(scores: List[float], temp: float = 1.0) -> List[float]:
109
+ if not scores:
110
+ return []
111
+ exps = [math.exp(s / temp) for s in scores]
112
+ s = sum(exps)
113
+ if s == 0: return [1.0/len(scores)]*len(scores)
114
+ return [e/s for e in exps]
115
+
116
+ # ---------- Expert base classes ----------
117
+ class Expert:
118
+ name: str
119
+ description: str
120
+
121
+ def __init__(self, name:str, description:str):
122
+ self.name = name
123
+ self.description = description
124
+
125
+ def score(self, features: Dict[str,Any]) -> float:
126
+ """Return a heuristic affinity score (higher = more relevant)."""
127
+ # default neutral
128
+ return 0.0
129
+
130
+ def run(self, messages: List[Dict[str,str]], features: Dict[str,Any]) -> Dict[str,Any]:
131
+ """
132
+ Optionally run expert-specific logic (synchronously).
133
+ For now return metadata only. In production this could call a model endpoint.
134
+ """
135
+ return {"expert": self.name, "action": "noop", "note": "heuristic-only"}
136
+
137
+ # ---------- Concrete experts ----------
138
+ class ShortChatExpert(Expert):
139
+ def __init__(self):
140
+ super().__init__("short_chat", "Handles greetings/short conversational turns")
141
+
142
+ def score(self, f):
143
+ if f.get("short_chat"): return 5.0
144
+ return 0.1
145
+
146
+ def run(self, messages, features):
147
+ return {"expert": self.name, "action": "short_reply", "note": "Use concise response template."}
148
+
149
+ class CodeExpert(Expert):
150
+ def __init__(self):
151
+ super().__init__("code_expert", "Handles code, stacktraces, debugging tasks")
152
+
153
+ def score(self, f):
154
+ sc = 0.0
155
+ if f.get("has_code"): sc += 4.0
156
+ sc += 0.8 * f.get("kw_count",0)
157
+ sc += 0.6 * f.get("numeric_count",0)
158
+ return sc
159
+
160
+ def run(self, messages, features):
161
+ # Placeholder: in production call a code-specialized model or analyzer endpoint
162
+ return {"expert": self.name, "action": "analyze_code", "note": "Run code LLM or static-checker (not implemented)."}
163
+
164
+ class NLUExpert(Expert):
165
+ def __init__(self):
166
+ super().__init__("nlu_expert", "Deep intent and slot extraction / classification")
167
+
168
+ def score(self, f):
169
+ sc = 1.0 * f.get("kw_count",0)
170
+ sc += 0.8 * f.get("question_score",0)
171
+ sc += 0.4 * (f.get("word_count",0) / 30.0)
172
+ sc += 0.6 * f.get("history_signal",0)
173
+ return sc
174
+
175
+ def run(self, messages, features):
176
+ # Example: return intent classification tags (heuristic)
177
+ intent = "general"
178
+ if features.get("kw_count",0) >= 2 or features.get("has_code"):
179
+ intent = "technical_task"
180
+ elif features.get("short_chat"):
181
+ intent = "social"
182
+ return {"expert": self.name, "action": "classify_intent", "intent": intent}
183
+
184
+ class RAGExpert(Expert):
185
+ def __init__(self):
186
+ super().__init__("rag_expert", "Handles retrieval-augmented requests (RAG/agent)")
187
+
188
+ def score(self, f):
189
+ sc = 0.0
190
+ # if user mentions 'search', 'latest', has urls, or long context -> RAG useful
191
+ if f.get("has_url"): sc += 2.0
192
+ sc += 1.2 * f.get("kw_count",0)
193
+ sc += 0.9 * f.get("numeric_count",0)
194
+ if f.get("word_count",0) > 60: sc += 1.5
195
+ return sc
196
+
197
+ def run(self, messages, features):
198
+ # Placeholder: should trigger a retrieval job or agent
199
+ return {"expert": self.name, "action": "retrieve", "note": "Trigger RAG pipeline or agent (not implemented)."}
200
+
201
+ class SafetyExpert(Expert):
202
+ def __init__(self):
203
+ super().__init__("safety_expert", "Safety checks, identity questions, hallucination guards")
204
+
205
+ def score(self, f):
206
+ sc = 0.0
207
+ txt = f.get("last_text","").lower() if f.get("last_text") else ""
208
+ if any(w in txt for w in ["who created you","who made you","identity","where are you from"]):
209
+ sc += 3.0
210
+ # any suspicious tokens (email, ssn, credit card-like) -> safety
211
+ if re.search(r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", txt):
212
+ sc += 4.0
213
+ return sc
214
+
215
+ def run(self, messages, features):
216
+ return {"expert": self.name, "action": "safety_check", "note": "Run policy checks."}
217
+
218
+ # Add more experts as needed...
219
+ _DEFAULT_EXPERTS: List[Expert] = [
220
+ ShortChatExpert(),
221
+ NLUExpert(),
222
+ CodeExpert(),
223
+ RAGExpert(),
224
+ SafetyExpert(),
225
  ]
226
 
227
+ # ---------- Core gating/routing function ----------
228
+ def _extract_features(messages: List[Dict[str,str]]) -> Dict[str,Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  if not messages:
230
+ return {"word_count": 0, "kw_count":0, "has_code": False, "numeric_count":0, "question_score":0.0, "short_chat": False, "has_url": False, "history_signal":0.0, "last_text":""}
231
+ last = messages[-1].get("content","") if isinstance(messages[-1], dict) else str(messages[-1])
232
+ prev = " ".join(m.get("content","") for m in messages[:-1] if isinstance(m, dict))
233
+ full = (prev + "\n" + last).strip()
234
+
235
+ features = {}
236
+ features["last_text"] = last
237
+ features["word_count"] = _word_count(last)
238
+ features["total_word_count"] = _word_count(full)
239
+ features["kw_count"] = _keyword_matches(full)
240
+ features["has_code"] = _has_code(full)
241
+ features["has_list"] = _has_list(full)
242
+ features["numeric_count"] = _numeric_count(full)
243
+ features["question_score"] = _question_score(last)
244
+ features["short_chat"] = _short_chat_score(last)
245
+ features["has_url"] = bool(_url_re.search(full))
246
+ features["is_urgent"] = _is_urgent(full)
247
+ features["history_signal"] = _history_signal(messages)
248
+ return features
249
+
250
+ def _gate_select_experts(features: Dict[str,Any], experts: List[Expert]) -> Tuple[List[Tuple[Expert,float]], List[float]]:
251
+ # compute raw scores per expert
252
+ raw_scores = [max(0.0, e.score(features)) for e in experts]
253
+ if not raw_scores:
254
+ return [], []
255
+
256
+ # normalize via softmax for relative weighting
257
+ probs = _softmax(raw_scores, temp=SOFTMAX_TEMPERATURE)
258
+ # select top-K experts by probability
259
+ indexed = list(enumerate(probs))
260
+ indexed.sort(key=lambda x: x[1], reverse=True)
261
+ top = indexed[:TOP_K]
262
+ chosen = [(experts[i], probs[i]) for i, _ in top]
263
+ return chosen, probs
264
+
265
+ # ---------- Public API: analyze_flow ----------
266
+ def analyze_flow(messages: List[Dict[str,str]]) -> Dict[str,Any]:
267
  """
268
+ Returns:
269
+ {
270
+ "route": "direct" / "planning",
271
+ "is_complex": bool,
272
+ "flow_label": str,
273
+ "confidence": float,
274
+ "explanation": str,
275
+ "experts": [ {"name":.., "score":.., "note":..}, ... ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  }
277
+ """
278
+ features = _extract_features(messages)
279
+ experts = _DEFAULT_EXPERTS.copy()
280
+
281
+ # gating
282
+ chosen, probs = _gate_select_experts(features, experts)
283
+
284
+ # Decide flow_label heuristics based on features
285
+ flow_label = "general"
286
+ if features.get("has_code") or features.get("kw_count",0) >= 2:
287
+ flow_label = "coding_request"
288
+ elif features.get("is_urgent"):
289
+ flow_label = "escalation"
290
+ elif features.get("kw_count",0) >= 1 and features.get("word_count",0) >= 25:
291
+ flow_label = "task_request"
292
+ elif features.get("short_chat"):
293
+ flow_label = "short_chat"
294
+ elif features.get("question_score",0) > 0.9 and features.get("word_count",0) < 25:
295
+ flow_label = "short_question"
296
+
297
+ # compute a complexity/confidence scalar from features + expert probs
298
+ feature_score = (
299
+ WEIGHT_LENGTH * (features.get("word_count",0) / 30.0) +
300
+ WEIGHT_KEYWORD * features.get("kw_count",0) +
301
+ WEIGHT_CODE * (4.0 if features.get("has_code") else 0.0) +
302
+ WEIGHT_NUMERIC * features.get("numeric_count",0) +
303
+ WEIGHT_QUESTION * features.get("question_score",0) +
304
+ WEIGHT_URGENT * (1.0 if features.get("is_urgent") else 0.0) +
305
+ WEIGHT_HISTORY * features.get("history_signal",0)
306
+ )
307
+
308
+ # Map to 0..1 via logistic
309
+ conf = 1.0 / (1.0 + math.exp(-0.45 * (feature_score - 2.0)))
310
+ conf = max(0.0, min(1.0, conf))
311
+
312
+ # route decision
313
+ is_complex = conf >= MIN_COMPLEX_CONF_FOR_PLANNING or features.get("has_code") or features.get("kw_count",0) >= 2
314
+ # short-chat override: always direct
315
+ if features.get("short_chat"):
316
+ route = "direct"
317
+ is_complex = False
318
  else:
319
+ route = "planning" if is_complex else "direct"
320
+
321
+ # Build explanation and expert list
322
+ expert_list = []
323
+ for e, p in chosen:
324
+ # we can call run() here for metadata without actually executing heavy ops
325
+ meta = e.run(messages, features)
326
+ expert_list.append({"name": e.name, "prob": round(float(p),4), "meta": meta})
327
+
328
+ explanation = ("features=" + json.dumps(features) + f" | feature_score={feature_score:.2f} | conf={conf:.3f} | chosen={[e.name for e,_ in chosen]}")
329
+ return {
330
+ "route": route,
331
+ "is_complex": bool(is_complex),
332
+ "flow_label": flow_label,
333
+ "confidence": round(float(conf), 3),
334
+ "explanation": explanation,
335
+ "experts": expert_list
336
+ }
337
 
338
+ # ---------- Debug helper ----------
339
+ def debug_flow(text: str, history: List[str] = None):
340
+ hist_msgs = [{"role":"user","content":h} for h in (history or [])]
341
+ hist_msgs.append({"role":"user","content": text})
342
+ return analyze_flow(hist_msgs)
343
+
344
+ # Example self-test when run directly
345
+ if __name__ == "__main__":
346
+ tests = [
347
+ "Hi 🙂",
348
+ "What is your name?",
349
+ "What is neural network",
350
+ "My app crashes with TypeError: undefined is not a function. Stacktrace: ```TypeError: ...``` How to fix?",
351
+ "Deploy my node app to Docker with Nginx and SSL — step-by-step please.",
352
+ "Quick: 2+2?"
353
+ ]
354
+ for t in tests:
355
+ print("----")
356
+ print("MSG:", t)
357
+ out = debug_flow(t)
358
+ print(json.dumps(out, indent=2))