File size: 6,794 Bytes
e1fe2e7
 
 
 
 
 
 
aa877c7
e1fe2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa877c7
e1fe2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa877c7
 
 
 
 
 
 
 
 
 
 
 
e1fe2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os, json, re
from typing import List, Tuple

import numpy as np
import gradio as gr
import faiss
from sentence_transformers import SentenceTransformer
HF_TOKEN = os.getenv("HF_TOKEN") 

# ---------- Paths (expects files committed under ./assets) ----------
APP_DIR    = os.path.dirname(__file__)
ASSETS_DIR = os.path.join(APP_DIR, "assets")
CACHE_DIR  = "/mnt/data/eg_space_cache"  # runtime cache
os.makedirs(CACHE_DIR, exist_ok=True)

CORPUS_JSON = os.path.join(ASSETS_DIR, "corpus.json")
EMB_FP32    = os.path.join(ASSETS_DIR, "doc_embs_fp32.npy")
EMB_FP16    = os.path.join(ASSETS_DIR, "doc_embs_fp16.npy")
FAISS_MAIN  = os.path.join(ASSETS_DIR, "faiss_ip_768.index")

# ---------- Matryoshka dims ----------
MATRYOSHKA_DIMS = [768, 512, 256, 128]
DEFAULT_DIMS    = 768

# ---------- Load corpus ----------
with open(CORPUS_JSON, "r", encoding="utf-8") as f:
    corpus = json.load(f)  # list of {"title","text"} in EXACT same order as embeddings

# ---------- Load embeddings ----------
if os.path.exists(EMB_FP32):
    doc_embs = np.load(EMB_FP32).astype(np.float32, copy=False)
elif os.path.exists(EMB_FP16):
    doc_embs = np.load(EMB_FP16).astype(np.float32)  # cast back for FAISS
else:
    raise FileNotFoundError("Expected assets/doc_embs_fp32.npy or assets/doc_embs_fp16.npy")

if doc_embs.ndim != 2 or doc_embs.shape[0] != len(corpus):
    raise ValueError("Embeddings shape mismatch vs corpus length.")

EMB_DIM = doc_embs.shape[1]  # should be 768

# ---------- Model (for queries + sentence-level ops) ----------
model = SentenceTransformer("google/embeddinggemma-300m", token=HF_TOKEN)  # CPU is fine for queries

# ---------- FAISS indexes ----------
if os.path.exists(FAISS_MAIN):
    base_index_768 = faiss.read_index(FAISS_MAIN)
else:
    base_index_768 = faiss.IndexFlatIP(EMB_DIM)
    base_index_768.add(doc_embs.astype(np.float32, copy=False))

# Build per-dimension flat IP indexes from the loaded embeddings
class MultiDimFaiss:
    def __init__(self, doc_embs_full: np.ndarray):
        self.full = doc_embs_full
        self.indexes = {}
        for d in MATRYOSHKA_DIMS:
            if d == 768 and FAISS_MAIN and os.path.exists(FAISS_MAIN):
                self.indexes[d] = base_index_768
            else:
                view = self.full[:, :d].astype(np.float32, copy=False)
                idx = faiss.IndexFlatIP(d)
                idx.add(view)
                self.indexes[d] = idx

    def search(self, q_vec: np.ndarray, top_k: int, dims: int) -> Tuple[np.ndarray, np.ndarray]:
        q = q_vec[:dims].astype(np.float32, copy=False)[None, :]
        idx = self.indexes[dims]
        return idx.search(q, top_k)

faiss_md = MultiDimFaiss(doc_embs)

# ---------- Core ops ----------
def _format_snippet(text: str, max_len: int = 380) -> str:
    return text[:max_len] + ("…" if len(text) > max_len else "")

def do_search(query: str, top_k: int = 5, dims: int = DEFAULT_DIMS) -> List[List[str]]:
    if not query or not query.strip():
        return []
    q_emb = model.encode_query(
        query.strip(),
        normalize_embeddings=True,
        convert_to_numpy=True
    )
    scores, idxs = faiss_md.search(q_emb, top_k=top_k, dims=dims)
    rows = []
    for s, i in zip(scores[0].tolist(), idxs[0].tolist()):
        if i == -1:
            continue
        title = corpus[i]["title"]
        snippet = _format_snippet(corpus[i]["text"])
        rows.append([f"{s:.4f}", title, snippet])
    return rows

def do_similarity(text_a: str, text_b: str, dims: int = DEFAULT_DIMS) -> float:
    if not text_a or not text_b:
        return 0.0
    a = model.encode_document([text_a], normalize_embeddings=True, convert_to_numpy=True)[0][:dims]
    b = model.encode_document([text_b], normalize_embeddings=True, convert_to_numpy=True)[0][:dims]
    return float(np.dot(a, b))

# ---------- Gradio UI ----------
with gr.Blocks(title="EmbeddingGemma × Wikipedia (EN corpus)") as demo:
    gr.Markdown(
        """
    # Demo: EmbeddingGemma × Wikipedia (EN corpus)
    
    This Space showcases [Google DeepMind’s EmbeddingGemma models](https://huggingface.co/collections/google/embeddinggemma-68b9ae3a72a82f0562a80dc4), on a pre-indexed **random 10k sample** of [English Wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia).
    You can try:
    
    - **Semantic search** (English queries)  
    - **Cross-lingual search** (queries in other languages → English articles)  
    - **Sentence similarity** (compare two texts)  
    
    🔗 Learn more in the [EmbeddingGemma blog post](https://huggingface.co/blog/embeddinggemma).
    """
    )

    with gr.Tabs():
        # 1) Semantic Search (EN-only corpus)
        with gr.TabItem("Semantic Search (EN corpus)"):
            with gr.Row():
                q = gr.Textbox(label="Query", value="Who discovered penicillin?")
                topk = gr.Slider(1, 20, value=5, step=1, label="Top-K")
                dims = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
                run = gr.Button("Search")
            out = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True)
            run.click(lambda query, k, d: do_search(query, int(k), int(d)), [q, topk, dims], out)

        # 2) Cross-Lingual (queries in FR/ES/etc → EN corpus)
        with gr.TabItem("Cross-Lingual (EN corpus)"):
            gr.Markdown("Type your query in **French/Spanish/Arabic**. Results come from the **English-only** corpus.")
            with gr.Row():
                qx = gr.Textbox(label="Query", value="¿Quién descubrió la penicilina?")
                topkx = gr.Slider(1, 20, value=5, step=1, label="Top-K")
                dimsx = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
                runx = gr.Button("Search")
            outx = gr.Dataframe(headers=["score", "title", "snippet"], wrap=True)
            runx.click(lambda query, k, d: do_search(query, int(k), int(d)), [qx, topkx, dimsx], outx)

        # 3) Similarity
        with gr.TabItem("Similarity"):
            with gr.Row():
                a = gr.Textbox(lines=5, label="Text A", value="Alexander Fleming observed a mold that killed bacteria in 1928.")
                b = gr.Textbox(lines=5, label="Text B", value="La penicilina fue descubierta por Alexander Fleming en 1928.")
            dims2 = gr.Dropdown([str(d) for d in MATRYOSHKA_DIMS], value=str(DEFAULT_DIMS), label="Embedding dims")
            sim_btn = gr.Button("Compute Similarity")
            sim_out = gr.Number(label="Cosine similarity (-1..1)")
            sim_btn.click(lambda x, y, d: do_similarity(x, y, int(d)), [a, b, dims2], sim_out)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)