File size: 5,413 Bytes
15c8adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74d3140
 
15c8adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d84f5a4
 
15c8adb
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import json
from flask import Flask, request, jsonify
from flask_cors import CORS
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from dotenv import load_dotenv
import re


load_dotenv()

app = Flask(__name__)
CORS(app)

rag_chain = None

def json_to_text(json_data):
    """
    A simple function to convert the structured JSON into a single string.
    This text will be used for creating embeddings.
    """
    text = ""
    for key, value in json_data.items():
        if isinstance(value, dict):
            text += f"{key.replace('_', ' ').title()}:\n"
            for sub_key, sub_value in value.items():
                text += f"  {sub_key.replace('_', ' ').title()}: {sub_value}\n"
        elif isinstance(value, list):
            text += f"{key.replace('_', ' ').title()}:\n"
            for item in value:
                if isinstance(item, dict):
                    for item_key, item_value in item.items():
                        text += f"  - {item_key.replace('_', ' ').title()}: {item_value}\n"
                else:
                    text += f"  - {item}\n"
        else:
            text += f"{key.replace('_', ' ').title()}: {value}\n"
        text += "\n"
    return text

def initialize_rag_chain():
    global rag_chain
    try:
        os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp'

        # 1. Load data from JSON file
        print("Loading data from knowledge_base.json...")
        data_dir = "data"
        json_path = os.path.join(data_dir, 'knowledge_base.json')
        
        with open(json_path, 'r') as f:
            knowledge_base = json.load(f)
        
        # Convert the entire JSON to a single text string
        text_content = json_to_text(knowledge_base)
        # Wrap it in a LangChain Document object
        documents = [Document(page_content=text_content)]
        
        # 2. Chunk the documents
        print("Splitting document into chunks...")
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        chunks = text_splitter.split_documents(documents)
        print(f"Created {len(chunks)} text chunks.")

        # 3. Create Embeddings
        print("Initializing embedding model...")
        embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")

        # 4. Create FAISS Vector Store
        print("Creating FAISS vector store...")
        vector_store = FAISS.from_documents(chunks, embedding_model)
        retriever = vector_store.as_retriever()

        # 5. Initialize LLM
        print("Initializing Groq LLM...")
        llm = ChatGroq(temperature=0, model_name="qwen/qwen3-32b")

        # 6. Create RAG Chain
        print("Creating RAG chain...")
        prompt = ChatPromptTemplate.from_template("""
        You are "KurianGPT", an expert and very friendly AI assistant providing information about Kurian Jose based on his resume and project documents.
        Answer the user's question based only on the following context.
        Refer too yourself as "KurianGPT" and Kurian as "Kurian".
        Avoid showing internal reasoning. Do not output <think> tags. Respond directly and professionally.
        Keep responses **very brief** (1–2 sentences max) unless the user asks for more details or examples.
        If the answer is not in the context, politely say that you can only answer questions regarding kurian's professional background and projects.
        refer to the user as "you" and Kurian as "Kurian".
        <context>
        {context}
        </context>

        Question: {input}
        """)
        document_chain = create_stuff_documents_chain(llm, prompt)
        rag_chain = create_retrieval_chain(retriever, document_chain)
        print("--- RAG Chain Initialized Successfully! ---")
    except Exception as e:
        print(f"Error during RAG initialization: {e}")

def strip_think_tags(text):
    return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

#App routes
@app.route('/api/chat', methods=['POST'])
def chat():
    if not rag_chain:
        return jsonify({'error': 'RAG chain is not initialized. Check server logs.'}), 500
    data = request.get_json()
    user_message = data.get('message')
    if not user_message:
        return jsonify({'error': 'No message provided'}), 400
    try:
        result = rag_chain.invoke({"input": user_message})
        raw_response = result.get('answer', "I couldn't generate a response.")
        print(f"DEBUG RAW RESPONSE:\n{raw_response}\n----------------")
        
        clean_response = strip_think_tags(raw_response)
        return jsonify({'reply': clean_response})
    except Exception as e:
        print(f"Error during chat processing: {e}")
        return jsonify({'error': 'An error occurred while processing your request.'}), 500

# Initialize the RAG chain when the application starts
initialize_rag_chain()

# This part is for local development only and will not be used by Vercel
if __name__ == '__main__':
    app.run(debug=True, port=5001)