vinithius's picture
Update app.py
0c9eb21 verified
raw
history blame
1.75 kB
import torch
from torch import nn
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Nome do modelo no Hugging Face Hub
MODEL_NAME = "facebook/dinov2-small"
# Carregando processador e modelo
# Usamos a classe específica Dinov2Model para garantir que o modelo seja carregado corretamente
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = Dinov2Model.from_pretrained(MODEL_NAME)
# Projeção para 512D
projection = nn.Linear(model.config.hidden_size, 512)
# Inicializa o FastAPI
app = FastAPI(
title="API de Embedding de Imagem",
description="Endpoint para obter o embedding de uma imagem usando o modelo DINOv2.",
version="1.0.0"
)
# Define o modelo de dados para a requisição
class ImageRequest(BaseModel):
image: str
# Define o endpoint para o embedding da imagem
@app.post("/embed")
async def get_embedding(request: ImageRequest):
try:
header, img_base64 = request.image.split(",", 1)
image_data = base64.b64decode(img_base64)
image = Image.open(BytesIO(image_data))
# --- Lógica de Inferência do seu script original ---
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
embedding = last_hidden_state[:, 0]
embedding_512 = projection(embedding)
return {"embedding": embedding_512.squeeze().tolist()}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}")