|
|
""" |
|
|
Agent-Based RAG System using LangGraph |
|
|
Provides intelligent query answering with tool use and multi-hop reasoning |
|
|
""" |
|
|
from typing import List, Dict, Any, TypedDict, Annotated |
|
|
from typing_extensions import TypedDict |
|
|
from langgraph.graph import StateGraph, END, START |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langchain_core.tools import tool |
|
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage |
|
|
from loguru import logger |
|
|
import asyncio |
|
|
|
|
|
from models import SourceCitation, ChatResponse |
|
|
from graph_store import GraphStore |
|
|
from embedding_service import EmbeddingService |
|
|
from llm_service import LLMService |
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
"""State for the RAG agent workflow""" |
|
|
messages: List |
|
|
query: str |
|
|
pdf_id: str |
|
|
tool_results: Dict[str, Any] |
|
|
reasoning_steps: List[str] |
|
|
final_answer: str |
|
|
citations: List[SourceCitation] |
|
|
next_action: str |
|
|
|
|
|
|
|
|
class RAGAgent: |
|
|
""" |
|
|
Intelligent RAG agent that uses multiple tools to answer questions |
|
|
|
|
|
Tools available: |
|
|
1. vector_search - Semantic search through document chunks |
|
|
2. graph_search - Find concepts in knowledge graph |
|
|
3. get_node_details - Get detailed info about a graph node |
|
|
4. get_related_nodes - Traverse graph relationships |
|
|
5. get_chunk_by_id - Retrieve specific chunks for citations |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
graph_store: GraphStore, |
|
|
embedding_service: EmbeddingService, |
|
|
llm_service: LLMService): |
|
|
"""Initialize the RAG agent with necessary services""" |
|
|
self.graph_store = graph_store |
|
|
self.embedding_service = embedding_service |
|
|
self.llm_service = llm_service |
|
|
|
|
|
|
|
|
self.workflow = self._build_workflow() |
|
|
self.app = self.workflow.compile() |
|
|
|
|
|
logger.info("β RAG Agent initialized with LangGraph workflow") |
|
|
|
|
|
def _create_tools(self): |
|
|
"""Create tool functions for the agent""" |
|
|
|
|
|
@tool |
|
|
def vector_search(query: str, top_k: int = 5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Search document chunks using semantic similarity. |
|
|
|
|
|
Args: |
|
|
query: The search query |
|
|
top_k: Number of results to return |
|
|
|
|
|
Returns: |
|
|
List of relevant chunks with metadata and scores |
|
|
""" |
|
|
logger.info(f"π Tool: vector_search('{query}', top_k={top_k})") |
|
|
|
|
|
try: |
|
|
results = self.embedding_service.search( |
|
|
query=query, |
|
|
top_k=top_k |
|
|
) |
|
|
|
|
|
formatted_results = [] |
|
|
for metadata, score in results: |
|
|
formatted_results.append({ |
|
|
"text": metadata.get("text", ""), |
|
|
"page_number": metadata.get("page_number", 0), |
|
|
"chunk_id": metadata.get("chunk_id", ""), |
|
|
"score": float(score) |
|
|
}) |
|
|
|
|
|
logger.info(f" β Found {len(formatted_results)} chunks") |
|
|
return formatted_results |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β vector_search failed: {e}") |
|
|
return [] |
|
|
|
|
|
@tool |
|
|
def graph_search(concept: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Find a concept node in the knowledge graph. |
|
|
|
|
|
Args: |
|
|
concept: The concept to search for |
|
|
|
|
|
Returns: |
|
|
Node information if found, None otherwise |
|
|
""" |
|
|
logger.info(f"π Tool: graph_search('{concept}')") |
|
|
|
|
|
try: |
|
|
node = self.graph_store.get_node_by_label(concept) |
|
|
|
|
|
if node: |
|
|
logger.info(f" β Found node: {node.label}") |
|
|
return { |
|
|
"node_id": node.node_id, |
|
|
"label": node.label, |
|
|
"type": node.type.value if hasattr(node.type, 'value') else node.type, |
|
|
"importance": node.importance_score |
|
|
} |
|
|
else: |
|
|
logger.info(f" β No node found for '{concept}'") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β graph_search failed: {e}") |
|
|
return None |
|
|
|
|
|
@tool |
|
|
def get_node_details(node_id: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Get detailed information about a graph node. |
|
|
|
|
|
Args: |
|
|
node_id: The ID of the node |
|
|
|
|
|
Returns: |
|
|
Detailed node information including supporting chunks |
|
|
""" |
|
|
logger.info(f"π Tool: get_node_details('{node_id}')") |
|
|
|
|
|
try: |
|
|
node = self.graph_store.get_node(node_id) |
|
|
|
|
|
if not node: |
|
|
logger.info(f" β Node not found") |
|
|
return None |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
for chunk in node.supporting_chunks[:5]: |
|
|
chunks.append({ |
|
|
"page_number": chunk.page_number, |
|
|
"snippet": chunk.snippet, |
|
|
"score": chunk.score |
|
|
}) |
|
|
|
|
|
logger.info(f" β Got details for {node.label}") |
|
|
return { |
|
|
"label": node.label, |
|
|
"type": node.type.value if hasattr(node.type, 'value') else node.type, |
|
|
"importance": node.importance_score, |
|
|
"supporting_chunks": chunks |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β get_node_details failed: {e}") |
|
|
return None |
|
|
|
|
|
@tool |
|
|
def get_related_nodes(node_id: str, max_neighbors: int = 5) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Get nodes related to a given node (graph traversal). |
|
|
|
|
|
Args: |
|
|
node_id: The ID of the starting node |
|
|
max_neighbors: Maximum number of related nodes to return |
|
|
|
|
|
Returns: |
|
|
List of related nodes with relationship information |
|
|
""" |
|
|
logger.info(f"π Tool: get_related_nodes('{node_id}', max={max_neighbors})") |
|
|
|
|
|
try: |
|
|
neighbors = self.graph_store.get_neighbors(node_id) |
|
|
|
|
|
related = [] |
|
|
for neighbor_node, edge in neighbors[:max_neighbors]: |
|
|
relation_value = edge.relation.value if hasattr(edge.relation, 'value') else edge.relation |
|
|
related.append({ |
|
|
"node_id": neighbor_node.node_id, |
|
|
"label": neighbor_node.label, |
|
|
"relation": relation_value, |
|
|
"confidence": edge.confidence |
|
|
}) |
|
|
|
|
|
logger.info(f" β Found {len(related)} related nodes") |
|
|
return related |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β get_related_nodes failed: {e}") |
|
|
return [] |
|
|
|
|
|
@tool |
|
|
def get_chunk_by_id(chunk_id: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Retrieve a specific chunk by its ID (for detailed citations). |
|
|
|
|
|
Args: |
|
|
chunk_id: The chunk identifier |
|
|
|
|
|
Returns: |
|
|
Chunk content and metadata |
|
|
""" |
|
|
logger.info(f"π Tool: get_chunk_by_id('{chunk_id}')") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
results = self.embedding_service.search_by_chunk_ids([chunk_id], top_k=1) |
|
|
|
|
|
if results: |
|
|
metadata, score = results[0] |
|
|
logger.info(f" β Found chunk") |
|
|
return { |
|
|
"text": metadata.get("text", ""), |
|
|
"page_number": metadata.get("page_number", 0), |
|
|
"chunk_id": chunk_id |
|
|
} |
|
|
else: |
|
|
logger.info(f" β Chunk not found") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f" β get_chunk_by_id failed: {e}") |
|
|
return None |
|
|
|
|
|
return [vector_search, graph_search, get_node_details, get_related_nodes, get_chunk_by_id] |
|
|
|
|
|
def _build_workflow(self) -> StateGraph: |
|
|
"""Build the LangGraph workflow for the agent""" |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
workflow.add_node("plan", self._plan_node) |
|
|
workflow.add_node("execute_tools", self._execute_tools_node) |
|
|
workflow.add_node("synthesize", self._synthesize_node) |
|
|
|
|
|
|
|
|
workflow.add_edge(START, "plan") |
|
|
workflow.add_conditional_edges( |
|
|
"plan", |
|
|
self._should_use_tools, |
|
|
{ |
|
|
"tools": "execute_tools", |
|
|
"direct": "synthesize" |
|
|
} |
|
|
) |
|
|
workflow.add_edge("execute_tools", "synthesize") |
|
|
workflow.add_edge("synthesize", END) |
|
|
|
|
|
return workflow |
|
|
|
|
|
def _plan_node(self, state: AgentState) -> AgentState: |
|
|
"""Agent decides which tools to use""" |
|
|
logger.info("π€ Agent: Planning which tools to use...") |
|
|
|
|
|
query = state["query"] |
|
|
|
|
|
|
|
|
tools_to_use = [] |
|
|
reasoning = [] |
|
|
|
|
|
|
|
|
tools_to_use.append("vector_search") |
|
|
reasoning.append("Use vector search for semantic document retrieval") |
|
|
|
|
|
|
|
|
if any(word in query.lower() for word in ["relate", "connection", "link", "between"]): |
|
|
tools_to_use.append("graph_search") |
|
|
reasoning.append("Query asks about relationships - use graph search") |
|
|
|
|
|
|
|
|
if any(word in query.lower() for word in ["what is", "define", "explain"]): |
|
|
tools_to_use.append("graph_search") |
|
|
reasoning.append("Query asks for concept definition - check graph") |
|
|
|
|
|
state["tool_results"] = {"planned_tools": tools_to_use} |
|
|
state["reasoning_steps"] = reasoning |
|
|
state["next_action"] = "tools" if tools_to_use else "direct" |
|
|
|
|
|
logger.info(f" Plan: {tools_to_use}") |
|
|
return state |
|
|
|
|
|
def _should_use_tools(self, state: AgentState) -> str: |
|
|
"""Decide if tools are needed""" |
|
|
return state.get("next_action", "direct") |
|
|
|
|
|
def _execute_tools_node(self, state: AgentState) -> AgentState: |
|
|
"""Execute the planned tools""" |
|
|
logger.info("π§ Agent: Executing tools...") |
|
|
|
|
|
query = state["query"] |
|
|
planned_tools = state["tool_results"].get("planned_tools", []) |
|
|
results = {} |
|
|
|
|
|
|
|
|
tools_map = {} |
|
|
for tool in self._create_tools(): |
|
|
tools_map[tool.name] = tool |
|
|
|
|
|
|
|
|
if "vector_search" in planned_tools: |
|
|
vector_tool = tools_map["vector_search"] |
|
|
results["vector_results"] = vector_tool.invoke({"query": query, "top_k": 5}) |
|
|
|
|
|
if "graph_search" in planned_tools: |
|
|
|
|
|
|
|
|
words = query.lower().split() |
|
|
potential_concepts = [w for w in words if len(w) > 4 and w not in ["what", "how", "does", "relate"]] |
|
|
|
|
|
for concept in potential_concepts[:2]: |
|
|
graph_tool = tools_map["graph_search"] |
|
|
node_result = graph_tool.invoke({"concept": concept}) |
|
|
if node_result: |
|
|
results[f"graph_node_{concept}"] = node_result |
|
|
|
|
|
|
|
|
related_tool = tools_map["get_related_nodes"] |
|
|
related = related_tool.invoke({"node_id": node_result["node_id"], "max_neighbors": 3}) |
|
|
results[f"related_{concept}"] = related |
|
|
break |
|
|
|
|
|
state["tool_results"].update(results) |
|
|
logger.info(f" β Executed {len(planned_tools)} tools, got {len(results)} results") |
|
|
return state |
|
|
|
|
|
async def _synthesize_node(self, state: AgentState) -> AgentState: |
|
|
"""Synthesize final answer from tool results""" |
|
|
logger.info("π― Agent: Synthesizing answer...") |
|
|
|
|
|
query = state["query"] |
|
|
tool_results = state["tool_results"] |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
citations = [] |
|
|
|
|
|
|
|
|
if "vector_results" in tool_results: |
|
|
vector_results = tool_results["vector_results"] |
|
|
for i, result in enumerate(vector_results[:3]): |
|
|
context_parts.append(f"[Source {i+1}, p.{result['page_number']}]: {result['text']}") |
|
|
citations.append(SourceCitation( |
|
|
page_number=result["page_number"], |
|
|
snippet=result["text"][:120] + "..." if len(result["text"]) > 120 else result["text"], |
|
|
chunk_id=result["chunk_id"], |
|
|
score=result["score"] |
|
|
)) |
|
|
|
|
|
|
|
|
for key, value in tool_results.items(): |
|
|
if key.startswith("graph_node_"): |
|
|
concept = key.replace("graph_node_", "") |
|
|
context_parts.append(f"[Graph Node]: '{value['label']}' is a {value['type']} (importance: {value['importance']:.2f})") |
|
|
elif key.startswith("related_"): |
|
|
concept = key.replace("related_", "") |
|
|
if value: |
|
|
relations = ", ".join([f"{r['label']} ({r['relation']})" for r in value]) |
|
|
context_parts.append(f"[Related Concepts]: {relations}") |
|
|
|
|
|
|
|
|
context = "\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
answer = await self.llm_service.agent_synthesize(query, context) |
|
|
|
|
|
state["final_answer"] = answer |
|
|
state["citations"] = citations |
|
|
|
|
|
logger.info(" β Answer synthesized") |
|
|
return state |
|
|
|
|
|
async def chat(self, query: str, pdf_id: str = None, include_citations: bool = True) -> ChatResponse: |
|
|
""" |
|
|
Main entry point for agent-based chat |
|
|
|
|
|
Args: |
|
|
query: User's question |
|
|
pdf_id: Optional PDF context |
|
|
include_citations: Whether to include source citations |
|
|
|
|
|
Returns: |
|
|
ChatResponse with answer and citations |
|
|
""" |
|
|
logger.info(f"\n{'='*80}") |
|
|
logger.info(f"π€ Agent-Based RAG Query: '{query}'") |
|
|
logger.info(f"{'='*80}") |
|
|
|
|
|
|
|
|
initial_state = { |
|
|
"messages": [HumanMessage(content=query)], |
|
|
"query": query, |
|
|
"pdf_id": pdf_id or "", |
|
|
"tool_results": {}, |
|
|
"reasoning_steps": [], |
|
|
"final_answer": "", |
|
|
"citations": [], |
|
|
"next_action": "" |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
final_state = await self.app.ainvoke(initial_state) |
|
|
|
|
|
|
|
|
answer = final_state.get("final_answer", "I couldn't generate an answer.") |
|
|
citations = final_state.get("citations", []) |
|
|
|
|
|
if not include_citations: |
|
|
citations = [] |
|
|
|
|
|
logger.info(f"β Agent completed successfully") |
|
|
logger.info(f" Answer length: {len(answer)} chars") |
|
|
logger.info(f" Citations: {len(citations)}") |
|
|
logger.info(f"{'='*80}\n") |
|
|
|
|
|
return ChatResponse( |
|
|
answer=answer, |
|
|
sources=citations[:5] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Agent failed: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
logger.warning("Falling back to simple RAG...") |
|
|
return await self._fallback_simple_rag(query, pdf_id) |
|
|
|
|
|
async def _fallback_simple_rag(self, query: str, pdf_id: str = None) -> ChatResponse: |
|
|
"""Fallback to simple RAG if agent fails""" |
|
|
try: |
|
|
results = self.embedding_service.search(query=query, top_k=5, filter_pdf_id=pdf_id) |
|
|
|
|
|
if not results: |
|
|
return ChatResponse( |
|
|
answer="I couldn't find relevant information to answer your question.", |
|
|
sources=[] |
|
|
) |
|
|
|
|
|
|
|
|
context_chunks = [ |
|
|
{ |
|
|
"page_number": meta.get("page_number", 0), |
|
|
"text": meta.get("text", "") |
|
|
} |
|
|
for meta, score in results[:3] |
|
|
] |
|
|
|
|
|
|
|
|
answer = await self.llm_service.rag_chat(query, context_chunks) |
|
|
|
|
|
|
|
|
sources = [] |
|
|
for meta, score in results[:5]: |
|
|
text = meta.get("text", "") |
|
|
snippet = text[:120] + "..." if len(text) > 120 else text |
|
|
sources.append(SourceCitation( |
|
|
page_number=meta.get("page_number", 0), |
|
|
snippet=snippet, |
|
|
chunk_id=meta.get("chunk_id", ""), |
|
|
score=score |
|
|
)) |
|
|
|
|
|
return ChatResponse(answer=answer, sources=sources) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Fallback RAG also failed: {e}") |
|
|
return ChatResponse( |
|
|
answer="I encountered an error processing your question.", |
|
|
sources=[] |
|
|
) |
|
|
|