Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| import sys | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| import yaml | |
| from dotenv import load_dotenv | |
| from langchain.chains.base import Chain | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import BasePromptTemplate, load_prompt | |
| from langchain_core.callbacks import CallbackManagerForChainRun | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.retrievers import BaseRetriever | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) # src/ directory | |
| kit_dir = os.path.abspath(os.path.join(current_dir, '..')) # EKR/ directory | |
| repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) | |
| sys.path.append(kit_dir) | |
| sys.path.append(repo_dir) | |
| #import streamlit as st | |
| from utils.model_wrappers.api_gateway import APIGateway | |
| from utils.vectordb.vector_db import VectorDb | |
| from utils.visual.env_utils import get_wandb_key | |
| CONFIG_PATH = os.path.join(kit_dir, 'config.yaml') | |
| PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db') | |
| #load_dotenv(os.path.join(kit_dir, '.env')) | |
| from utils.parsing.sambaparse import parse_doc_universal | |
| # Handle the WANDB_API_KEY resolution before importing weave | |
| #wandb_api_key = get_wandb_key() | |
| # If WANDB_API_KEY is set, proceed with weave initialization | |
| #if wandb_api_key: | |
| # import weave | |
| # Initialize Weave with your project name | |
| # weave.init('sambanova_ekr') | |
| #else: | |
| # print('WANDB_API_KEY is not set. Weave initialization skipped.') | |
| class RetrievalQAChain(Chain): | |
| """class for question-answering.""" | |
| retriever: BaseRetriever | |
| rerank: bool = True | |
| llm: BaseLanguageModel | |
| qa_prompt: BasePromptTemplate | |
| final_k_retrieved_documents: int = 3 | |
| def input_keys(self) -> List[str]: | |
| """Input keys. | |
| :meta private: | |
| """ | |
| return ['question'] | |
| def output_keys(self) -> List[str]: | |
| """Output keys. | |
| :meta private: | |
| """ | |
| return ['answer', 'source_documents'] | |
| def _format_docs(self, docs): | |
| return '\n\n'.join(doc.page_content for doc in docs) | |
| def rerank_docs(self, query, docs, final_k): | |
| # Lazy hardcoding for now | |
| tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large') | |
| reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large') | |
| pairs = [] | |
| for d in docs: | |
| pairs.append([query, d.page_content]) | |
| with torch.no_grad(): | |
| inputs = tokenizer( | |
| pairs, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt', | |
| max_length=512, | |
| ) | |
| scores = ( | |
| reranker(**inputs, return_dict=True) | |
| .logits.view( | |
| -1, | |
| ) | |
| .float() | |
| ) | |
| scores_list = scores.tolist() | |
| scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True) | |
| docs_sorted = [docs[k] for k in scores_sorted_idx] | |
| # docs_sorted = [docs[k] for k in scores_sorted_idx if scores_list[k]>0] | |
| docs_sorted = docs_sorted[:final_k] | |
| return docs_sorted | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| qa_chain = self.qa_prompt | self.llm | StrOutputParser() | |
| response = {} | |
| documents = self.retriever.invoke(inputs['question']) | |
| if self.rerank: | |
| documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents) | |
| docs = self._format_docs(documents) | |
| response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs}) | |
| response['source_documents'] = documents | |
| return response | |
| class DocumentRetrieval: | |
| def __init__(self, sambanova_api_key): | |
| self.vectordb = VectorDb() | |
| config_info = self.get_config_info() | |
| self.api_info = config_info[0] | |
| self.llm_info = config_info[1] | |
| self.embedding_model_info = config_info[2] | |
| self.retrieval_info = config_info[3] | |
| self.prompts = config_info[4] | |
| self.prod_mode = config_info[5] | |
| self.retriever = None | |
| self.llm = self.set_llm(sambanova_api_key) | |
| def get_config_info(self): | |
| """ | |
| Loads json config file | |
| """ | |
| # Read config file | |
| with open(CONFIG_PATH, 'r') as yaml_file: | |
| config = yaml.safe_load(yaml_file) | |
| api_info = config['api'] | |
| llm_info = config['llm'] | |
| embedding_model_info = config['embedding_model'] | |
| retrieval_info = config['retrieval'] | |
| prompts = config['prompts'] | |
| prod_mode = config['prod_mode'] | |
| return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode | |
| def set_llm(self, sambanova_api_key): | |
| #if self.prod_mode: | |
| # sambanova_api_key = st.session_state.SAMBANOVA_API_KEY | |
| #else: | |
| # if 'SAMBANOVA_API_KEY' in st.session_state: | |
| # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY | |
| # else: | |
| # sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
| #sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') | |
| llm = APIGateway.load_llm( | |
| type=self.api_info, | |
| streaming=True, | |
| coe=self.llm_info['coe'], | |
| do_sample=self.llm_info['do_sample'], | |
| max_tokens_to_generate=self.llm_info['max_tokens_to_generate'], | |
| temperature=self.llm_info['temperature'], | |
| select_expert=self.llm_info['select_expert'], | |
| process_prompt=False, | |
| sambanova_api_key=sambanova_api_key, | |
| ) | |
| return llm | |
| def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]: | |
| """ | |
| Parse the uploaded documents and return a list of LangChain documents. | |
| Args: | |
| docs (List[UploadFile]): A list of uploaded files. | |
| additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents. | |
| Defaults to an empty dictionary. | |
| Returns: | |
| List[Document]: A list of LangChain documents. | |
| """ | |
| if additional_metadata is None: | |
| additional_metadata = {} | |
| # Create the data/tmp folder if it doesn't exist | |
| temp_folder = os.path.join(kit_dir, 'data/tmp') | |
| if not os.path.exists(temp_folder): | |
| os.makedirs(temp_folder) | |
| else: | |
| # If there are already files there, delete them | |
| for filename in os.listdir(temp_folder): | |
| file_path = os.path.join(temp_folder, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| # Save all selected files to the tmp dir with their file names | |
| #for doc in docs: | |
| # temp_file = os.path.join(temp_folder, doc.name) | |
| # with open(temp_file, 'wb') as f: | |
| # f.write(doc.getvalue()) | |
| for doc_info in docs: | |
| file_name, file_obj = doc_info | |
| temp_file = os.path.join(temp_folder, file_name) | |
| with open(temp_file, 'wb') as f: | |
| f.write(file_obj.read()) | |
| # Pass in the temp folder for processing into the parse_doc_universal function | |
| _, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata) | |
| return langchain_docs | |
| def load_embedding_model(self): | |
| embeddings = APIGateway.load_embedding_model( | |
| type=self.embedding_model_info['type'], | |
| batch_size=self.embedding_model_info['batch_size'], | |
| coe=self.embedding_model_info['coe'], | |
| select_expert=self.embedding_model_info['select_expert'], | |
| ) | |
| return embeddings | |
| def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None): | |
| print(f'Collection name is {collection_name}') | |
| vectorstore = self.vectordb.create_vector_store( | |
| text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma' | |
| ) | |
| return vectorstore | |
| def load_vdb(self, db_path, embeddings, collection_name=None): | |
| print(f'Loading collection name is {collection_name}') | |
| vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name) | |
| return vectorstore | |
| def init_retriever(self, vectorstore): | |
| if self.retrieval_info['rerank']: | |
| self.retriever = vectorstore.as_retriever( | |
| search_type='similarity_score_threshold', | |
| search_kwargs={ | |
| 'score_threshold': self.retrieval_info['score_threshold'], | |
| 'k': self.retrieval_info['k_retrieved_documents'], | |
| }, | |
| ) | |
| else: | |
| self.retriever = vectorstore.as_retriever( | |
| search_type='similarity_score_threshold', | |
| search_kwargs={ | |
| 'score_threshold': self.retrieval_info['score_threshold'], | |
| 'k': self.retrieval_info['final_k_retrieved_documents'], | |
| }, | |
| ) | |
| def get_qa_retrieval_chain(self): | |
| """ | |
| Generate a qa_retrieval chain using a language model. | |
| This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain | |
| based on the input vector store of text chunks. | |
| Parameters: | |
| vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context | |
| for generating the conversation chain. | |
| Returns: | |
| RetrievalQA: A chain ready for QA without memory | |
| """ | |
| # customprompt = load_prompt(os.path.join(kit_dir, self.prompts["qa_prompt"])) | |
| # qa_chain = customprompt | self.llm | StrOutputParser() | |
| # response = {} | |
| # documents = self.retriever.invoke(question) | |
| # if self.retrieval_info["rerank"]: | |
| # documents = self.rerank_docs(question, documents, self.retrieval_info["final_k_retrieved_documents"]) | |
| # docs = self._format_docs(documents) | |
| # response["answer"] = qa_chain.invoke({"question": question, "context": docs}) | |
| # response["source_documents"] = documents | |
| retrievalQAChain = RetrievalQAChain( | |
| retriever=self.retriever, | |
| llm=self.llm, | |
| qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])), | |
| rerank=self.retrieval_info['rerank'], | |
| final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'], | |
| ) | |
| return retrievalQAChain | |
| def get_conversational_qa_retrieval_chain(self): | |
| """ | |
| Generate a conversational retrieval qa chain using a language model. | |
| This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain | |
| based on the chat history and the relevant retrieved content from the input vector store of text chunks. | |
| Parameters: | |
| vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context | |
| for generating the conversation chain. | |
| Returns: | |
| RetrievalQA: A chain ready for QA with memory | |
| """ | |