File size: 2,383 Bytes
e47948f
 
4086272
e47948f
 
 
 
 
96fcd47
e47948f
 
41a3f77
62fd394
 
 
 
 
0c9eb21
1bdcb1f
 
e47948f
 
 
39e2ab2
e47948f
 
 
 
 
1bdcb1f
 
e47948f
 
 
 
 
 
39e2ab2
41a3f77
1bdcb1f
e47948f
41a3f77
 
 
1bdcb1f
 
 
 
 
 
 
 
 
 
 
 
 
96fcd47
 
39e2ab2
1bdcb1f
39e2ab2
 
e47948f
96fcd47
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
from transformers import AutoImageProcessor, Dinov2Model
from PIL import Image
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import imagehash

MODEL_NAME = "facebook/dinov2-small"
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)

print(f"Resize Strategy: {processor.size}")
print(f"Do Center Crop?: {processor.do_center_crop}")
print(f"Crop Size: {processor.crop_size}")

model = Dinov2Model.from_pretrained(MODEL_NAME)
# A camada de projeção para 512 dimensões agora é criada dentro da função,
# para permitir a escolha entre 384 e 512.

app = FastAPI(
    title="API de Embedding de Imagem",
    description="Endpoint para obter o embedding e pHash de uma imagem.",
    version="1.0.0"
)

class ImageRequest(BaseModel):
    image: str
    target_dim: int = 512  # <-- NOVO: Parâmetro opcional para a dimensão do embedding
    use_float16: bool = False # <-- NOVO: Parâmetro opcional para usar float16

@app.post("/embed")
async def get_embedding(request: ImageRequest):
    try:
        header, img_base64 = request.image.split(",", 1)
        image_data = base64.b64decode(img_base64)
        image = Image.open(BytesIO(image_data)).convert("RGB")
        inputs = processor(images=image, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)
            last_hidden_state = outputs.last_hidden_state
            embedding = last_hidden_state[:, 0]
            
            # Lógica para definir a dimensão
            if request.target_dim == 384:
                final_embedding = embedding
            elif request.target_dim == 512:
                projection = nn.Linear(model.config.hidden_size, 512)
                final_embedding = projection(embedding)
            else:
                raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.")
            
            # Lógica para a conversão para float16
            if request.use_float16:
                final_embedding = final_embedding.half()

        phash = str(imagehash.phash(image))
        return {
            "embedding": final_embedding.squeeze().tolist(),
            "phash": phash
        }
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Erro ao processar a imagem: {e}")