knguyen471 commited on
Commit
ce42873
·
verified ·
1 Parent(s): 78f14ab

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +6 -5
  2. main.py +11 -14
app.py CHANGED
@@ -108,6 +108,7 @@ with gr.Blocks(
108
  data_source = gr.Dropdown(
109
  choices=["Michelin Guide", "Google", "Yelp"],
110
  value="Yelp",
 
111
  label="Data Source",
112
  info="Select restaurant data source"
113
  )
@@ -142,10 +143,10 @@ with gr.Blocks(
142
 
143
  examples = [
144
  ["Italian pasta", "Yelp", 10],
145
- ["sushi", "Michelin", 10],
146
  ["romantic dinner", "Google", 8],
147
  ["family-friendly pizza", "Yelp", 10],
148
- ["best seafood", "Michelin", 10],
149
  ["cheap burger", "Google", 10]
150
  ]
151
 
@@ -172,7 +173,7 @@ if __name__ == "__main__":
172
  print("Opening at http://127.0.0.1:7860\n")
173
 
174
  # if run locally
175
- # app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True)
176
 
177
- # if run on HF Space
178
- app.launch(ssr_mode=False)
 
108
  data_source = gr.Dropdown(
109
  choices=["Michelin Guide", "Google", "Yelp"],
110
  value="Yelp",
111
+ multiselect=True,
112
  label="Data Source",
113
  info="Select restaurant data source"
114
  )
 
143
 
144
  examples = [
145
  ["Italian pasta", "Yelp", 10],
146
+ ["sushi", "Michelin Guide", 10],
147
  ["romantic dinner", "Google", 8],
148
  ["family-friendly pizza", "Yelp", 10],
149
+ ["best seafood", "Michelin Guide", 10],
150
  ["cheap burger", "Google", 10]
151
  ]
152
 
 
173
  print("Opening at http://127.0.0.1:7860\n")
174
 
175
  # if run locally
176
+ app.launch(share=False, server_name="127.0.0.1", server_port=7860, inbrowser=True)
177
 
178
+ # # if run on HF Space
179
+ # app.launch(ssr_mode=False)
main.py CHANGED
@@ -12,11 +12,7 @@ from utils.semantic_similarity import Encoder
12
  from utils.syntactic_similarity import Parser
13
  from utils.tfidf_similarity import TFIDF_Vectorizer
14
 
15
- # Set default device to CUDA if available, otherwise CPU
16
- if torch.cuda.is_available():
17
- torch.set_default_device("cuda")
18
- else:
19
- torch.set_default_device("cpu")
20
 
21
  # Download models/data
22
  nltk.download('punkt')
@@ -30,9 +26,7 @@ data = pd.read_csv("data/toy_data_aggregated_embeddings.csv")
30
  with open("data/restaurant_by_source.json", "r") as f:
31
  restaurant_by_source = json.load(f)
32
 
33
- # Load precomputed TF-IDF features
34
- # restaurant_tfidf_features = np.load("data/toy_data_tfidf_features.npz")
35
-
36
  print("Computing TFIDF")
37
  tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
38
  restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
@@ -91,7 +85,7 @@ def retrieve_candidates(query: str, n_candidates: int):
91
  return candidates_idx
92
 
93
 
94
- def rerank(candidates_idx: np.ndarray, n_rec: int = 10, data_source: str = None) -> list:
95
  print("Reranking...")
96
 
97
  # Get popularity scores for stage 1 candidates
@@ -105,15 +99,18 @@ def rerank(candidates_idx: np.ndarray, n_rec: int = 10, data_source: str = None)
105
  restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
106
 
107
  # Filter to only data_source
108
- print(f"[RERANK] Filtering to only source - {data_source}")
109
- restaurant_by_source_set = set(restaurant_by_source[data_source])
110
- restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
 
 
 
111
 
112
  print(f"[RERANK] Final recommendations: {restaurant_ids}")
113
  return restaurant_ids
114
 
115
- def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_source: str = None):
116
  query_clean = clean_text(query)
117
  candidates_idx = retrieve_candidates(query_clean, n_candidates)
118
- restaurant_ids = rerank(candidates_idx, n_rec, data_source)
119
  return restaurant_ids
 
12
  from utils.syntactic_similarity import Parser
13
  from utils.tfidf_similarity import TFIDF_Vectorizer
14
 
15
+ torch.set_default_device("cpu")
 
 
 
 
16
 
17
  # Download models/data
18
  nltk.download('punkt')
 
26
  with open("data/restaurant_by_source.json", "r") as f:
27
  restaurant_by_source = json.load(f)
28
 
29
+ # Compute TFIDF features
 
 
30
  print("Computing TFIDF")
31
  tfidf_vectorizer = TFIDF_Vectorizer(load_vectorizer=False)
32
  restaurant_tfidf_features = tfidf_vectorizer.compute_tfidf_matrix(data["review_text_clean"])
 
85
  return candidates_idx
86
 
87
 
88
+ def rerank(candidates_idx: np.ndarray, n_rec: int, data_sources: list = None) -> list:
89
  print("Reranking...")
90
 
91
  # Get popularity scores for stage 1 candidates
 
99
  restaurant_ids = data.loc[topN_reranked_global_idx, "id"].tolist()
100
 
101
  # Filter to only data_source
102
+ if data_sources is not None:
103
+ print(f"[RERANK] Filtering to only source - {data_sources}")
104
+ restaurant_by_source_set = set()
105
+ for src in data_sources:
106
+ restaurant_by_source_set.update(restaurant_by_source[src])
107
+ restaurant_ids = [x for x in restaurant_ids if x in restaurant_by_source_set]
108
 
109
  print(f"[RERANK] Final recommendations: {restaurant_ids}")
110
  return restaurant_ids
111
 
112
+ def get_recommendations(query: str, n_candidates: int = 100, n_rec: int = 30, data_sources: list = None):
113
  query_clean = clean_text(query)
114
  candidates_idx = retrieve_candidates(query_clean, n_candidates)
115
+ restaurant_ids = rerank(candidates_idx, n_rec, data_sources)
116
  return restaurant_ids