Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional, Literal, Dict, Any | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, ConfigDict | |
| from sentence_transformers import SparseEncoder | |
| from transformers import AutoTokenizer | |
| # -------------------------------------------------------------------------------------- | |
| # Logging | |
| # -------------------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("main") | |
| # -------------------------------------------------------------------------------------- | |
| # Device selection — intentionally NEVER choose MPS for SPLADE due to sparse-op gaps | |
| # -------------------------------------------------------------------------------------- | |
| def choose_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| # Avoid MPS for SPLADE (missing sparse ops). Default to CPU instead. | |
| return "cpu" | |
| DEVICE = choose_device() | |
| logger.info(f"Selected device: {DEVICE}") | |
| # -------------------------------------------------------------------------------------- | |
| # Model loading | |
| # -------------------------------------------------------------------------------------- | |
| MODEL_ID = "sparse-encoder/splade-robbert-dutch-base-v1" | |
| def load_sparse_encoder(model_id: str, device: str) -> SparseEncoder: | |
| """Load SparseEncoder. Prefer safetensors when available, but fall back to .bin. | |
| Torch >= 2.6 is required by Transformers to load .bin safely. | |
| """ | |
| # Do NOT force safetensors globally; some repos only publish .bin | |
| os.environ.pop("TRANSFORMERS_USE_SAFETENSORS", None) | |
| try: | |
| logger.info(f"Loading Dutch SPLADE model on {device}...") | |
| m = SparseEncoder(model_id, device=device, model_kwargs={"use_safetensors": True}) | |
| return m | |
| except OSError as e: | |
| msg = str(e) | |
| if "does not appear to have a file named model.safetensors" in msg: | |
| logger.info("No safetensors in repo; retrying with .bin weights.") | |
| return SparseEncoder(model_id, device=device) | |
| raise | |
| model: Optional[SparseEncoder] = None | |
| # Tokenizer for mapping vocab ids -> readable tokens in explanations | |
| tokenizer: Optional[AutoTokenizer] = None | |
| async def lifespan(app: FastAPI): | |
| global model, tokenizer | |
| try: | |
| model = load_sparse_encoder(MODEL_ID, DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| logger.info("Model & tokenizer loaded.") | |
| yield | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| finally: | |
| # Allow GC to clean up if server stops | |
| pass | |
| app = FastAPI(title="Sparse Embedding API", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # -------------------------------------------------------------------------------------- | |
| # Schemas | |
| # -------------------------------------------------------------------------------------- | |
| class HealthResponse(BaseModel): | |
| # Pydantic v2 warns about names starting with model_; allow them explicitly | |
| model_config = ConfigDict(protected_namespaces=()) | |
| model_loaded: bool | |
| model_name: str | |
| device: str | |
| class EmbeddingsRequest(BaseModel): | |
| texts: List[str] | |
| mode: Literal["query", "document"] = "query" | |
| normalize: bool = True | |
| # Keep payloads light; 0/None means no cap | |
| max_active_dims: Optional[int] = 0 | |
| class EmbeddingRow(BaseModel): | |
| indices: List[int] | |
| weights: List[float] | |
| class EmbeddingsResponse(BaseModel): | |
| data: List[EmbeddingRow] | |
| dim: int | |
| info: Dict[str, Any] | |
| # --- Similarity API --- | |
| class SimilarityRequest(BaseModel): | |
| queries: List[str] | |
| documents: List[str] | |
| normalize: bool = True | |
| max_active_dims: Optional[int] = 0 | |
| top_k: Optional[int] = 5 | |
| class SimilarityHit(BaseModel): | |
| doc_index: int | |
| score: float | |
| text: str | |
| class SimilarityResponse(BaseModel): | |
| results: List[List[SimilarityHit]] # one list per query | |
| info: Dict[str, Any] | |
| # --- Explain API --- | |
| class TokenContribution(BaseModel): | |
| token_id: int | |
| token: str | |
| query_weight: float | |
| doc_weight: float | |
| contribution: float | |
| class ExplainRequest(BaseModel): | |
| query: str | |
| document: str | |
| normalize: bool = True | |
| max_active_dims: Optional[int] = 0 | |
| top_k_tokens: int = 15 | |
| class ExplainResponse(BaseModel): | |
| score: float | |
| top_tokens: List[TokenContribution] | |
| info: Dict[str, Any] | |
| # -------------------------------------------------------------------------------------- | |
| # Helpers | |
| # -------------------------------------------------------------------------------------- | |
| def torch_sparse_batch_to_rows(t: torch.Tensor) -> List[Dict[str, Any]]: | |
| """Convert a 2D torch sparse tensor [batch, dim] to list of {indices, weights} per row.""" | |
| if not isinstance(t, torch.Tensor): | |
| raise TypeError("Expected a torch.Tensor from SparseEncoder") | |
| if not t.is_sparse: | |
| # Dense fallback (shouldn't happen with SparseEncoder). Convert per-row. | |
| t = t.to("cpu") | |
| rows = [] | |
| for r in t: | |
| nz = torch.nonzero(r, as_tuple=True)[0] | |
| rows.append({"indices": nz.tolist(), "weights": r[nz].tolist()}) | |
| return rows | |
| # COO expected; coalesce and split by row | |
| t = t.coalesce() # merge duplicates | |
| idx = t.indices() # [2, nnz] | |
| vals = t.values() # [nnz] | |
| batch_size = t.size(0) | |
| rows_out: List[Dict[str, Any]] = [] | |
| row_ids = idx[0] | |
| col_ids = idx[1] | |
| # For each row, mask and gather its entries | |
| for i in range(batch_size): | |
| m = row_ids == i | |
| if torch.count_nonzero(m) == 0: | |
| rows_out.append({"indices": [], "weights": []}) | |
| continue | |
| cols_i = col_ids[m].to("cpu") | |
| vals_i = vals[m].to("cpu") | |
| rows_out.append({"indices": cols_i.tolist(), "weights": vals_i.tolist()}) | |
| return rows_out | |
| def top_token_contributions(q_row: Dict[str, Any], d_row: Dict[str, Any], k: int) -> List[Dict[str, Any]]: | |
| """Intersect query/doc indices and score tokens by product of weights.""" | |
| q_map = {int(i): float(w) for i, w in zip(q_row.get("indices", []), q_row.get("weights", []))} | |
| contribs = [] | |
| for i, dw in zip(d_row.get("indices", []), d_row.get("weights", [])): | |
| i = int(i) | |
| dw = float(dw) | |
| qw = q_map.get(i) | |
| if qw is not None: | |
| contribs.append((i, qw, dw, qw * dw)) | |
| contribs.sort(key=lambda t: t[3], reverse=True) | |
| top = contribs[: max(k, 0) or 15] | |
| out: List[Dict[str, Any]] = [] | |
| for tok_id, qw, dw, c in top: | |
| try: | |
| # RobBERT uses RoBERTa/BPE-style tokens (Ġ denotes a leading space) | |
| tok = tokenizer.convert_ids_to_tokens([tok_id])[0] | |
| pretty = tok.replace("Ġ", " ").replace("▁", " ") | |
| except Exception: | |
| tok = pretty = str(tok_id) | |
| out.append({ | |
| "token_id": tok_id, | |
| "token": pretty, | |
| "query_weight": qw, | |
| "doc_weight": dw, | |
| "contribution": c, | |
| }) | |
| return out | |
| # -------------------------------------------------------------------------------------- | |
| # Routes | |
| # -------------------------------------------------------------------------------------- | |
| async def root(): | |
| return { | |
| "message": "Dutch SPLADE Embedding API", | |
| "docs": "https://moimobrian-py-api.hf.space/docs", | |
| "health": "https://moimobrian-py-api.hf.space/health" | |
| } | |
| async def health() -> HealthResponse: | |
| return HealthResponse( | |
| model_loaded=model is not None, | |
| model_name=MODEL_ID, | |
| device=DEVICE, | |
| ) | |
| async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not req.texts: | |
| raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") | |
| prompt_name = "query" if req.mode == "query" else "document" | |
| max_k = req.max_active_dims or None | |
| logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") | |
| try: | |
| if req.mode == "query": | |
| embs = model.encode_query( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| else: | |
| embs = model.encode_document( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| rows = torch_sparse_batch_to_rows(embs) | |
| # Model card states ~50k dims; we can read the 2nd dimension from the tensor | |
| dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
| return EmbeddingsResponse( | |
| data=[EmbeddingRow(**r) for r in rows], | |
| dim=dim, | |
| info={ | |
| "mode": req.mode, | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": DEVICE, | |
| }, | |
| ) | |
| except RuntimeError as e: | |
| # If anything MPS-related sneaks in, hard-move to CPU and retry once | |
| msg = str(e) | |
| if "MPS" in msg or "to_sparse" in msg: | |
| logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") | |
| try: | |
| model.to("cpu") | |
| if req.mode == "query": | |
| embs = model.encode_query( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device="cpu", | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| else: | |
| embs = model.encode_document( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device="cpu", | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| rows = torch_sparse_batch_to_rows(embs) | |
| dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
| return EmbeddingsResponse( | |
| data=[EmbeddingRow(**r) for r in rows], | |
| dim=dim, | |
| info={ | |
| "mode": req.mode, | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": "cpu", | |
| "retry": True, | |
| }, | |
| ) | |
| except Exception: | |
| logger.exception("CPU retry failed") | |
| raise HTTPException(status_code=500, detail=msg) | |
| # Unknown runtime error | |
| logger.exception("Error generating embeddings") | |
| raise HTTPException(status_code=500, detail=msg) | |
| except Exception as e: | |
| logger.exception("Error generating embeddings") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def similarity(req: SimilarityRequest) -> SimilarityResponse: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not req.queries: | |
| raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") | |
| if not req.documents: | |
| raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") | |
| max_k = req.max_active_dims or None | |
| try: | |
| q = model.encode_query( | |
| req.queries, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| d = model.encode_document( | |
| req.documents, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] | |
| results: List[List[SimilarityHit]] = [] | |
| k = min(req.top_k or 5, len(req.documents)) | |
| for i in range(scores.size(0)): | |
| vals, idxs = torch.topk(scores[i], k=k) | |
| q_hits: List[SimilarityHit] = [] | |
| for v, j in zip(vals.tolist(), idxs.tolist()): | |
| q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) | |
| results.append(q_hits) | |
| return SimilarityResponse( | |
| results=results, | |
| info={ | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": DEVICE, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.exception("Error computing similarity") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------------------------------------------------------------------------------------- | |
| # Routes | |
| # -------------------------------------------------------------------------------------- | |
| async def health() -> HealthResponse: | |
| return HealthResponse( | |
| model_loaded=model is not None, | |
| model_name=MODEL_ID, | |
| device=DEVICE, | |
| ) | |
| async def embeddings(req: EmbeddingsRequest) -> EmbeddingsResponse: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not req.texts: | |
| raise HTTPException(status_code=400, detail="'texts' must be a non-empty list") | |
| prompt_name = "query" if req.mode == "query" else "document" | |
| max_k = req.max_active_dims or None | |
| logger.info(f"Processing {len(req.texts)} texts in {req.mode} mode") | |
| try: | |
| if req.mode == "query": | |
| embs = model.encode_query( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| else: | |
| embs = model.encode_document( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| rows = torch_sparse_batch_to_rows(embs) | |
| # Model card states ~50k dims; we can read the 2nd dimension from the tensor | |
| dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
| return EmbeddingsResponse( | |
| data=[EmbeddingRow(**r) for r in rows], | |
| dim=dim, | |
| info={ | |
| "mode": req.mode, | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": DEVICE, | |
| }, | |
| ) | |
| except RuntimeError as e: | |
| # If anything MPS-related sneaks in, hard-move to CPU and retry once | |
| msg = str(e) | |
| if "MPS" in msg or "to_sparse" in msg: | |
| logger.warning("Encountered MPS/sparse op issue; retrying on CPU.") | |
| try: | |
| model.to("cpu") | |
| if req.mode == "query": | |
| embs = model.encode_query( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device="cpu", | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| else: | |
| embs = model.encode_document( | |
| req.texts, | |
| convert_to_tensor=True, | |
| device="cpu", | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| rows = torch_sparse_batch_to_rows(embs) | |
| dim = int(embs.size(1)) if isinstance(embs, torch.Tensor) else 0 | |
| return EmbeddingsResponse( | |
| data=[EmbeddingRow(**r) for r in rows], | |
| dim=dim, | |
| info={ | |
| "mode": req.mode, | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": "cpu", | |
| "retry": True, | |
| }, | |
| ) | |
| except Exception: | |
| logger.exception("CPU retry failed") | |
| raise HTTPException(status_code=500, detail=msg) | |
| # Unknown runtime error | |
| logger.exception("Error generating embeddings") | |
| raise HTTPException(status_code=500, detail=msg) | |
| except Exception as e: | |
| logger.exception("Error generating embeddings") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def similarity(req: SimilarityRequest) -> SimilarityResponse: | |
| if model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if not req.queries: | |
| raise HTTPException(status_code=400, detail="'queries' must be a non-empty list") | |
| if not req.documents: | |
| raise HTTPException(status_code=400, detail="'documents' must be a non-empty list") | |
| max_k = req.max_active_dims or None | |
| try: | |
| q = model.encode_query( | |
| req.queries, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| d = model.encode_document( | |
| req.documents, | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| scores = model.similarity(q, d).to("cpu") # [num_queries, num_docs] | |
| results: List[List[SimilarityHit]] = [] | |
| k = min(req.top_k or 5, len(req.documents)) | |
| for i in range(scores.size(0)): | |
| vals, idxs = torch.topk(scores[i], k=k) | |
| q_hits: List[SimilarityHit] = [] | |
| for v, j in zip(vals.tolist(), idxs.tolist()): | |
| q_hits.append(SimilarityHit(doc_index=j, score=float(v), text=req.documents[j])) | |
| results.append(q_hits) | |
| return SimilarityResponse( | |
| results=results, | |
| info={ | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": DEVICE, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.exception("Error computing similarity") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def explain(req: ExplainRequest) -> ExplainResponse: | |
| if model is None or tokenizer is None: | |
| raise HTTPException(status_code=503, detail="Model/tokenizer not loaded") | |
| max_k = req.max_active_dims or None | |
| try: | |
| q = model.encode_query( | |
| [req.query], | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| d = model.encode_document( | |
| [req.document], | |
| convert_to_tensor=True, | |
| device=DEVICE, | |
| normalize=req.normalize, | |
| max_active_dims=max_k, | |
| ) | |
| score = float(model.similarity(q, d)[0, 0].item()) | |
| q_row = torch_sparse_batch_to_rows(q)[0] | |
| d_row = torch_sparse_batch_to_rows(d)[0] | |
| tokens = top_token_contributions(q_row, d_row, req.top_k_tokens) | |
| return ExplainResponse( | |
| score=score, | |
| top_tokens=[TokenContribution(**t) for t in tokens], | |
| info={ | |
| "normalize": req.normalize, | |
| "max_active_dims": max_k, | |
| "device": DEVICE, | |
| }, | |
| ) | |
| except Exception as e: | |
| logger.exception("Error explaining match") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # -------------------------------------------------------------------------------------- | |
| # Local dev runner | |
| # -------------------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=True, | |
| log_level="info", | |
| ) | |