Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import argparse | |
| import pandas as pd | |
| import time | |
| from typing import Any, Dict, Optional | |
| from langchain_core.callbacks import CallbackManagerForChainRun | |
| from langchain.prompts import load_prompt | |
| from langchain_core.output_parsers import StrOutputParser | |
| from transformers import AutoTokenizer | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| kit_dir = os.path.abspath(os.path.join(current_dir, "..")) | |
| repo_dir = os.path.abspath(os.path.join(kit_dir, "..")) | |
| sys.path.append(kit_dir) | |
| sys.path.append(repo_dir) | |
| from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain | |
| class TimedRetrievalQAChain(RetrievalQAChain): | |
| #override call method to return times | |
| def _call(self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| qa_chain = self.qa_prompt | self.llm | StrOutputParser() | |
| response = {} | |
| start_time = time.time() | |
| documents = self.retriever.invoke(inputs["question"]) | |
| if self.rerank: | |
| documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents) | |
| docs = self._format_docs(documents) | |
| end_preprocessing_time=time.time() | |
| response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs}) | |
| end_llm_time=time.time() | |
| response["source_documents"] = documents | |
| response["start_time"] = start_time | |
| response["end_preprocessing_time"] = end_preprocessing_time | |
| response["end_llm_time"] = end_llm_time | |
| return response | |
| def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer): | |
| preprocessing_time=end_preprocessing_time-start_time | |
| llm_time=end_llm_time-end_preprocessing_time | |
| token_count=len(tokenizer.encode(answer)) | |
| tokens_per_second = token_count / llm_time | |
| perf = {"preprocessing_time": preprocessing_time, | |
| "llm_time": llm_time, | |
| "token_count": token_count, | |
| "tokens_per_second": tokens_per_second} | |
| return perf | |
| def generate(qa_chain, question, tokenizer): | |
| response = qa_chain.invoke({"question": question}) | |
| answer = response.get('answer') | |
| sources = set([ | |
| f'{sd.metadata["filename"]}' | |
| for sd in response["source_documents"] | |
| ]) | |
| times = analyze_times( | |
| answer, | |
| response.get("start_time"), | |
| response.get("end_preprocessing_time"), | |
| response.get("end_llm_time"), | |
| tokenizer | |
| ) | |
| return answer, sources, times | |
| def process_bulk_QA(vectordb_path, questions_file_path): | |
| documentRetrieval = DocumentRetrieval() | |
| tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") | |
| if os.path.exists(vectordb_path): | |
| # load the vectorstore | |
| embeddings = documentRetrieval.load_embedding_model() | |
| vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings) | |
| print("Database loaded") | |
| documentRetrieval.init_retriever(vectorstore) | |
| print("retriever initialized") | |
| #get qa chain | |
| qa_chain = TimedRetrievalQAChain( | |
| retriever=documentRetrieval.retriever, | |
| llm=documentRetrieval.llm, | |
| qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])), | |
| rerank = documentRetrieval.retrieval_info["rerank"], | |
| final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"] | |
| ) | |
| else: | |
| raise f"vector db path {vectordb_path} does not exist" | |
| if os.path.exists(questions_file_path): | |
| df = pd.read_excel(questions_file_path) | |
| print(df) | |
| output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx') | |
| if 'Answer' not in df.columns: | |
| df['Answer'] = '' | |
| df['Sources'] = '' | |
| df['preprocessing_time'] = '' | |
| df['llm_time'] = '' | |
| df['token_count'] = '' | |
| df['tokens_per_second'] = '' | |
| for index, row in df.iterrows(): | |
| if row['Answer'].strip()=='': # Only process if 'Answer' is empty | |
| try: | |
| # Generate the answer | |
| print(f"Generating answer for row {index}") | |
| answer, sources, times = generate(qa_chain, row['Questions'], tokenizer) | |
| df.at[index, 'Answer'] = answer | |
| df.at[index, 'Sources'] = sources | |
| df.at[index, 'preprocessing_time'] = times.get("preprocessing_time") | |
| df.at[index, 'llm_time'] = times.get("llm_time") | |
| df.at[index, 'token_count'] = times.get("token_count") | |
| df.at[index, 'tokens_per_second'] = times.get("tokens_per_second") | |
| except Exception as e: | |
| print(f"Error processing row {index}: {e}") | |
| # Save the file after each iteration to avoid data loss | |
| df.to_excel(output_file_path, index=False) | |
| else: | |
| print(f"Skipping row {index} because 'Answer' is already in the document") | |
| return output_file_path | |
| else: | |
| raise f"questions file path {questions_file_path} does not exist" | |
| if __name__ == "__main__": | |
| # Parse the arguments | |
| parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions') | |
| parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG') | |
| parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions') | |
| args = parser.parse_args() | |
| # process in bulk | |
| out_file = process_bulk_QA(args.vectordb_path, args.questions_path) | |
| print(f"Finished, responses in: {out_file}") |