Spaces:
Build error
Build error
| import tempfile | |
| import itertools | |
| import gradio as gr | |
| from __init__ import * | |
| from llama_cpp import Llama | |
| from chromadb.config import Settings | |
| from typing import List, Optional, Union | |
| from langchain.vectorstores import Chroma | |
| from langchain.docstore.document import Document | |
| from huggingface_hub.file_download import http_get | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| class LocalChatGPT: | |
| def __init__(self): | |
| self.llama_model: Optional[Llama] = None | |
| self.embeddings: HuggingFaceEmbeddings = self.initialize_app() | |
| def initialize_app(self) -> HuggingFaceEmbeddings: | |
| """ | |
| Load all models from the list | |
| :return: | |
| """ | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| model_url, model_name = list(DICT_REPO_AND_MODELS.items())[0] | |
| final_model_path = os.path.join(MODELS_DIR, model_name) | |
| os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) | |
| if not os.path.exists(final_model_path): | |
| with open(final_model_path, "wb") as f: | |
| http_get(model_url, f) | |
| self.llama_model = Llama( | |
| model_path=final_model_path, | |
| n_ctx=2000, | |
| n_parts=1, | |
| ) | |
| return HuggingFaceEmbeddings(model_name=EMBEDDER_NAME, cache_folder=MODELS_DIR) | |
| def load_model(self, model_name): | |
| """ | |
| :param model_name: | |
| :return: | |
| """ | |
| final_model_path = os.path.join(MODELS_DIR, model_name) | |
| os.makedirs("/".join(final_model_path.split("/")[:-1]), exist_ok=True) | |
| if not os.path.exists(final_model_path): | |
| with open(final_model_path, "wb") as f: | |
| if model_url := [i for i in DICT_REPO_AND_MODELS if DICT_REPO_AND_MODELS[i] == model_name]: | |
| http_get(model_url[0], f) | |
| self.llama_model = Llama( | |
| model_path=final_model_path, | |
| n_ctx=2000, | |
| n_parts=1, | |
| ) | |
| return model_name | |
| def load_single_document(file_path: str) -> Document: | |
| """ | |
| Upload one document. | |
| :param file_path: | |
| :return: | |
| """ | |
| ext: str = "." + file_path.rsplit(".", 1)[-1] | |
| assert ext in LOADER_MAPPING | |
| loader_class, loader_args = LOADER_MAPPING[ext] | |
| loader = loader_class(file_path, **loader_args) | |
| return loader.load()[0] | |
| def get_message_tokens(model: Llama, role: str, content: str) -> list: | |
| """ | |
| :param model: | |
| :param role: | |
| :param content: | |
| :return: | |
| """ | |
| message_tokens: list = model.tokenize(content.encode("utf-8")) | |
| message_tokens.insert(1, ROLE_TOKENS[role]) | |
| message_tokens.insert(2, LINEBREAK_TOKEN) | |
| message_tokens.append(model.token_eos()) | |
| return message_tokens | |
| def get_system_tokens(self, model: Llama) -> list: | |
| """ | |
| :param model: | |
| :return: | |
| """ | |
| system_message: dict = {"role": "system", "content": SYSTEM_PROMPT} | |
| return self.get_message_tokens(model, **system_message) | |
| def upload_files(files: List[tempfile.TemporaryFile]) -> List[str]: | |
| """ | |
| :param files: | |
| :return: | |
| """ | |
| return [f.name for f in files] | |
| def process_text(text: str) -> Optional[str]: | |
| """ | |
| :param text: | |
| :return: | |
| """ | |
| lines: list = text.split("\n") | |
| lines = [line for line in lines if len(line.strip()) > 2] | |
| text = "\n".join(lines).strip() | |
| return None if len(text) < 10 else text | |
| def update_text_db( | |
| db: Optional[Chroma], | |
| fixed_documents: List[Document], | |
| ids: List[str] | |
| ) -> Union[Optional[Chroma], str]: | |
| if db: | |
| data: dict = db.get() | |
| files_db = {dict_data['source'].split('/')[-1] for dict_data in data["metadatas"]} | |
| files_load = {dict_data.metadata["source"].split('/')[-1] for dict_data in fixed_documents} | |
| if files_load == files_db: | |
| # db.delete([item for item in data['ids'] if item not in ids]) | |
| # db.update_documents(ids, fixed_documents) | |
| db.delete(data['ids']) | |
| db.add_texts( | |
| texts=[doc.page_content for doc in fixed_documents], | |
| metadatas=[doc.metadata for doc in fixed_documents], | |
| ids=ids | |
| ) | |
| file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions" | |
| return db, file_warning | |
| def build_index( | |
| self, | |
| file_paths: List[str], | |
| db: Optional[Chroma], | |
| chunk_size: int, | |
| chunk_overlap: int | |
| ): | |
| """ | |
| :param file_paths: | |
| :param db: | |
| :param chunk_size: | |
| :param chunk_overlap: | |
| :return: | |
| """ | |
| documents: List[Document] = [self.load_single_document(path) for path in file_paths] | |
| text_splitter: RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
| ) | |
| documents = text_splitter.split_documents(documents) | |
| fixed_documents: List[Document] = [] | |
| for doc in documents: | |
| doc.page_content = self.process_text(doc.page_content) | |
| if not doc.page_content: | |
| continue | |
| fixed_documents.append(doc) | |
| ids: List[str] = [ | |
| f"{path.split('/')[-1].replace('.txt', '')}{i}" | |
| for path, i in itertools.product(file_paths, range(1, len(fixed_documents) + 1)) | |
| ] | |
| self.update_text_db(db, fixed_documents, ids) | |
| db = Chroma.from_documents( | |
| documents=fixed_documents, | |
| embedding=self.embeddings, | |
| ids=ids, | |
| client_settings=Settings( | |
| anonymized_telemetry=False, | |
| persist_directory="db" | |
| ) | |
| ) | |
| file_warning = f"Uploaded {len(fixed_documents)} fragments! You can ask questions." | |
| return db, file_warning | |
| def user(message, history): | |
| new_history = history + [[message, None]] | |
| return "", new_history | |
| def regenerate_response(history): | |
| """ | |
| :param history: | |
| :return: | |
| """ | |
| return "", history | |
| def retrieve(history, db: Optional[Chroma], retrieved_docs): | |
| """ | |
| :param history: | |
| :param db: | |
| :param retrieved_docs: | |
| :return: | |
| """ | |
| if db: | |
| last_user_message = history[-1][0] | |
| try: | |
| docs = db.similarity_search(last_user_message, k=4) | |
| # retriever = db.as_retriever(search_kwargs={"k": k_documents}) | |
| # docs = retriever.get_relevant_documents(last_user_message) | |
| except RuntimeError: | |
| docs = db.similarity_search(last_user_message, k=1) | |
| # retriever = db.as_retriever(search_kwargs={"k": 1}) | |
| # docs = retriever.get_relevant_documents(last_user_message) | |
| source_docs = set() | |
| for doc in docs: | |
| for content in doc.metadata.values(): | |
| source_docs.add(content.split("/")[-1]) | |
| retrieved_docs = "\n\n".join([doc.page_content for doc in docs]) | |
| retrieved_docs = f"A document- {''.join(list(source_docs))}.\n\n{retrieved_docs}" | |
| return retrieved_docs | |
| def bot(self, history, retrieved_docs): | |
| """ | |
| :param history: | |
| :param retrieved_docs: | |
| :return: | |
| """ | |
| if not history: | |
| return | |
| tokens = self.get_system_tokens(self.llama_model)[:] | |
| tokens.append(LINEBREAK_TOKEN) | |
| for user_message, bot_message in history[:-1]: | |
| message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=user_message) | |
| tokens.extend(message_tokens) | |
| last_user_message = history[-1][0] | |
| if retrieved_docs: | |
| last_user_message = f"Context: {retrieved_docs}\n\nUsing context, answer the question:" \ | |
| f"{last_user_message}" | |
| message_tokens = self.get_message_tokens(model=self.llama_model, role="user", content=last_user_message) | |
| tokens.extend(message_tokens) | |
| role_tokens = [self.llama_model.token_bos(), BOT_TOKEN, LINEBREAK_TOKEN] | |
| tokens.extend(role_tokens) | |
| generator = self.llama_model.generate( | |
| tokens, | |
| top_k=30, | |
| top_p=0.9, | |
| temp=0.1 | |
| ) | |
| partial_text = "" | |
| for i, token in enumerate(generator): | |
| if token == self.llama_model.token_eos() or (MAX_NEW_TOKENS is not None and i >= MAX_NEW_TOKENS): | |
| break | |
| partial_text += self.llama_model.detokenize([token]).decode("utf-8", "ignore") | |
| history[-1][1] = partial_text | |
| yield history | |
| def run(self): | |
| """ | |
| :return: | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=BLOCK_CSS) as demo: | |
| db: Optional[Chroma] = gr.State(None) | |
| favicon = f'<img src="{FAVICON_PATH}" width="48px" style="display: inline">' | |
| gr.Markdown( | |
| f"""<h1><center>{favicon} GPT-based text assistant</center></h1>""" | |
| ) | |
| with gr.Row(elem_id="model_selector_row"): | |
| models: list = list(DICT_REPO_AND_MODELS.values()) | |
| model_selector = gr.Dropdown( | |
| choices=models, | |
| value=models[0] if models else "", | |
| interactive=True, | |
| show_label=False, | |
| container=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| chatbot = gr.Chatbot(label="Dialogue", height=400) | |
| with gr.Column(min_width=200, scale=4): | |
| retrieved_docs = gr.Textbox( | |
| label="Extracted fragments", | |
| placeholder="Will appear after asking questions", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=20): | |
| msg = gr.Textbox( | |
| label="send a message", | |
| show_label=False, | |
| placeholder="send a message", | |
| container=False | |
| ) | |
| with gr.Column(scale=3, min_width=100): | |
| submit = gr.Button("📤 Send", variant="primary") | |
| with gr.Row(): | |
| # gr.Button(value="👍 Понравилось") | |
| # gr.Button(value="👎 Не понравилось") | |
| stop = gr.Button(value="⛔ Stop") | |
| regenerate = gr.Button(value="🔄 Repeat") | |
| clear = gr.Button(value="🗑️ Clear") | |
| # # Upload files | |
| # file_output.upload( | |
| # fn=self.upload_files, | |
| # inputs=[file_output], | |
| # outputs=[file_paths], | |
| # queue=True, | |
| # ).success( | |
| # fn=self.build_index, | |
| # inputs=[file_paths, db, chunk_size, chunk_overlap], | |
| # outputs=[db, file_warning], | |
| # queue=True | |
| # ) | |
| model_selector.change( | |
| fn=self.load_model, | |
| inputs=[model_selector], | |
| outputs=[model_selector] | |
| ) | |
| # Pressing Enter | |
| submit_event = msg.submit( | |
| fn=self.user, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).success( | |
| fn=self.retrieve, | |
| inputs=[chatbot, db, retrieved_docs], | |
| outputs=[retrieved_docs], | |
| queue=True, | |
| ).success( | |
| fn=self.bot, | |
| inputs=[chatbot, retrieved_docs], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| # Pressing the button | |
| submit_click_event = submit.click( | |
| fn=self.user, | |
| inputs=[msg, chatbot], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).success( | |
| fn=self.retrieve, | |
| inputs=[chatbot, db, retrieved_docs], | |
| outputs=[retrieved_docs], | |
| queue=True, | |
| ).success( | |
| fn=self.bot, | |
| inputs=[chatbot, retrieved_docs], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| # Stop generation | |
| stop.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| cancels=[submit_event, submit_click_event], | |
| queue=False, | |
| ) | |
| # Regenerate | |
| regenerate.click( | |
| fn=self.regenerate_response, | |
| inputs=[chatbot], | |
| outputs=[msg, chatbot], | |
| queue=False, | |
| ).success( | |
| fn=self.retrieve, | |
| inputs=[chatbot, db, retrieved_docs], | |
| outputs=[retrieved_docs], | |
| queue=True, | |
| ).success( | |
| fn=self.bot, | |
| inputs=[chatbot, retrieved_docs], | |
| outputs=chatbot, | |
| queue=True, | |
| ) | |
| # Clear history | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue(max_size=128, default_concurrency_limit=10, api_open=False) | |
| demo.launch(server_name="0.0.0.0", max_threads=200) | |
| if __name__ == "__main__": | |
| local_chat_gpt = LocalChatGPT() | |
| local_chat_gpt.run() |