Spaces:
Running
Running
File size: 4,070 Bytes
812c65f 888aba6 83dc914 888aba6 ce42873 888aba6 812c65f af58613 ce42873 83dc914 888aba6 93e82aa 888aba6 83dc914 888aba6 83dc914 888aba6 83dc914 888aba6 83dc914 888aba6 83dc914 888aba6 83dc914 888aba6 83dc914 888aba6 ce42873 83dc914 888aba6 83dc914 ce42873 83dc914 888aba6 ce42873 888aba6 812c65f ce42873 888aba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import json
import torch
import nltk
import benepar
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from utils.clean_text import clean_text
from utils.semantic_similarity import Encoder
from utils.syntactic_similarity import Parser
from utils.tfidf_similarity import TFIDF_Vectorizer
torch.set_default_device("cpu")
# Download models/data
nltk.download('punkt')
nltk.download('punkt_tab')
benepar.download('benepar_en3_large')
# Load dataset
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
# Load restaurant_by_source
with open("data/restaurant_by_source.json", "r") as f:
restaurant_by_source = json.load(f)
# Compute TFIDF features
print("Computing TFIDF")
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
# Extract embeddings
data["embedding"] = data["embedding"].apply(
lambda x: np.fromstring(x.strip('[]'), sep=' ')
)
all_desc_embeddings = np.vstack(data["embedding"].values)
# Initialize encoder
encoder = Encoder()
# Initialize syntactic parser
parser = Parser()
def retrieve_candidates(query: str, n_candidates: int):
print(f"Retrieving {n_candidates} candidates...")
# Encode query
print("[RETRIEVAL] Encoding query")
query_emb = encoder.encode([query]).cpu().numpy()
# Semantic similarities
print("[RETRIEVAL] Computing semantic similarities")
desc_sem_sim = cosine_similarity(query_emb, all_desc_embeddings)[0]
# TF-IDF similarities
print("[RETRIEVAL] Computing TF-IDF")
tfidf_sim = tfidf_vectorizer.compute_tfidf_scores(query, restaurant_tfidf_features)
# Syntactic similarities
print("[RETRIEVAL] Computing syntactic similarities")
parsed_query = parser.parse_text(query)
parsed_query = parser.subtree_set(parsed_query)
syn_sims = []
for trees_list in tqdm(data["syntactic_tree"], total=len(data), desc="[RETRIEVAL] Computing syntactic similarities"):
review_sims = []
for review_tree_subs in trees_list:
if review_tree_subs is None:
review_tree_subs = set()
sim = parser.compute_syntactic_similarity(parsed_query, review_tree_subs)
review_sims.append(sim)
syn_sims.append(np.mean(review_sims))
# Combined Stage 1 score
syn_sims = np.array(syn_sims)
combined_stage1_scores = 0.8*desc_sem_sim + 0.1*syn_sims + 0.1*tfidf_sim
# Get top N candidates for Stage 2 reranking
candidates_idx = np.argsort(combined_stage1_scores)[-n_candidates:][::-1]
print(f"[RETRIEVAL] Results: {candidates_idx}")
return candidates_idx
def rerank(candidates_idx: np.ndarray, n_rec: int, data_sources: list = None) -> list:
print("Reranking...")
# Get popularity scores for stage 1 candidates
rerank_scores = data.loc[candidates_idx, "pop_score"].values
# Retrieve n_rec restaurant based on pop_score
topN_reranked_local_idx = np.argsort(rerank_scores)[-n_rec:][::-1]
topN_reranked_global_idx = candidates_idx[topN_reranked_local_idx]
# Get restaurant_id for final recommendations
restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
# Filter to only data_source
if data_sources is not None:
print(f"[RERANK] Filtering to only source - {data_sources}")
restaurant_by_source_set = set()
for src in data_sources:
restaurant_by_source_set.update(restaurant_by_source[src])
restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
print(f"[RERANK] Final recommendations: {restaurant_ids}")
return restaurant_ids
def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_sources: list = None):
query_clean = clean_text(query)
candidates_idx = retrieve_candidates(query_clean, n_candidates)
restaurant_ids = rerank(candidates_idx, n_rec, data_sources)
return restaurant_ids
|