File size: 1,747 Bytes
e47948f
 
4086272
e47948f
 
 
 
 
 
 
 
 
41a3f77
0c9eb21
41a3f77
0c9eb21
e47948f
 
41a3f77
e47948f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41a3f77
162ac4b
0c9eb21
41a3f77
e47948f
 
41a3f77
 
 
 
e47948f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch import nn
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Nome do modelo no Hugging Face Hub
MODEL_NAME = "facebook/dinov2-small"

# Carregando processador e modelo
# Usamos a classe específica Dinov2Model para garantir que o modelo seja carregado corretamente
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = Dinov2Model.from_pretrained(MODEL_NAME)

# Projeção para 512D
projection = nn.Linear(model.config.hidden_size, 512)

# Inicializa o FastAPI
app = FastAPI(
    title="API de Embedding de Imagem",
    description="Endpoint para obter o embedding de uma imagem usando o modelo DINOv2.",
    version="1.0.0"
)

# Define o modelo de dados para a requisição
class ImageRequest(BaseModel):
    image: str

# Define o endpoint para o embedding da imagem
@app.post("/embed")
async def get_embedding(request: ImageRequest):
    try:
        header, img_base64 = request.image.split(",", 1)
        image_data = base64.b64decode(img_base64)
        image = Image.open(BytesIO(image_data))
        
        # --- Lógica de Inferência do seu script original ---
        inputs = processor(images=image, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_state = outputs.last_hidden_state
            embedding = last_hidden_state[:, 0]
            embedding_512 = projection(embedding)
            
        return {"embedding": embedding_512.squeeze().tolist()}
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}")