File size: 18,977 Bytes
3289c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
"""
Test chat message persistence across multiple turns in CugaLite.

This test verifies that chat messages are properly maintained and accumulated
across multiple conversation turns in the CugaAgent execution flow.
"""

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from langchain_core.messages import HumanMessage, AIMessage

from cuga.backend.cuga_graph.nodes.cuga_lite.cuga_agent_base import CugaAgent
from cuga.backend.cuga_graph.nodes.cuga_lite.cuga_lite_node import CugaLiteNode
from cuga.backend.cuga_graph.state.agent_state import AgentState, VariablesManager


class MockToolProvider:
    """Mock tool provider for testing."""

    async def initialize(self):
        pass

    async def get_apps(self):
        return []

    async def get_all_tools(self):
        from langchain_core.tools import tool

        @tool
        def mock_tool(query: str) -> str:
            """A mock tool for testing."""
            return f"Mock response to: {query}"

        return [mock_tool]


class TestChatMessagesPersistence:
    """Test that chat messages persist across multiple turns."""

    def setup_method(self):
        """Reset state before each test."""
        VariablesManager().reset()

    def teardown_method(self):
        """Clean up after each test."""
        VariablesManager().reset()

    @pytest.mark.asyncio
    async def test_chat_messages_returned_from_execute(self):
        """Test that execute() returns updated chat messages."""
        # Create agent with mock provider
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        # Mock the agent graph to return expected state
        with patch.object(agent, 'agent') as mock_agent_graph:
            # Mock the compile method
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            # Simulate CodeAct agent stream returning messages
            async def mock_stream(*args, **kwargs):
                # Initial state has the user message
                initial_state = args[0] if args else kwargs.get('initial_state', {})
                messages = initial_state.get('messages', [])

                # Add AI response
                yield {
                    'messages': messages
                    + [{'role': 'assistant', 'content': 'Here is the answer to your question.'}],
                    'context': {},
                }

            mock_compiled.astream = mock_stream

            # Initialize agent
            await agent.initialize()
            agent.agent = mock_compiled

            # First turn - with previous chat history
            # Simulating a conversation where user asked "What is 2+2?" before
            previous_chat_history = [HumanMessage(content="What is 2+2?"), AIMessage(content="4")]

            answer, metrics, state_messages, updated_chat_messages = await agent.execute(
                task="Now calculate 3+3", show_progress=False, chat_messages=previous_chat_history
            )

            # Verify chat messages are returned
            assert updated_chat_messages is not None, "updated_chat_messages should not be None"
            assert len(updated_chat_messages) >= 4, (
                "Should have previous 2 messages + new user message + AI response"
            )

            # Verify message types and order
            assert isinstance(updated_chat_messages[0], HumanMessage)
            assert updated_chat_messages[0].content == "What is 2+2?"
            assert isinstance(updated_chat_messages[1], AIMessage)
            assert updated_chat_messages[1].content == "4"

            # New messages should be added
            assert isinstance(updated_chat_messages[2], HumanMessage)
            assert updated_chat_messages[2].content == "Now calculate 3+3"

            # Last message should be AI response
            assert isinstance(updated_chat_messages[-1], AIMessage)

    @pytest.mark.asyncio
    async def test_chat_messages_accumulate_across_turns(self):
        """Test that chat messages accumulate correctly across multiple turns."""
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            async def mock_stream(*args, **kwargs):
                initial_state = args[0] if args else kwargs.get('initial_state', {})
                messages = initial_state.get('messages', [])

                # Return all messages plus new AI response
                yield {
                    'messages': messages + [{'role': 'assistant', 'content': f'Response #{len(messages)}'}],
                    'context': {},
                }

            mock_compiled.astream = mock_stream

            await agent.initialize()
            agent.agent = mock_compiled

            # Turn 1 - no previous history
            _, _, _, updated_chat_turn1 = await agent.execute(
                task="Question 1",
                show_progress=False,
                chat_messages=None,  # No previous chat history
            )

            # Turn 1 should return None when chat_messages is None
            assert updated_chat_turn1 is None, "Should return None when no chat_messages provided"

            # Turn 2 - start a conversation with history from turn 1
            # Manually create history (simulating what would be stored from turn 1)
            chat_history_turn2 = [HumanMessage(content="Question 1"), AIMessage(content="Answer 1")]
            _, _, _, updated_chat_turn2 = await agent.execute(
                task="Question 2",
                show_progress=False,
                chat_messages=chat_history_turn2,  # Pass previous conversation
            )

            assert updated_chat_turn2 is not None
            assert len(updated_chat_turn2) == 4, "Should have 2 user + 2 AI messages"
            # Verify the messages are in correct order
            assert updated_chat_turn2[0].content == "Question 1"
            assert updated_chat_turn2[1].content == "Answer 1"

            # Turn 3 - continue the conversation
            _, _, _, updated_chat_turn3 = await agent.execute(
                task="Question 3",
                show_progress=False,
                chat_messages=updated_chat_turn2,  # Pass updated history from turn 2
            )

            assert updated_chat_turn3 is not None
            assert len(updated_chat_turn3) == 6, "Should have 3 user + 3 AI messages"

            # Verify message order and types
            assert isinstance(updated_chat_turn3[0], HumanMessage)
            assert isinstance(updated_chat_turn3[1], AIMessage)
            assert isinstance(updated_chat_turn3[2], HumanMessage)
            assert isinstance(updated_chat_turn3[3], AIMessage)
            assert isinstance(updated_chat_turn3[4], HumanMessage)
            assert isinstance(updated_chat_turn3[5], AIMessage)

    @pytest.mark.asyncio
    async def test_cuga_lite_node_updates_state_chat_messages(self):
        """Test that CugaLiteNode properly updates state.chat_messages."""
        node = CugaLiteNode()

        # Create initial state with no chat messages
        state = AgentState(input="Test query", url="http://example.com", final_answer="")

        # Mock the agent creation and execution
        mock_agent = AsyncMock()
        mock_agent.tools = []
        mock_agent.get_langfuse_trace_id.return_value = None

        # Simulate execute returning updated chat messages
        initial_chat = [HumanMessage(content="Previous question"), AIMessage(content="Previous answer")]
        new_chat = initial_chat + [HumanMessage(content="Test query"), AIMessage(content="Test answer")]

        mock_agent.execute.return_value = (
            "Test answer",  # answer
            {"duration_seconds": 1.0, "total_tokens": 100},  # metrics
            [],  # state_messages
            new_chat,  # updated_chat_messages
        )

        # Patch create_agent to return our mock
        with patch.object(node, 'create_agent', return_value=mock_agent):
            # Set initial chat messages in state
            state.chat_messages = initial_chat

            # Execute the node
            await node.node(state)

            # Verify state.chat_messages was updated
            assert state.chat_messages is not None
            assert len(state.chat_messages) == 4, "Should have 4 messages after update"
            assert state.chat_messages[0].content == "Previous question"
            assert state.chat_messages[1].content == "Previous answer"
            assert state.chat_messages[2].content == "Test query"
            assert state.chat_messages[3].content == "Test answer"

    @pytest.mark.asyncio
    async def test_chat_messages_none_when_not_provided(self):
        """Test that chat_messages can be None when not provided."""
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            async def mock_stream(*args, **kwargs):
                initial_state = args[0] if args else kwargs.get('initial_state', {})
                messages = initial_state.get('messages', [])

                yield {'messages': messages, 'context': {}}

            mock_compiled.astream = mock_stream

            await agent.initialize()
            agent.agent = mock_compiled

            # Execute without chat_messages (None)
            answer, metrics, state_messages, updated_chat_messages = await agent.execute(
                task="Simple query", show_progress=False, chat_messages=None
            )

            # When no chat_messages provided, should return None
            assert updated_chat_messages is None, "Should return None when chat_messages not provided"

    @pytest.mark.asyncio
    async def test_chat_messages_preserved_on_error(self):
        """Test that chat messages are handled correctly on execution error."""
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            async def mock_stream_error(*args, **kwargs):
                raise Exception("Test error")

            mock_compiled.astream = mock_stream_error

            await agent.initialize()
            agent.agent = mock_compiled

            initial_chat = [HumanMessage(content="Test question")]

            # Execute with error
            answer, metrics, state_messages, updated_chat_messages = await agent.execute(
                task="Query that will fail", show_progress=False, chat_messages=initial_chat
            )

            # Should return None for chat_messages on error
            assert updated_chat_messages is None
            assert "Error during execution" in answer
            assert metrics.get('error') is not None

    @pytest.mark.asyncio
    async def test_message_format_conversion(self):
        """Test that dict messages are properly converted to BaseMessage objects."""
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            async def mock_stream(*args, **kwargs):
                # Return messages in dict format (as CodeAct does)
                yield {
                    'messages': [
                        {'role': 'user', 'content': 'User message'},
                        {'role': 'assistant', 'content': 'AI response'},
                    ],
                    'context': {},
                }

            mock_compiled.astream = mock_stream

            await agent.initialize()
            agent.agent = mock_compiled

            # Execute with initial messages
            initial_chat = [HumanMessage(content="User message")]
            _, _, _, updated_chat = await agent.execute(
                task="Test", show_progress=False, chat_messages=initial_chat
            )

            # Verify conversion to BaseMessage objects
            assert updated_chat is not None
            assert len(updated_chat) == 2
            assert isinstance(updated_chat[0], HumanMessage)
            assert isinstance(updated_chat[1], AIMessage)
            assert updated_chat[0].content == "User message"
            assert updated_chat[1].content == "AI response"

    @pytest.mark.asyncio
    async def test_real_world_conversation_flow(self):
        """
        Integration test simulating real-world conversation flow:
        - User asks question 1
        - Agent responds
        - User asks question 2 (referring to previous context)
        - Agent responds with full conversation history
        """
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            # Track call count to generate different responses
            call_count = [0]

            async def mock_stream(*args, **kwargs):
                call_count[0] += 1
                initial_state = args[0] if args else kwargs.get('initial_state', {})
                messages = initial_state.get('messages', [])

                # Return all messages plus appropriate AI response
                if call_count[0] == 1:
                    response = "The capital of France is Paris."
                elif call_count[0] == 2:
                    response = "Yes, Paris is also known for the Eiffel Tower."
                else:
                    response = f"Response to turn {call_count[0]}"

                yield {'messages': messages + [{'role': 'assistant', 'content': response}], 'context': {}}

            mock_compiled.astream = mock_stream

            await agent.initialize()
            agent.agent = mock_compiled

            # === Turn 1: First question (no history) ===
            print("\n=== TURN 1: First question ===")
            answer1, _, _, chat_messages_after_turn1 = await agent.execute(
                task="What is the capital of France?",
                show_progress=False,
                chat_messages=None,  # No previous conversation
            )

            # Since chat_messages was None, the returned chat_messages should also be None
            assert chat_messages_after_turn1 is None
            print(f"Answer 1: {answer1}")
            print("Chat messages after turn 1: None (as expected)")

            # === Turn 2: Follow-up question (with simulated history) ===
            print("\n=== TURN 2: Follow-up question ===")
            # In real app, we'd store turn 1's conversation. Simulating that:
            conversation_history = [
                HumanMessage(content="What is the capital of France?"),
                AIMessage(content="The capital of France is Paris."),
            ]

            answer2, _, _, chat_messages_after_turn2 = await agent.execute(
                task="What is it known for?", show_progress=False, chat_messages=conversation_history
            )

            # Now chat_messages should be returned with full history
            assert chat_messages_after_turn2 is not None
            assert len(chat_messages_after_turn2) == 4, (
                f"Expected 4 messages, got {len(chat_messages_after_turn2)}"
            )

            print(f"Answer 2: {answer2}")
            print(f"Chat messages after turn 2: {len(chat_messages_after_turn2)} messages")

            # Verify conversation history is preserved
            assert chat_messages_after_turn2[0].content == "What is the capital of France?"
            assert chat_messages_after_turn2[1].content == "The capital of France is Paris."
            assert chat_messages_after_turn2[2].content == "What is it known for?"
            assert "Eiffel Tower" in chat_messages_after_turn2[3].content

            print("\n✅ Full conversation history maintained correctly!")
            print("Conversation:")
            for i, msg in enumerate(chat_messages_after_turn2):
                role = "User" if isinstance(msg, HumanMessage) else "AI"
                print(f"  {i + 1}. {role}: {msg.content[:50]}...")

    @pytest.mark.asyncio
    async def test_empty_chat_history_starts_conversation(self):
        """
        Test that passing an empty list for chat_messages starts tracking conversation history.
        This simulates the real scenario where AgentState.chat_messages defaults to [].
        """
        mock_provider = MockToolProvider()
        agent = CugaAgent(tool_provider=mock_provider)

        with patch.object(agent, 'agent') as mock_agent_graph:
            mock_compiled = MagicMock()
            mock_agent_graph.compile.return_value = mock_compiled

            async def mock_stream(*args, **kwargs):
                initial_state = args[0] if args else kwargs.get('initial_state', {})
                messages = initial_state.get('messages', [])

                yield {
                    'messages': messages + [{'role': 'assistant', 'content': 'Hello! How can I help you?'}],
                    'context': {},
                }

            mock_compiled.astream = mock_stream

            await agent.initialize()
            agent.agent = mock_compiled

            # Start with empty list (like AgentState default)
            empty_chat_history = []

            answer, metrics, state_messages, updated_chat_messages = await agent.execute(
                task="Hello",
                show_progress=False,
                chat_messages=empty_chat_history,  # Empty list, not None
            )

            # Should return updated messages even when starting from empty list
            assert updated_chat_messages is not None, "Should return messages when starting from empty list"
            assert len(updated_chat_messages) == 2, (
                f"Expected 2 messages (user + AI), got {len(updated_chat_messages)}"
            )
            assert isinstance(updated_chat_messages[0], HumanMessage)
            assert updated_chat_messages[0].content == "Hello"
            assert isinstance(updated_chat_messages[1], AIMessage)
            assert updated_chat_messages[1].content == "Hello! How can I help you?"

            print("✅ Empty chat history correctly starts conversation tracking!")


if __name__ == "__main__":
    pytest.main([__file__, "-v", "-s"])  # -s to show print statements