Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from transformers import AutoImageProcessor, Dinov2Model | |
| from PIL import Image | |
| import base64 | |
| from io import BytesIO | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import imagehash | |
| MODEL_NAME = "facebook/dinov2-small" | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| print(f"Resize Strategy: {processor.size}") | |
| print(f"Do Center Crop?: {processor.do_center_crop}") | |
| print(f"Crop Size: {processor.crop_size}") | |
| model = Dinov2Model.from_pretrained(MODEL_NAME) | |
| # A camada de projeção para 512 dimensões agora é criada dentro da função, | |
| # para permitir a escolha entre 384 e 512. | |
| app = FastAPI( | |
| title="API de Embedding de Imagem", | |
| description="Endpoint para obter o embedding e pHash de uma imagem.", | |
| version="1.0.0" | |
| ) | |
| class ImageRequest(BaseModel): | |
| image: str | |
| target_dim: int = 512 # <-- NOVO: Parâmetro opcional para a dimensão do embedding | |
| use_float16: bool = False # <-- NOVO: Parâmetro opcional para usar float16 | |
| async def get_embedding(request: ImageRequest): | |
| try: | |
| header, img_base64 = request.image.split(",", 1) | |
| image_data = base64.b64decode(img_base64) | |
| image = Image.open(BytesIO(image_data)).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| last_hidden_state = outputs.last_hidden_state | |
| embedding = last_hidden_state[:, 0] | |
| # Lógica para definir a dimensão | |
| if request.target_dim == 384: | |
| final_embedding = embedding | |
| elif request.target_dim == 512: | |
| projection = nn.Linear(model.config.hidden_size, 512) | |
| final_embedding = projection(embedding) | |
| else: | |
| raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.") | |
| # Lógica para a conversão para float16 | |
| if request.use_float16: | |
| final_embedding = final_embedding.half() | |
| phash = str(imagehash.phash(image)) | |
| return { | |
| "embedding": final_embedding.squeeze().tolist(), | |
| "phash": phash | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}") |