Spaces:
Sleeping
Sleeping
| from flask import ( | |
| Flask, | |
| jsonify, | |
| request, | |
| render_template_string, | |
| abort, | |
| ) | |
| from flask_cors import CORS | |
| import unicodedata | |
| import markdown | |
| import time | |
| import os | |
| import gc | |
| import base64 | |
| from io import BytesIO | |
| from random import randint | |
| import hashlib | |
| from colorama import Fore, Style, init as colorama_init | |
| import chromadb | |
| import posthog | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from werkzeug.middleware.proxy_fix import ProxyFix | |
| colorama_init() | |
| port = 7860 | |
| host = "0.0.0.0" | |
| embedding_model = 'sentence-transformers/all-mpnet-base-v2' | |
| print("Initializing ChromaDB") | |
| # disable chromadb telemetry | |
| posthog.capture = lambda *args, **kwargs: None | |
| chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) | |
| chromadb_embedder = SentenceTransformer(embedding_model) | |
| chromadb_embed_fn = chromadb_embedder.encode | |
| # Flask init | |
| app = Flask(__name__) | |
| CORS(app) # allow cross-domain requests | |
| app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 | |
| app.wsgi_app = ProxyFix( | |
| app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 | |
| ) | |
| def get_real_ip(): | |
| return request.remote_addr | |
| def index(): | |
| with open("./README.md", "r", encoding="utf8") as f: | |
| content = f.read() | |
| return render_template_string(markdown.markdown(content, extensions=["tables"])) | |
| def get_modules(): | |
| return jsonify({"modules": ['chromadb']}) | |
| def chromadb_add_messages(): | |
| data = request.get_json() | |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
| abort(400, '"chat_id" is required') | |
| if "messages" not in data or not isinstance(data["messages"], list): | |
| abort(400, '"messages" is required') | |
| ip = get_real_ip() | |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
| collection = chromadb_client.get_or_create_collection( | |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
| ) | |
| documents = [m["content"] for m in data["messages"]] | |
| ids = [m["id"] for m in data["messages"]] | |
| metadatas = [ | |
| {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} | |
| for m in data["messages"] | |
| ] | |
| if len(ids) > 0: | |
| collection.upsert( | |
| ids=ids, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ) | |
| return jsonify({"count": len(ids)}) | |
| def chromadb_query(): | |
| data = request.get_json() | |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
| abort(400, '"chat_id" is required') | |
| if "query" not in data or not isinstance(data["query"], str): | |
| abort(400, '"query" is required') | |
| if "n_results" not in data or not isinstance(data["n_results"], int): | |
| n_results = 1 | |
| else: | |
| n_results = data["n_results"] | |
| ip = get_real_ip() | |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
| collection = chromadb_client.get_or_create_collection( | |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
| ) | |
| n_results = min(collection.count(), n_results) | |
| messages = [] | |
| if n_results > 0: | |
| query_result = collection.query( | |
| query_texts=[data["query"]], | |
| n_results=n_results, | |
| ) | |
| documents = query_result["documents"][0] | |
| ids = query_result["ids"][0] | |
| metadatas = query_result["metadatas"][0] | |
| distances = query_result["distances"][0] | |
| messages = [ | |
| { | |
| "id": ids[i], | |
| "date": metadatas[i]["date"], | |
| "role": metadatas[i]["role"], | |
| "meta": metadatas[i]["meta"], | |
| "content": documents[i], | |
| "distance": distances[i], | |
| } | |
| for i in range(len(ids)) | |
| ] | |
| return jsonify(messages) | |
| def chromadb_purge(): | |
| data = request.get_json() | |
| if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
| abort(400, '"chat_id" is required') | |
| ip = get_real_ip() | |
| chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
| collection = chromadb_client.get_or_create_collection( | |
| name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
| ) | |
| deleted = collection.delete() | |
| print("ChromaDB embeddings deleted", len(deleted)) | |
| return 'Ok', 200 | |
| app.run(host=host, port=port) |