Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from tavily import TavilyClient | |
| from openai import OpenAI | |
| from crewai_tools import RagTool | |
| from pydantic import Field, PrivateAttr | |
| import os | |
| from html import unescape | |
| import re | |
| import logging | |
| class HybridRetrieverTool(RagTool): | |
| name: str = "Hybrid Retriever Tool" | |
| description: str = "Combines BM25 keyword scoring with semantic similarity for hybrid retrieval" | |
| alpha: float = Field(default=0.6, description="Weight between semantic and lexical scores") | |
| # Define private attributes | |
| _embedder: SentenceTransformer = PrivateAttr() | |
| _tavily: TavilyClient = PrivateAttr() | |
| _client: OpenAI = PrivateAttr() | |
| def __init__(self, **data): | |
| super().__init__(**data) | |
| self._embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| self._tavily = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) | |
| self._client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # 🧹 Text Cleaning | |
| def _clean_text(self, text: str): | |
| """ | |
| Clean Tavily content by removing HTML, bullets, boilerplate, and repetitive junk | |
| while preserving high-value plain text and extracting source URLs for citation. | |
| """ | |
| if not text or len(text.strip()) < 10: | |
| return None, [] | |
| # Extract URLs for citation before cleaning | |
| urls = re.findall(r'https?://\S+', text) | |
| # Decode HTML entities and remove tags | |
| text = unescape(text) | |
| text = re.sub(r"<[^>]+>", " ", text) # strip HTML tags | |
| text = re.sub(r"!\[.*?\]\(.*?\)", " ", text) # remove Markdown images | |
| text = re.sub(r"\[.*?\]\(.*?\)", " ", text) # remove Markdown links | |
| text = re.sub(r"\S+\.(jpg|jpeg|png|gif|svg|webp|pdf)", " ", text, flags=re.I) | |
| text = re.sub(r"http\S+", " ", text) # remove URLs inline | |
| # Remove layout and boilerplate junk | |
| text = re.sub(r"(Share|Tweet|Email|Login|Subscribe|Learn More|Read More|Click Here)+", " ", text, flags=re.I) | |
| text = re.sub(r"(Education Weekly Update.*?)+", " ", text, flags=re.I) | |
| text = re.sub(r"(\bAI\s*\+\s*){2,}", "AI ", text) # collapse 'AI + AI + AI' | |
| text = re.sub(r"[•·●○◦‣⁃∙▪]+", " ", text) # remove bullet symbols | |
| text = re.sub(r"(?m)^\s*#.*$", " ", text) # remove markdown headers | |
| text = re.sub(r"\b[A-Z]{2,}\b( [A-Z]{2,}\b)+", " ", text) # collapse ALLCAPS runs | |
| text = text.replace("\xa0", " ") # remove non-breaking spaces | |
| text = re.sub(r"\s{2,}", " ", text).strip() # normalize whitespace | |
| # Filter out boilerplate / short junk sections | |
| if any(kw in text.lower() for kw in [ | |
| "education weekly update", | |
| "copyright", | |
| "terms of use", | |
| "cookie policy", | |
| "advertisement", | |
| "site map", | |
| ]): | |
| return None, [] | |
| if len(text.split()) < 30: | |
| return None, [] | |
| # Normalize casing (optional but improves readability) | |
| text = text[0].upper() + text[1:] if len(text) > 1 else text | |
| return text, urls | |
| def _build_corpus(self, topic: str, top_k: int = 8): | |
| """Fetch up-to-date search results.""" | |
| results = self._tavily.search(query=topic, max_results=50) | |
| raw_texts = [r.get("content", "").strip() for r in results.get("results", []) if r.get("content")] | |
| corpus, all_urls = [], [] | |
| for t in raw_texts: | |
| clean_text, urls = self._clean_text(t) | |
| if clean_text: | |
| corpus.append(clean_text) | |
| all_urls.extend(urls) | |
| #Deduplicate and keep top unique URLs | |
| all_urls = list(dict.fromkeys(all_urls))[:top_k] | |
| return corpus, all_urls | |
| # LLM reranker | |
| def _rerank(self, query: str, passages: list[str], top_n: int) -> list[str]: | |
| """ | |
| Use an LLM to re-rank retrieved passages for contextual relevance to the query. | |
| """ | |
| if not passages: | |
| return [] | |
| try: | |
| formatted_passages = "\n\n".join( | |
| [f"Passage {i+1}:\n{p}" for i, p in enumerate(passages)] | |
| ) | |
| prompt = f""" | |
| You are a precise research assistant that ranks text passages for relevance. | |
| Query: | |
| "{query}" | |
| Passages: | |
| {formatted_passages} | |
| Instructions: | |
| - Rank passages by how directly and substantively they address the query. | |
| - Ignore repetitive, boilerplate, or promotional content. | |
| - Return ONLY the top {top_n} most relevant passages, in their original text form. | |
| """ | |
| response = self._client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert LLM reranker for information retrieval."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0, | |
| ) | |
| ranked_text = response.choices[0].message.content.strip() | |
| reranked = re.split(r"Passage\s*\d+:", ranked_text) | |
| reranked = [p.strip() for p in reranked if len(p.strip()) > 20] | |
| if len(reranked) == 0: | |
| print("⚠️ Reranker returned no valid text, using original order.") | |
| return passages[:top_n] | |
| return reranked[:top_n] | |
| except Exception as e: | |
| logging.warning(f"Reranker failed: {e}") | |
| return passages[:top_n] | |
| def _run(self, query: str, top_k: int = 8) -> str: | |
| """ | |
| Run hybrid search: BM25 + semantic similarity. | |
| """ | |
| corpus, urls = self._build_corpus(query, top_k=top_k) | |
| if not corpus: | |
| return "No relevant content found." | |
| # Lexical relevance | |
| bm25 = BM25Okapi([doc.split() for doc in corpus]) | |
| bm25_scores = np.array(bm25.get_scores(query.split())) | |
| # semantic relevance | |
| emb_corpus = self._embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True) | |
| emb_query = self._embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True) | |
| sem_scores = np.dot(emb_corpus, emb_query) | |
| # Normalize scores | |
| if np.ptp(bm25_scores) == 0: | |
| bm25_norm = np.zeros_like(bm25_scores) #ensure BM25 works even if only one doc | |
| else: | |
| bm25_norm = (bm25_scores - bm25_scores.min()) / (np.ptp(bm25_scores) + 1e-8) | |
| sem_norm = (sem_scores - sem_scores.min()) / (np.ptp(sem_scores) + 1e-8) | |
| # Weighted fusion | |
| hybrid_scores = self.alpha * sem_norm + (1 - self.alpha) * bm25_norm | |
| top_indices= np.argsort(hybrid_scores)[::-1][:top_k] | |
| top_passages = [corpus[i] for i in top_indices] | |
| reranked = self._rerank(query, top_passages, top_n=top_k) | |
| return "\n\n".join(reranked) | |
| def summarize_passages(self, topic: str, passages, top_k: int = 8): | |
| """Summarize retrieved content into a coherent short digest, keeping citations.""" | |
| if isinstance(passages, str): | |
| passages = [passages] | |
| # Clean and compress passages | |
| main_text = [] | |
| urls = [] | |
| for p in passages: | |
| text, found_urls = self._clean_text(p) | |
| if text: | |
| main_text.append(text) | |
| urls.extend(found_urls) | |
| if not main_text: | |
| return "No meaningful content found to summarize." | |
| # --- Limit and re-rank by diversity --- | |
| unique_texts = list(dict.fromkeys(main_text))[:5] # prevent duplication | |
| text_block = " ".join(unique_texts) | |
| text_block = re.sub(r"\s{2,}", " ", text_block).strip() | |
| text_block = text_block[:5000] # safety limit for token size | |
| unique_urls = list(dict.fromkeys(urls))[:top_k] | |
| # --- Structured summarization --- | |
| prompt = f""" | |
| You are a research assistant creating a clean, readable summary. | |
| Topic: {topic} | |
| Condense the following information into **2–3 coherent paragraphs** that: | |
| 1. Focus on factual insights and trends, not raw data or footnotes. | |
| 2. Remove list items, footers, or numeric citations (like (1), (2)). | |
| 3. Retain key facts, organizations, or findings. | |
| 4. Avoid repeating words or phrases. | |
| 5. Conclude with a single “Sources” section listing the most relevant URLs. | |
| Text to summarize: | |
| {text_block} | |
| Return output in Markdown format. | |
| """ | |
| try: | |
| response = self._client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are a concise, professional summarizer."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.3 | |
| ) | |
| summary = response.choices[0].message.content.strip() | |
| if unique_urls: | |
| if unique_urls: | |
| summary += "\n\n**Sources:**\n" + "\n".join(f"- [{u}]({u})" for u in unique_urls) | |
| return summary | |
| except Exception as e: | |
| return f"Summarization failed: {e}" | |