Spaces:
Sleeping
Sleeping
| # --- Imports --- | |
| import gradio as gr | |
| from langchain_groq import ChatGroq | |
| from langchain_community.tools.tavily_search import TavilySearchResults # Updated import | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| import os | |
| from dotenv import load_dotenv | |
| from typing import TypedDict, List, Optional, Dict, Sequence | |
| from langgraph.graph import StateGraph, END | |
| # --- Environment Setup --- | |
| load_dotenv() | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not TAVILY_API_KEY or not GROQ_API_KEY: | |
| raise ValueError("API keys for Tavily and Groq must be set in .env file") | |
| # --- Constants --- | |
| ORDERED_FIELDS = [ | |
| "destination", | |
| "budget", | |
| "activities", | |
| "duration", | |
| "accommodation", | |
| ] | |
| QUESTIONS_MAP = { | |
| "destination": "π Where would you like to travel?", | |
| "budget": "π° What is your budget for this trip?", | |
| "activities": "π€Έ What kind of activities do you prefer? (e.g., adventure, relaxation, sightseeing)", | |
| "duration": "β³ How many days do you plan to stay?", | |
| "accommodation": "π¨ Do you prefer hotels, hostels, or Airbnbs?", | |
| } | |
| INITIAL_MESSAGE = "π Welcome! Type `START` to begin planning your travel itinerary." | |
| START_CONFIRMATION = "π Great! Let's plan your trip. " + QUESTIONS_MAP[ORDERED_FIELDS[0]] | |
| RESTART_MESSAGE = "β Restarted! Type `START` to begin again." | |
| INVALID_START_MESSAGE = "β Please type `START` to begin the travel itinerary process." | |
| ITINERARY_READY_MESSAGE = "β Your travel itinerary is ready!" | |
| # --- LangChain / LangGraph Components --- | |
| # Initialize models and tools | |
| llm = ChatGroq(api_key=GROQ_API_KEY, model="llama3-8b-8192") | |
| # Note: TavilySearchResults often works better when integrated as a Langchain Tool | |
| # but for direct use like this, the previous langchain_tavily.TavilySearch is also fine. | |
| # Let's use the more standard Tool wrapper approach for better compatibility. | |
| search_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY) | |
| # Define the state for our graph | |
| class AgentState(TypedDict): | |
| messages: Sequence[BaseMessage] # Conversation history | |
| user_info: Dict[str, str] # Collected user preferences | |
| missing_fields: List[str] # Fields still needed | |
| search_results: Optional[str] # Results from Tavily search | |
| itinerary: Optional[str] # Final generated itinerary | |
| # --- Node Functions --- | |
| def process_user_input(state: AgentState) -> Dict: | |
| """Processes the latest user message to update user_info.""" | |
| last_message = state['messages'][-1] | |
| if not isinstance(last_message, HumanMessage): | |
| # Should not happen in normal flow, but good safeguard | |
| return {} | |
| user_response = last_message.content.strip() | |
| missing_fields = state.get('missing_fields', []) | |
| user_info = state.get('user_info', {}) | |
| if not missing_fields: | |
| # All info gathered, nothing to process here for now | |
| # This node is primarily for capturing answers to questions | |
| return {} | |
| # Assume the user is answering the question related to the *first* missing field | |
| field_to_update = missing_fields[0] | |
| user_info[field_to_update] = user_response | |
| # Remove the field we just got | |
| updated_missing_fields = missing_fields[1:] | |
| print(f"--- Processed Input: Field '{field_to_update}' = '{user_response}' ---") | |
| print(f"--- Remaining Fields: {updated_missing_fields} ---") | |
| return {"user_info": user_info, "missing_fields": updated_missing_fields} | |
| def ask_next_question(state: AgentState) -> Dict: | |
| """Adds the next question to the messages list.""" | |
| missing_fields = state.get('missing_fields', []) | |
| if not missing_fields: | |
| # Should not be called if no fields are missing, but handle defensively | |
| return {"messages": state['messages'] + [AIMessage(content="Something went wrong, no more questions to ask.")]} # type: ignore | |
| next_field = missing_fields[0] | |
| question = QUESTIONS_MAP[next_field] | |
| print(f"--- Asking Question for: {next_field} ---") | |
| # Append the new question | |
| updated_messages = state['messages'] + [AIMessage(content=question)] # type: ignore | |
| return {"messages": updated_messages} | |
| def run_search(state: AgentState) -> Dict: | |
| """Runs Tavily search based on collected user info, handling descriptive terms.""" | |
| user_info = state['user_info'] | |
| print("--- Running Search ---") | |
| # Construct a more descriptive query for the search engine | |
| query_parts = [f"Travel itinerary ideas for {user_info.get('destination', 'anywhere')}"] | |
| if user_info.get('duration'): | |
| # Try to clarify duration if it's non-numeric, otherwise use as is | |
| duration_desc = user_info.get('duration') | |
| query_parts.append(f"for about {duration_desc} days.") | |
| if user_info.get('budget'): | |
| # Frame budget as a description | |
| query_parts.append(f"User's budget is described as: '{user_info.get('budget')}'.") | |
| if user_info.get('activities'): | |
| # Frame activities as a description | |
| query_parts.append(f"User is looking for activities like: '{user_info.get('activities')}'.") | |
| if user_info.get('accommodation'): | |
| # Frame accommodation as a description | |
| query_parts.append(f"User's accommodation preference is: '{user_info.get('accommodation')}'.") | |
| search_query = " ".join(query_parts) | |
| print(f"--- Refined Search Query: {search_query} ---") | |
| try: | |
| results = search_tool.invoke(search_query) | |
| # Handle potential result formats (string, list of docs, dict) | |
| if isinstance(results, list): | |
| search_results_str = "\n\n".join([getattr(doc, 'page_content', str(doc)) for doc in results]) # Safer access | |
| elif isinstance(results, dict) and 'answer' in results: | |
| search_results_str = results['answer'] | |
| elif isinstance(results, dict) and 'result' in results: # Another common format | |
| search_results_str = results['result'] | |
| else: | |
| search_results_str = str(results) | |
| print(f"--- Search Results Found ---") | |
| return {"search_results": search_results_str} | |
| except Exception as e: | |
| print(f"--- Search Failed: {e} ---") | |
| # Provide a more informative error message if possible | |
| error_details = str(e) | |
| return {"search_results": f"Search failed or timed out. Details: {error_details}"} | |
| def generate_itinerary(state: AgentState) -> Dict: | |
| """Generates the final itinerary using the LLM, interpreting flexible user inputs.""" | |
| user_info = state['user_info'] | |
| search_results = state.get('search_results', "No search results available.") # Provide default | |
| print("--- Generating Itinerary ---") | |
| # --- Enhanced Prompt --- | |
| itinerary_prompt = f""" | |
| You are an expert travel planner. Create a detailed and engaging travel itinerary based on the following user preferences: | |
| **User Preferences:** | |
| - Destination: {user_info.get('destination', 'Not specified')} | |
| - Duration: {user_info.get('duration', 'Not specified')} days | |
| - Budget Description: '{user_info.get('budget', 'Not specified')}' | |
| - Preferred Activities Description: '{user_info.get('activities', 'Not specified')}' | |
| - Preferred Accommodation Description: '{user_info.get('accommodation', 'Not specified')}' | |
| **Your Task:** | |
| 1. **Interpret Preferences:** Carefully interpret the user's descriptions for budget, activities, and accommodation. | |
| * If the budget is descriptive (e.g., 'moderate', 'budget-friendly', 'a bit flexible', 'around $X'), tailor suggestions to match that level for the specific destination. Avoid extreme high-cost or only free options unless explicitly requested. 'Moderate' usually implies a balance of value, comfort, and experiences. | |
| * If activities are described generally (e.g., 'mix of famous and offbeat', 'relaxing', 'cultural immersion', 'adventure'), create an itinerary that reflects this. A 'mix' should include popular landmarks and hidden gems. 'Relaxing' should include downtime. | |
| * Interpret accommodation descriptions (e.g., 'mid-range', 'cheap but clean', 'boutique hotel') based on typical offerings at the destination. | |
| 2. **Use Search Results:** Incorporate relevant and specific suggestions from the search results below, but *only* if they align with the interpreted user preferences. Do not blindly copy search results. | |
| 3. **Create a Coherent Plan:** Structure the itinerary logically, often day-by-day. Include suggestions for specific activities, potential dining spots (matching budget), and estimated timings where appropriate. | |
| 4. **Engaging Tone:** Present the itinerary in an exciting and appealing way. | |
| NOTE :- Do not go over the Budget Description. No matter the case , Always try to stay under the Budget Description | |
| **Supporting Search Results:** | |
| ``` | |
| {search_results} | |
| ``` | |
| **Generate the Itinerary:** | |
| """ | |
| # --- End of Enhanced Prompt --- | |
| try: | |
| response = llm.invoke(itinerary_prompt) | |
| # Ensure content extraction handles potential variations | |
| itinerary_content = getattr(response, 'content', str(response)) | |
| print("--- Itinerary Generated ---") | |
| final_message = AIMessage(content=f"{ITINERARY_READY_MESSAGE}\n\n{itinerary_content}") | |
| # Ensure messages list exists and append correctly | |
| updated_messages = list(state.get('messages', [])) + [final_message] | |
| return {"itinerary": itinerary_content, "messages": updated_messages} | |
| except Exception as e: | |
| print(f"--- Itinerary Generation Failed: {e} ---") | |
| error_details = str(e) | |
| error_message = AIMessage(content=f"Sorry, I encountered an error while generating the itinerary. Details: {error_details}") | |
| # Ensure messages list exists and append correctly | |
| updated_messages = list(state.get('messages', [])) + [error_message] | |
| return {"itinerary": None, "messages": updated_messages} | |
| def should_ask_question_or_search(state: AgentState) -> str: | |
| """Determines the next step based on whether all info is collected.""" | |
| missing_fields = state.get('missing_fields', []) | |
| if not missing_fields: | |
| print("--- Condition: All info gathered, proceed to search ---") | |
| return "run_search" | |
| else: | |
| print("--- Condition: More info needed, ask next question ---") | |
| return "ask_next_question" | |
| # --- Build the Graph --- | |
| graph_builder = StateGraph(AgentState) | |
| # Define nodes | |
| graph_builder.add_node("process_user_input", process_user_input) | |
| graph_builder.add_node("ask_next_question", ask_next_question) | |
| graph_builder.add_node("run_search", run_search) # Uses the updated function | |
| graph_builder.add_node("generate_itinerary", generate_itinerary) | |
| # Define edges | |
| graph_builder.set_entry_point("process_user_input") | |
| graph_builder.add_conditional_edges( | |
| "process_user_input", | |
| should_ask_question_or_search, | |
| {"ask_next_question": "ask_next_question", "run_search": "run_search"} | |
| ) | |
| graph_builder.add_edge("ask_next_question", END) | |
| graph_builder.add_edge("run_search", "generate_itinerary") | |
| graph_builder.add_edge("generate_itinerary", END) | |
| # Compile the graph | |
| travel_agent_app = graph_builder.compile() | |
| # --- Gradio Interface --- | |
| # Function to handle the conversation logic with LangGraph state | |
| def handle_user_message(user_input: str, history: List[List[str | None]], current_state_dict: Optional[dict]) -> tuple: | |
| user_input_cleaned = user_input.strip().lower() | |
| # Initialize state if it doesn't exist (first interaction) | |
| if current_state_dict is None: | |
| current_state_dict = { | |
| "messages": [AIMessage(content=INITIAL_MESSAGE)], | |
| "user_info": {}, | |
| "missing_fields": [], # Will be populated if user types START | |
| "search_results": None, | |
| "itinerary": None, | |
| } | |
| # Handle START command | |
| if user_input_cleaned == "start" and not current_state_dict.get("missing_fields"): # Only start if not already started | |
| print("--- Received START ---") | |
| current_state_dict["missing_fields"] = list(ORDERED_FIELDS) # Initialize missing fields | |
| current_state_dict["user_info"] = {} # Reset user info | |
| current_state_dict["messages"] = [AIMessage(content=START_CONFIRMATION)] # type: ignore | |
| # No graph execution needed yet, just update state and return the first question | |
| history.append([None, START_CONFIRMATION]) # Gradio format needs None for user message here | |
| # Handle case where user types something other than START initially | |
| elif not current_state_dict.get("missing_fields") and user_input_cleaned != "start": | |
| print("--- Waiting for START ---") | |
| current_state_dict["messages"] = current_state_dict["messages"] + [ # type: ignore | |
| HumanMessage(content=user_input), | |
| AIMessage(content=INVALID_START_MESSAGE) | |
| ] | |
| history.append([user_input, INVALID_START_MESSAGE]) | |
| # Handle user responses after START | |
| elif current_state_dict.get("missing_fields") or current_state_dict.get("itinerary"): # Process if questions pending or itinerary just generated | |
| print(f"--- User Input: {user_input} ---") | |
| # Prevent processing if itinerary was just generated and user typed something else | |
| if current_state_dict.get("itinerary") and not current_state_dict.get("missing_fields"): | |
| print("--- Itinerary already generated, waiting for START OVER ---") | |
| # Optionally add a message like "Type START OVER to begin again." | |
| history.append([user_input, "Itinerary generated. Please click 'Start Over' to plan a new trip."]) | |
| # Keep state as is, just update history | |
| return history, current_state_dict, "" | |
| # Add user message to state's messages list | |
| current_messages = current_state_dict.get("messages", []) | |
| current_messages.append(HumanMessage(content=user_input)) | |
| current_state_dict["messages"] = current_messages | |
| # Invoke the graph | |
| print("--- Invoking Graph ---") | |
| # Ensure state keys match AgentState before invoking | |
| graph_input_state = AgentState( | |
| messages=current_state_dict.get("messages", []), | |
| user_info=current_state_dict.get("user_info", {}), | |
| missing_fields=current_state_dict.get("missing_fields", []), | |
| search_results=current_state_dict.get("search_results"), | |
| itinerary=current_state_dict.get("itinerary"), | |
| ) | |
| # Use stream or invoke. Invoke is simpler for this request/response cycle. | |
| final_state = travel_agent_app.invoke(graph_input_state) | |
| print("--- Graph Execution Complete ---") | |
| # Update the state dictionary from the graph's final state | |
| current_state_dict.update(final_state) | |
| # Update Gradio history | |
| # The graph adds the AI response(s) to state['messages'] | |
| # Get the *last* AI message added by the graph | |
| ai_response = final_state['messages'][-1].content if final_state['messages'] and isinstance(final_state['messages'][-1], AIMessage) else "Error: No response." | |
| history.append([user_input, ai_response]) | |
| # Handle unexpected state (fallback) | |
| else: | |
| print("--- Unexpected State - Resetting ---") | |
| history.append([user_input, "Something went wrong. Please type START to begin."]) | |
| current_state_dict = { # Reset state | |
| "messages": [AIMessage(content=INITIAL_MESSAGE)], | |
| "user_info": {}, | |
| "missing_fields": [], | |
| "search_results": None, | |
| "itinerary": None, | |
| } | |
| # Return updated history, the persistent state dictionary, and clear the input box | |
| return history, current_state_dict, "" | |
| # Function to reset the state (Start Over button) | |
| def start_over() -> tuple: | |
| print("--- Starting Over ---") | |
| initial_state = { | |
| "messages": [AIMessage(content=RESTART_MESSAGE)], | |
| "user_info": {}, | |
| "missing_fields": [], | |
| "search_results": None, | |
| "itinerary": None, | |
| } | |
| # Gradio history format: List of [user_msg, assistant_msg] pairs | |
| initial_history = [[None, RESTART_MESSAGE]] | |
| return initial_history, initial_state, "" # Return history, state, clear input | |
| # --- Gradio UI Definition --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# π AI-Powered Travel Itinerary Generator (LangGraph Version)") | |
| gr.Markdown("Type `START` to begin planning your trip. Answer the questions, and I'll generate a personalized itinerary for you!") | |
| # Store the LangGraph state between interactions | |
| agent_state = gr.State(value=None) # Initialize state as None | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Travel Bot", | |
| bubble_full_width=False, | |
| value=[[None, INITIAL_MESSAGE]] # Initial message | |
| ) | |
| user_input = gr.Textbox(label="Your Message", placeholder="Type here...", scale=3) | |
| submit_btn = gr.Button("Send", scale=1) | |
| start_over_btn = gr.Button("Start Over", scale=1) | |
| # Button and Textbox actions | |
| submit_btn.click( | |
| fn=handle_user_message, | |
| inputs=[user_input, chatbot, agent_state], | |
| outputs=[chatbot, agent_state, user_input] # Update chatbot, state, and clear input | |
| ) | |
| user_input.submit( # Allow Enter key submission | |
| fn=handle_user_message, | |
| inputs=[user_input, chatbot, agent_state], | |
| outputs=[chatbot, agent_state, user_input] | |
| ) | |
| start_over_btn.click( | |
| fn=start_over, | |
| inputs=[], | |
| outputs=[chatbot, agent_state, user_input] # Reset chatbot, state, and clear input | |
| ) | |
| # --- Run the App --- | |
| if __name__ == "__main__": | |
| app.launch(debug=True)# Debug=True provides more logs |