File size: 8,185 Bytes
0646b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import json
import uuid
from typing import Literal, Optional, Dict, Callable

from langchain_core.messages import HumanMessage, ToolCall, BaseMessage
from loguru import logger

from cuga.backend.activity_tracker.tracker import ActivityTracker, Step
from cuga.backend.cuga_graph.nodes.shared.base_agent import create_partial
from cuga.backend.cuga_graph.nodes.chat.chat_agent.chat_agent import ChatAgent
from cuga.backend.cuga_graph.nodes.shared.base_node import BaseNode
from cuga.backend.cuga_graph.nodes.human_in_the_loop.followup_model import (
    create_flow_approve,
    create_new_flow_approve,
)
from cuga.backend.cuga_graph.state.agent_state import AgentState
from cuga.backend.cuga_graph.utils.nodes_names import NodeNames, ActionIds

from langgraph.types import Command
from cuga.config import settings


tracker = ActivityTracker()

ENABLE_SAVE_REUSE = settings.features.save_reuse


class ChatHumanInTheLoopHandler:
    """Handler for chat-specific human-in-the-loop interactions"""

    def __init__(self):
        self._action_handlers: Dict[str, Callable] = {
            ActionIds.FLOW_APPROVE: self._handle_tool_execute,
            # Add chat-specific action handlers here
            # Example: ActionIds.TOOL_EXECUTE: self._handle_tool_execute,
            # ActionIds.CHAT_CONTINUE: self._handle_chat_continue,
        }

    def handle_human_response(self, state: AgentState, node_name: str) -> Command:
        """Handle any human response based on action_id"""
        action_id = state.hitl_response.action_id

        if action_id in self._action_handlers:
            return self._action_handlers[action_id](state, node_name)

        # Default fallback for chat - continue to final answer
        return Command(update=state.model_dump(), goto=NodeNames.FINAL_ANSWER_AGENT)

    def add_action_handler(self, action_id: str, handler: Callable):
        """Add a custom action handler"""
        self._action_handlers[action_id] = handler

    def _handle_tool_execute(self, state: AgentState, node_name: str) -> Command:
        """Handle tool execution approval"""
        state.sender = node_name
        return Command(update=state.model_dump(), goto=NodeNames.WAIT_FOR_RESPONSE)

    def _handle_chat_continue(self, state: AgentState, node_name: str) -> Command:
        """Handle continuing chat conversation"""
        state.sender = node_name
        return Command(update=state.model_dump(), goto=NodeNames.FINAL_ANSWER_AGENT)


class ChatNode(BaseNode):
    def __init__(self):
        super().__init__()
        self.chat_agent: Optional[ChatAgent] = None
        self.hitl_handler = ChatHumanInTheLoopHandler()
        self._initialized = False

    @classmethod
    async def create(cls):
        """Factory method to create and initialize the class"""
        instance = cls()
        instance.chat_agent = ChatAgent()
        if settings.features.chat:
            await instance.chat_agent.setup()
        instance.node = create_partial(
            ChatNode.node_handler,
            agent=instance.chat_agent,
            hitl_handler=instance.hitl_handler,
            name=instance.chat_agent.name,
        )
        instance._initialized = True
        return instance

    @staticmethod
    def format_function_call(func_dict):
        name = func_dict["name"]
        args = func_dict["args"]

        def format_value(v):
            if isinstance(v, str):
                return f"'{v}'"
            elif isinstance(v, (list, dict)):
                return repr(v)
            else:
                return str(v)

        arg_strings = [f"{k}={format_value(v)}" for k, v in args.items()]
        return f"{name}({', '.join(arg_strings)})"

    @staticmethod
    async def node_handler(
        state: AgentState, agent: ChatAgent, hitl_handler: ChatHumanInTheLoopHandler, name: str
    ) -> Command[Literal["FinalAnswerAgent", "TaskAnalyzerAgent", "SuggestHumanActions"]]:
        # Handle human-in-the-loop responses
        if (
            state.sender == NodeNames.WAIT_FOR_RESPONSE
            and state.hitl_response.action_id == ActionIds.FLOW_APPROVE
        ):
            tool = ToolCall(**state.hitl_response.additional_data.tool)
            res = await agent.execute_tool(tool)
            parsed_result = res
            if isinstance(res, str):
                try:
                    parsed_result = json.loads(res)
                except (json.JSONDecodeError, TypeError):
                    # If parsing fails, keep original string
                    parsed_result = res
            # Get tool details
            tool_name = tool.get("name")
            tool_args = tool.get("args")
            # Add to variable manager

            var_name = f"tool_result_{str(uuid.uuid4())[:5]}"
            state.variables_manager.add_variable(
                parsed_result, var_name, f"Result of tool {tool_name} with args {tool_args}"
            )
            state.sender = "ChatAgentTool"
            state.last_planner_answer = state.variables_manager.present_variable(var_name)
            return Command(update=state.model_dump(), goto=NodeNames.FINAL_ANSWER_AGENT)

        if (
            state.sender == NodeNames.WAIT_FOR_RESPONSE
            and state.hitl_response.action_id == ActionIds.NEW_FLOW_APPROVE
        ):
            logger.debug("tool call in chat node")
            tool = ToolCall(**state.hitl_response.additional_data.tool)
            state.input = tool.get("args").get("user_task")
            state.sender = "ChatAgent"
            return Command(update=state.model_dump(), goto=NodeNames.TASK_ANALYZER_AGENT)

        # If chat feature is disabled, go directly to task analyzer
        if not settings.features.chat:
            state.sender = name
            return Command(update=state.model_dump(), goto=NodeNames.TASK_ANALYZER_AGENT)

        # Process chat input
        state.sender = name
        state.chat_agent_messages.append(HumanMessage(content=state.input))
        res: BaseMessage = await agent.invoke(state.chat_agent_messages, state)
        state.chat_agent_messages.append(res)
        # Handle tool calls - require human approval
        if ENABLE_SAVE_REUSE and res.tool_calls and res.tool_calls[0].get("name") == "run_new_flow":
            state.final_answer = state.chat_agent_messages[-1].content
            state.sender = name
            state.hitl_action = create_new_flow_approve(tool=res.tool_calls[0])
            return Command(update=state.model_dump(), goto=NodeNames.SUGGEST_HUMAN_ACTIONS)

        if ENABLE_SAVE_REUSE and res.tool_calls:
            state.final_answer = state.chat_agent_messages[-1].content
            state.sender = name
            state.hitl_action = create_flow_approve(tool=res.tool_calls[0])
            return Command(update=state.model_dump(), goto=NodeNames.SUGGEST_HUMAN_ACTIONS)

        if (
            not ENABLE_SAVE_REUSE
            and res.tool_calls
            and len(res.tool_calls) > 0
            and res.tool_calls[0].get("name") == "execute_task"
        ):
            logger.debug(f"tool call in chat node {res.tool_calls[0]}")
            variables_rel = res.tool_calls[0].get("args").get("relevant_variables")
            if variables_rel and len(variables_rel) > 0:
                state.input = (
                    f"task: {res.tool_calls[0].get('args').get('task')}"
                    + f"\n relevant variables from history: {res.tool_calls[0].get('args').get('relevant_variables')}"
                )
            else:
                state.input = res.tool_calls[0].get("args").get("task")
            return Command(update=state.model_dump(), goto="TaskAnalyzerAgent")
        # Regular chat response - add to messages and continue+
        res.content = state.variables_manager.replace_variables_placeholders(res.content)
        state.messages.append(res)
        tracker.collect_step(
            step=Step(
                name=name,
                data=res.content,
                current_url=state.url,
            )
        )
        state.final_answer = state.chat_agent_messages[-1].content

        return Command(update=state.model_dump(), goto=NodeNames.FINAL_ANSWER_AGENT)