Spaces:
Paused
Paused
| import transformers | |
| import re | |
| from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| import shutil | |
| import requests | |
| import lancedb | |
| import pandas as pd | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = "PleIAs/Pleias-3b-rag" | |
| # Get Hugging Face token from environment variable | |
| hf_token = os.environ.get('HF_TOKEN') | |
| if not hf_token: | |
| raise ValueError("Please set the HF_TOKEN environment variable") | |
| # Initialize model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token) | |
| model.to(device) | |
| # Set tokenizer configuration | |
| tokenizer.eos_token = "<|answer_end|>" | |
| eos_token_id=tokenizer.eos_token_id | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = 1 | |
| # Define variables | |
| temperature = 0.0 | |
| max_new_tokens = 1200 | |
| top_p = 0.95 | |
| repetition_penalty = 1.0 | |
| min_new_tokens = 600 | |
| early_stopping = False | |
| # Connect to the LanceDB database | |
| db = lancedb.connect("content19/lancedb_data") | |
| table = db.open_table("edunat19") | |
| def hybrid_search(text): | |
| results = table.search(text, query_type="hybrid").limit(5).to_pandas() | |
| # Add a check for duplicate hashes | |
| seen_hashes = set() | |
| document = [] | |
| document_html = [] | |
| for _, row in results.iterrows(): | |
| hash_id = str(row['hash']) | |
| # Skip if we've already seen this hash | |
| if hash_id in seen_hashes: | |
| continue | |
| seen_hashes.add(hash_id) | |
| title = row['section'] | |
| content = row['text'] | |
| document.append(f"<|source_start|><|source_id_start|>{hash_id}<|source_id_end|>{title}\n{content}<|source_end|>") | |
| document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>') | |
| document = "\n".join(document) | |
| document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
| return document, document_html | |
| class pleiasBot: | |
| def __init__(self, system_prompt="Tu es Appli, un asistant de recherche qui donne des responses sourcées"): | |
| self.system_prompt = system_prompt | |
| def predict(self, user_message): | |
| fiches, fiches_html = hybrid_search(user_message) | |
| detailed_prompt = f"""<|query_start|>{user_message}<|query_end|>\n{fiches}\n<|source_analysis_start|>""" | |
| # Convert inputs to tensor | |
| input_ids = tokenizer.encode(detailed_prompt, return_tensors="pt").to(device) | |
| attention_mask = torch.ones_like(input_ids) | |
| try: | |
| output = model.generate( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| early_stopping=early_stopping, | |
| min_new_tokens=min_new_tokens, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the generated text | |
| generated_text = tokenizer.decode(output[0][len(input_ids[0]):]) | |
| # Split the text into analysis and answer sections | |
| parts = generated_text.split("<|source_analysis_end|>") | |
| if len(parts) == 2: | |
| analysis = parts[0].strip() | |
| answer = parts[1].replace("<|answer_start|>", "").replace("<|answer_end|>", "").strip() | |
| # Format each section with matching h2 titles | |
| analysis_text = '<h2 style="text-align:center">Analyse des sources</h2>\n<div class="generation">' + format_references(analysis) + "</div>" | |
| answer_text = '<h2 style="text-align:center">Réponse</h2>\n<div class="generation">' + format_references(answer) + "</div>" | |
| else: | |
| analysis_text = "" | |
| answer_text = format_references(generated_text) | |
| fiches_html = '<h2 style="text-align:center">Sources</h2>\n' + fiches_html | |
| return analysis_text, answer_text, fiches_html | |
| except Exception as e: | |
| print(f"Error during generation: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, None | |
| def format_references(text): | |
| ref_pattern = r'<ref name="([^"]+)">"([^"]+)"</ref>\.\s*' # Modified pattern to include the period and whitespace after ref | |
| parts = [] | |
| current_pos = 0 | |
| ref_number = 1 | |
| for match in re.finditer(ref_pattern, text): | |
| # Add text before the reference | |
| text_before = text[current_pos:match.start()].rstrip() | |
| parts.append(text_before) | |
| # Extract reference components | |
| ref_id = match.group(1) | |
| ref_text = match.group(2).strip() | |
| # Add the reference, keeping the existing structure but adding <br> where whitespace was | |
| tooltip_html = f'<span class="tooltip"><strong>[{ref_number}]</strong><span class="tooltiptext"><strong>{ref_id}</strong>: {ref_text}</span></span>.<br>' | |
| parts.append(tooltip_html) | |
| current_pos = match.end() | |
| ref_number += 1 | |
| # Add any remaining text | |
| parts.append(text[current_pos:]) | |
| return ''.join(parts) | |
| # Initialize the pleiasBot | |
| pleias_bot = pleiasBot() | |
| # CSS for styling | |
| css = """ | |
| .generation { | |
| margin-left: 2em; | |
| margin-right: 2em; | |
| } | |
| :target { | |
| background-color: #CCF3DF; | |
| } | |
| .source { | |
| float: left; | |
| max-width: 17%; | |
| margin-left: 2%; | |
| } | |
| .tooltip { | |
| position: relative; | |
| display: inline-block; | |
| color: #183EFA; | |
| font-weight: bold; | |
| cursor: pointer; | |
| } | |
| .tooltip .tooltiptext { | |
| visibility: hidden; | |
| background-color: #fff; | |
| color: #000; | |
| text-align: left; | |
| padding: 12px; | |
| border-radius: 6px; | |
| border: 1px solid #e5e7eb; | |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); | |
| position: absolute; | |
| z-index: 1; | |
| bottom: 125%; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| min-width: 300px; | |
| max-width: 400px; | |
| white-space: normal; | |
| font-size: 0.9em; | |
| line-height: 1.4; | |
| } | |
| .tooltip:hover .tooltiptext { | |
| visibility: visible; | |
| } | |
| .tooltip .tooltiptext::after { | |
| content: ""; | |
| position: absolute; | |
| top: 100%; | |
| left: 50%; | |
| margin-left: -5px; | |
| border-width: 5px; | |
| border-style: solid; | |
| border-color: #fff transparent transparent transparent; | |
| } | |
| .section-title { | |
| font-weight: bold; | |
| font-size: 15px; | |
| margin-bottom: 1em; | |
| margin-top: 1em; | |
| } | |
| """ | |
| # Gradio interface | |
| def gradio_interface(user_message): | |
| analysis, response, sources = pleias_bot.predict(user_message) | |
| return analysis, response, sources | |
| # Create Gradio app | |
| demo = gr.Blocks(css=css) | |
| with demo: | |
| # Header with black bar | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; width: 100%; background-color: black; padding: 5px 0;"> | |
| <pre style="font-family: monospace; line-height: 1.2; font-size: 12px; color: #00ffea; margin: 0;"> | |
| _ _ ______ ___ _____ | |
| | | (_) | ___ \\/ _ \\| __ \\ | |
| _ __ | | ___ _ __ _ ___ ______ | |_/ / /_\\ \\ | \\/ | |
| | '_ \\| |/ _ \\ |/ _` / __| |______| | /| _ | | __ | |
| | |_) | | __/ | (_| \\__ \\ | |\\ \\| | | | |_\\ \\ | |
| | .__/|_|\\___|_|\\__,_|___/ \\_| \\_\\_| |_/\\____/ | |
| | | | |
| |_| </pre> | |
| </div> | |
| """) | |
| # Centered input section | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3) | |
| text_button = gr.Button("Interroger pleias-RAG") | |
| # Analysis and Response in side-by-side columns | |
| with gr.Row(): | |
| # Left column for analysis | |
| with gr.Column(scale=2): | |
| text_output = gr.HTML(label="Analyse des sources") | |
| # Right column for response | |
| with gr.Column(scale=3): | |
| response_output = gr.HTML(label="Réponse") | |
| # Sources at the bottom | |
| with gr.Row(): | |
| embedding_output = gr.HTML(label="Les sources utilisées") | |
| text_button.click(gradio_interface, | |
| inputs=text_input, | |
| outputs=[text_output, response_output, embedding_output]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |