Spaces:
Sleeping
Sleeping
File size: 2,383 Bytes
e47948f 4086272 e47948f 96fcd47 e47948f 41a3f77 62fd394 0c9eb21 1bdcb1f e47948f 39e2ab2 e47948f 1bdcb1f e47948f 39e2ab2 41a3f77 1bdcb1f e47948f 41a3f77 1bdcb1f 96fcd47 39e2ab2 1bdcb1f 39e2ab2 e47948f 96fcd47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from torch import nn
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import imagehash
MODEL_NAME = "facebook/dinov2-small"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
print(f"Resize Strategy: {processor.size}")
print(f"Do Center Crop?: {processor.do_center_crop}")
print(f"Crop Size: {processor.crop_size}")
model = Dinov2Model.from_pretrained(MODEL_NAME)
# A camada de projeção para 512 dimensões agora é criada dentro da função,
# para permitir a escolha entre 384 e 512.
app = FastAPI(
title="API de Embedding de Imagem",
description="Endpoint para obter o embedding e pHash de uma imagem.",
version="1.0.0"
)
class ImageRequest(BaseModel):
image: str
target_dim: int = 512 # <-- NOVO: Parâmetro opcional para a dimensão do embedding
use_float16: bool = False # <-- NOVO: Parâmetro opcional para usar float16
@app.post("/embed")
async def get_embedding(request: ImageRequest):
try:
header, img_base64 = request.image.split(",", 1)
image_data = base64.b64decode(img_base64)
image = Image.open(BytesIO(image_data)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
embedding = last_hidden_state[:, 0]
# Lógica para definir a dimensão
if request.target_dim == 384:
final_embedding = embedding
elif request.target_dim == 512:
projection = nn.Linear(model.config.hidden_size, 512)
final_embedding = projection(embedding)
else:
raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.")
# Lógica para a conversão para float16
if request.use_float16:
final_embedding = final_embedding.half()
phash = str(imagehash.phash(image))
return {
"embedding": final_embedding.squeeze().tolist(),
"phash": phash
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}") |