Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import yaml | |
| import gradio as gr | |
| import uuid | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| from src.document_retrieval import DocumentRetrieval | |
| from utils.parsing.sambaparse import parse_doc_universal # added | |
| from utils.vectordb.vector_db import VectorDb | |
| def handle_userinput(user_question, conversation_chain, history): | |
| if user_question: | |
| try: | |
| # Generate response | |
| response = conversation_chain.invoke({"question": user_question}) | |
| # Append user message and response to chat history | |
| history = history + [(user_question, response["answer"])] | |
| return history, "" | |
| except Exception as e: | |
| error_msg = f"An error occurred: {str(e)}" | |
| history = history + [(user_question, error_msg)] | |
| return history, "" | |
| else: | |
| return history, "" | |
| def process_documents(files, collection_name, document_retrieval, vectorstore, conversation_chain, api_key=None): | |
| try: | |
| if api_key: | |
| sambanova_api_key = api_key | |
| else: | |
| sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
| document_retrieval = DocumentRetrieval(sambanova_api_key) | |
| _, _, text_chunks = parse_doc_universal(doc=files) | |
| print(f'nb of chunks: {len(text_chunks)}') | |
| embeddings = document_retrieval.load_embedding_model() | |
| collection_id = str(uuid.uuid4()) | |
| collection_name = f"collection_{collection_id}" | |
| vectorstore = document_retrieval.create_vector_store(text_chunks, embeddings, output_db=None, collection_name=collection_name) | |
| document_retrieval.init_retriever(vectorstore) | |
| conversation_chain = document_retrieval.get_qa_retrieval_chain() | |
| return conversation_chain, vectorstore, document_retrieval, collection_name, "Complete! You can now ask questions." | |
| except Exception as e: | |
| return conversation_chain, vectorstore, document_retrieval, collection_name, f"An error occurred while processing: {str(e)}" | |
| caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes. | |
| """ | |
| with gr.Blocks() as demo: | |
| vectorstore = gr.State() | |
| conversation_chain = gr.State() | |
| document_retrieval = gr.State() | |
| collection_name=gr.State() | |
| gr.Markdown("# Enterprise Knowledge Retriever", | |
| elem_id="title") | |
| gr.Markdown("Powered by LLama3.1-8B-Instruct on SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).") | |
| api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability") | |
| # Step 1: Add PDF file | |
| gr.Markdown("## 1️⃣ Upload PDF") | |
| docs = gr.File(label="Add PDF file (single)", file_types=["pdf"], file_count="single") | |
| # Step 2: Process PDF file | |
| gr.Markdown(("## 2️⃣ Process document and create vector store")) | |
| db_btn = gr.Radio(["ChromaDB"], label="Vector store type", value = "ChromaDB", type="index", info="Choose your vector store") | |
| setup_output = gr.Textbox(label="Processing status", visible=True, value="None") | |
| process_btn = gr.Button("🔄 Process") | |
| gr.Markdown(caution_text) | |
| # Preprocessing events | |
| process_btn.click(process_documents, inputs=[docs, collection_name, document_retrieval, vectorstore, conversation_chain, api_key], outputs=[conversation_chain, vectorstore, document_retrieval, collection_name, setup_output], concurrency_limit=20) | |
| # Step 3: Chat with your data | |
| gr.Markdown("## 3️⃣ Chat with your document") | |
| chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True) | |
| msg = gr.Textbox(label="Ask questions about your data", show_label=True, placeholder="Enter your message...") | |
| clear_btn = gr.Button("Clear chat") | |
| sources_output = gr.Textbox(label="Sources", visible=False) | |
| # Chatbot events | |
| msg.submit(handle_userinput, inputs=[msg, conversation_chain, chatbot], outputs=[chatbot, msg], queue=False) | |
| clear_btn.click(lambda: [None, ""], inputs=None, outputs=[chatbot, msg], queue=False) | |
| if __name__ == "__main__": | |
| demo.launch() | |