AI_Blog_Writer / tools /hybrid_retriever_tool.py
cicboy's picture
update changes to hybrid_retriever_tool.py file
90f40ea
import numpy as np
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from tavily import TavilyClient
from openai import OpenAI
from crewai_tools import RagTool
from pydantic import Field, PrivateAttr
import os
from html import unescape
import re
import logging
class HybridRetrieverTool(RagTool):
name: str = "Hybrid Retriever Tool"
description: str = "Combines BM25 keyword scoring with semantic similarity for hybrid retrieval"
alpha: float = Field(default=0.6, description="Weight between semantic and lexical scores")
# Define private attributes
_embedder: SentenceTransformer = PrivateAttr()
_tavily: TavilyClient = PrivateAttr()
_client: OpenAI = PrivateAttr()
def __init__(self, **data):
super().__init__(**data)
self._embedder = SentenceTransformer("all-MiniLM-L6-v2")
self._tavily = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
self._client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# 🧹 Text Cleaning
def _clean_text(self, text: str):
"""
Clean Tavily content by removing HTML, bullets, boilerplate, and repetitive junk
while preserving high-value plain text and extracting source URLs for citation.
"""
if not text or len(text.strip()) < 10:
return None, []
# Extract URLs for citation before cleaning
urls = re.findall(r'https?://\S+', text)
# Decode HTML entities and remove tags
text = unescape(text)
text = re.sub(r"<[^>]+>", " ", text) # strip HTML tags
text = re.sub(r"!\[.*?\]\(.*?\)", " ", text) # remove Markdown images
text = re.sub(r"\[.*?\]\(.*?\)", " ", text) # remove Markdown links
text = re.sub(r"\S+\.(jpg|jpeg|png|gif|svg|webp|pdf)", " ", text, flags=re.I)
text = re.sub(r"http\S+", " ", text) # remove URLs inline
# Remove layout and boilerplate junk
text = re.sub(r"(Share|Tweet|Email|Login|Subscribe|Learn More|Read More|Click Here)+", " ", text, flags=re.I)
text = re.sub(r"(Education Weekly Update.*?)+", " ", text, flags=re.I)
text = re.sub(r"(\bAI\s*\+\s*){2,}", "AI ", text) # collapse 'AI + AI + AI'
text = re.sub(r"[•·●○◦‣⁃∙▪]+", " ", text) # remove bullet symbols
text = re.sub(r"(?m)^\s*#.*$", " ", text) # remove markdown headers
text = re.sub(r"\b[A-Z]{2,}\b( [A-Z]{2,}\b)+", " ", text) # collapse ALLCAPS runs
text = text.replace("\xa0", " ") # remove non-breaking spaces
text = re.sub(r"\s{2,}", " ", text).strip() # normalize whitespace
# Filter out boilerplate / short junk sections
if any(kw in text.lower() for kw in [
"education weekly update",
"copyright",
"terms of use",
"cookie policy",
"advertisement",
"site map",
]):
return None, []
if len(text.split()) < 30:
return None, []
# Normalize casing (optional but improves readability)
text = text[0].upper() + text[1:] if len(text) > 1 else text
return text, urls
def _build_corpus(self, topic: str, top_k: int = 8):
"""Fetch up-to-date search results."""
results = self._tavily.search(query=topic, max_results=50)
raw_texts = [r.get("content", "").strip() for r in results.get("results", []) if r.get("content")]
corpus, all_urls = [], []
for t in raw_texts:
clean_text, urls = self._clean_text(t)
if clean_text:
corpus.append(clean_text)
all_urls.extend(urls)
#Deduplicate and keep top unique URLs
all_urls = list(dict.fromkeys(all_urls))[:top_k]
return corpus, all_urls
# LLM reranker
def _rerank(self, query: str, passages: list[str], top_n: int) -> list[str]:
"""
Use an LLM to re-rank retrieved passages for contextual relevance to the query.
"""
if not passages:
return []
try:
formatted_passages = "\n\n".join(
[f"Passage {i+1}:\n{p}" for i, p in enumerate(passages)]
)
prompt = f"""
You are a precise research assistant that ranks text passages for relevance.
Query:
"{query}"
Passages:
{formatted_passages}
Instructions:
- Rank passages by how directly and substantively they address the query.
- Ignore repetitive, boilerplate, or promotional content.
- Return ONLY the top {top_n} most relevant passages, in their original text form.
"""
response = self._client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are an expert LLM reranker for information retrieval."},
{"role": "user", "content": prompt},
],
temperature=0,
)
ranked_text = response.choices[0].message.content.strip()
reranked = re.split(r"Passage\s*\d+:", ranked_text)
reranked = [p.strip() for p in reranked if len(p.strip()) > 20]
if len(reranked) == 0:
print("⚠️ Reranker returned no valid text, using original order.")
return passages[:top_n]
return reranked[:top_n]
except Exception as e:
logging.warning(f"Reranker failed: {e}")
return passages[:top_n]
def _run(self, query: str, top_k: int = 8) -> str:
"""
Run hybrid search: BM25 + semantic similarity.
"""
corpus, urls = self._build_corpus(query, top_k=top_k)
if not corpus:
return "No relevant content found."
# Lexical relevance
bm25 = BM25Okapi([doc.split() for doc in corpus])
bm25_scores = np.array(bm25.get_scores(query.split()))
# semantic relevance
emb_corpus = self._embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True)
emb_query = self._embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True)
sem_scores = np.dot(emb_corpus, emb_query)
# Normalize scores
if np.ptp(bm25_scores) == 0:
bm25_norm = np.zeros_like(bm25_scores) #ensure BM25 works even if only one doc
else:
bm25_norm = (bm25_scores - bm25_scores.min()) / (np.ptp(bm25_scores) + 1e-8)
sem_norm = (sem_scores - sem_scores.min()) / (np.ptp(sem_scores) + 1e-8)
# Weighted fusion
hybrid_scores = self.alpha * sem_norm + (1 - self.alpha) * bm25_norm
top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
top_passages = [corpus[i] for i in top_indices]
reranked = self._rerank(query, top_passages, top_n=top_k)
return "\n\n".join(reranked)
def summarize_passages(self, topic: str, passages, top_k: int = 8):
"""Summarize retrieved content into a coherent short digest, keeping citations."""
if isinstance(passages, str):
passages = [passages]
# Clean and compress passages
main_text = []
urls = []
for p in passages:
text, found_urls = self._clean_text(p)
if text:
main_text.append(text)
urls.extend(found_urls)
if not main_text:
return "No meaningful content found to summarize."
# --- Limit and re-rank by diversity ---
unique_texts = list(dict.fromkeys(main_text))[:5] # prevent duplication
text_block = " ".join(unique_texts)
text_block = re.sub(r"\s{2,}", " ", text_block).strip()
text_block = text_block[:5000] # safety limit for token size
unique_urls = list(dict.fromkeys(urls))[:top_k]
# --- Structured summarization ---
prompt = f"""
You are a research assistant creating a clean, readable summary.
Topic: {topic}
Condense the following information into **2–3 coherent paragraphs** that:
1. Focus on factual insights and trends, not raw data or footnotes.
2. Remove list items, footers, or numeric citations (like (1), (2)).
3. Retain key facts, organizations, or findings.
4. Avoid repeating words or phrases.
5. Conclude with a single “Sources” section listing the most relevant URLs.
Text to summarize:
{text_block}
Return output in Markdown format.
"""
try:
response = self._client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a concise, professional summarizer."},
{"role": "user", "content": prompt},
],
temperature=0.3
)
summary = response.choices[0].message.content.strip()
if unique_urls:
if unique_urls:
summary += "\n\n**Sources:**\n" + "\n".join(f"- [{u}]({u})" for u in unique_urls)
return summary
except Exception as e:
return f"Summarization failed: {e}"