import gradio as gr import spaces import torch from huggingface_hub import login from sentence_transformers import SentenceTransformer, util from transformers import pipeline, TextIteratorStreamer from threading import Thread from config import ( HF_TOKEN, EMBEDDING_MODEL_ID, LLM_MODEL_ID, DEFAULT_MESSAGE_NO_MATCH, get_all_game_data, BASE_SIMILARITY_THRESHOLD, FOLLOWUP_SIMILARITY_THRESHOLD, silksong_theme, silksong_css, ) class ChatContext: """Holds the conversational state, including the current context and thresholds.""" def __init__(self): self.context_index = -1 self.base_similarity = BASE_SIMILARITY_THRESHOLD self.followup_similarity = FOLLOWUP_SIMILARITY_THRESHOLD print("Logging into Hugging Face Hub...") login(token=HF_TOKEN) print("Initializing embedding model...") embedding_model = SentenceTransformer(EMBEDDING_MODEL_ID) print("Initializing language model...") llm_pipeline = pipeline( "text-generation", model=LLM_MODEL_ID, device_map="auto", dtype="auto", ) knowledge_base = get_all_game_data(embedding_model) def _select_content(title: str) -> list[dict]: """Helper to safely get the knowledge base for a specific title.""" return knowledge_base.get(title, []) @torch.no_grad() def find_best_context(query: str, contents: list[dict], similarity_threshold: float) -> int: """Finds the most relevant document index based on semantic similarity.""" if not query or not contents: return -1 query_embedding = embedding_model.encode(query, prompt_name="query", convert_to_tensor=True).to(embedding_model.device) try: # Stack pre-computed tensors from our knowledge base contents_embeddings = torch.stack([item["embedding"] for item in contents]).to(embedding_model.device) except (RuntimeError, IndexError, TypeError) as e: print(f"Warning: Could not stack content embeddings. Error: {e}") return -1 # Compute cosine similarity between the 1 query embedding and N content embeddings similarities = util.pytorch_cos_sim(query_embedding, contents_embeddings) if similarities.numel() == 0: print("Warning: Similarity computation returned an empty tensor.") return -1 # Get the index and score of the top match best_index = similarities.argmax().item() best_score = similarities[0, best_index].item() print(f"Best score: {best_score:.4f} (Threshold: {similarity_threshold})") if best_score >= similarity_threshold: print(f"Using \"{contents[best_index]['metadata']['source']}\"...") return best_index print("No context met the similarity threshold.") return -1 @spaces.GPU def respond(message: str, history: list, title: str, chat_context: ChatContext): """Generates a streaming response from the LLM based on the best context found.""" default_threshold = chat_context.base_similarity followup_threshold = chat_context.followup_similarity contents = _select_content(title) if not contents: print(f"No content found for {title}") chat_context.context_index = -1 # Return -1 to reset context yield DEFAULT_MESSAGE_NO_MATCH, chat_context return if len(history) == 0: # Clear context on a new conversation print("New conversation started. Clearing context.") chat_context.context_index = -1 # Determine threshold: Use follow-up ONLY if we have a valid previous context. similarity_threshold = followup_threshold if chat_context.context_index != -1 else default_threshold print(f"Using {'follow-up' if chat_context.context_index != -1 else 'default'} threshold: {similarity_threshold}") # Find the best new context based on the current message found_context_index = find_best_context(message, contents, similarity_threshold) if found_context_index >= 0: chat_context.context_index = found_context_index # A new, relevant context was found and set elif chat_context.context_index >= 0: # PASS: A follow-up question, but no new context. Reuse the old one. print("No new context found, reusing previous context for follow-up.") else: # FAILURE: No new context was found AND no previous context exists. print("No context found and no previous context. Yielding no match.") yield DEFAULT_MESSAGE_NO_MATCH, chat_context return system_prompt = f"Answer the following QUESTION based only on the CONTEXT provided. If the answer cannot be found in the CONTEXT, write \"{DEFAULT_MESSAGE_NO_MATCH}\"\n---\nCONTEXT:\n{contents[chat_context.context_index]['text']}\n" user_prompt = f"QUESTION:\n{message}" messages = [{"role": "system", "content": system_prompt}] # Add previous turns (history) after the system prompt but before the current question messages.extend(history) messages.append({"role": "user", "content": user_prompt}) # Debug print the conversation being sent (excluding the large system prompt) for item in messages[1:]: print(f"[{item['role']}] {item['content']}") streamer = TextIteratorStreamer(llm_pipeline.tokenizer, skip_prompt=True, skip_special_tokens=True) thread = Thread( target=llm_pipeline, kwargs=dict( text_inputs=messages, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, temperature=0.7, ) ) thread.start() response = "" for new_text in streamer: response += new_text # Yield the partial response AND the current state yield response, chat_context print(f"[assistant] {response}") # --- GRADIO UI --- # Defines the web interface for the chatbot. @staticmethod def on_title_changed(context_state: ChatContext) -> tuple[str, ChatContext]: """Resets the context display and state when the game is changed.""" context_state.context_index = -1 return """
Speak, little traveler. What secrets of Pharloom do you seek?
(Note: This bot has a limited knowledge.)
Disclaimer: