File size: 4,070 Bytes
812c65f
888aba6
 
 
 
 
83dc914
888aba6
 
 
 
 
 
 
ce42873
888aba6
 
 
 
 
 
 
 
 
812c65f
 
 
af58613
ce42873
83dc914
 
 
888aba6
93e82aa
 
 
 
888aba6
 
 
 
 
 
 
 
 
83dc914
 
888aba6
83dc914
888aba6
 
 
83dc914
888aba6
 
 
83dc914
888aba6
 
 
83dc914
888aba6
 
 
 
83dc914
888aba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83dc914
 
888aba6
 
 
ce42873
83dc914
888aba6
 
 
 
 
 
 
 
 
 
 
83dc914
ce42873
 
 
 
 
 
83dc914
 
888aba6
 
ce42873
888aba6
812c65f
ce42873
888aba6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import torch 
import nltk
import benepar
import pandas as pd 
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity 

from utils.clean_text import clean_text 
from utils.semantic_similarity import Encoder 
from utils.syntactic_similarity import Parser 
from utils.tfidf_similarity import TFIDF_Vectorizer 

torch.set_default_device("cpu")

# Download models/data
nltk.download('punkt')
nltk.download('punkt_tab')
benepar.download('benepar_en3_large')

# Load dataset 
data = pd.read_csv("data/toy_data_aggregated_embeddings.csv") 

# Load restaurant_by_source
with open("data/restaurant_by_source.json", "r") as f:
    restaurant_by_source = json.load(f)

# Compute TFIDF features
print("Computing TFIDF")
tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False) 
restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])

# Extract embeddings
data["embedding"] = data["embedding"].apply(
    lambda x: np.fromstring(x.strip('[]'), sep=' ')
)
all_desc_embeddings = np.vstack(data["embedding"].values) 

# Initialize encoder 
encoder = Encoder() 

# Initialize syntactic parser 
parser = Parser() 

def retrieve_candidates(query: str, n_candidates: int): 
    print(f"Retrieving {n_candidates} candidates...")

    # Encode query 
    print("[RETRIEVAL] Encoding query")
    query_emb = encoder.encode([query]).cpu().numpy() 
    
    # Semantic similarities 
    print("[RETRIEVAL] Computing semantic similarities")
    desc_sem_sim = cosine_similarity(query_emb, all_desc_embeddings)[0] 
    
    # TF-IDF similarities 
    print("[RETRIEVAL] Computing TF-IDF")
    tfidf_sim = tfidf_vectorizer.compute_tfidf_scores(query, restaurant_tfidf_features) 
    
    # Syntactic similarities 
    print("[RETRIEVAL] Computing syntactic similarities")
    parsed_query = parser.parse_text(query)
    parsed_query = parser.subtree_set(parsed_query)

    syn_sims = []
    for trees_list in tqdm(data["syntactic_tree"], total=len(data), desc="[RETRIEVAL] Computing syntactic similarities"):
        review_sims = []
        for review_tree_subs in trees_list:
            if review_tree_subs is None:
                review_tree_subs = set()
            sim = parser.compute_syntactic_similarity(parsed_query, review_tree_subs)
            review_sims.append(sim)
        syn_sims.append(np.mean(review_sims))

    # Combined Stage 1 score
    syn_sims = np.array(syn_sims)
    combined_stage1_scores = 0.8*desc_sem_sim + 0.1*syn_sims + 0.1*tfidf_sim

    # Get top N candidates for Stage 2 reranking
    candidates_idx = np.argsort(combined_stage1_scores)[-n_candidates:][::-1]

    print(f"[RETRIEVAL] Results: {candidates_idx}")

    return candidates_idx


def rerank(candidates_idx: np.ndarray, n_rec: int, data_sources: list = None) -> list:
    print("Reranking...")
    
    # Get popularity scores for stage 1 candidates 
    rerank_scores = data.loc[candidates_idx, "pop_score"].values

    # Retrieve n_rec restaurant based on pop_score 
    topN_reranked_local_idx = np.argsort(rerank_scores)[-n_rec:][::-1] 
    topN_reranked_global_idx = candidates_idx[topN_reranked_local_idx]
    
    # Get restaurant_id for final recommendations 
    restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist() 
    
    # Filter to only data_source
    if data_sources is not None:
        print(f"[RERANK] Filtering to only source - {data_sources}")
        restaurant_by_source_set = set()
        for src in data_sources:
            restaurant_by_source_set.update(restaurant_by_source[src])
        restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
    
    print(f"[RERANK] Final recommendations: {restaurant_ids}")
    return restaurant_ids

def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_sources: list = None):
    query_clean = clean_text(query) 
    candidates_idx = retrieve_candidates(query_clean, n_candidates)
    restaurant_ids = rerank(candidates_idx, n_rec, data_sources)
    return restaurant_ids