Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,8 @@ import imagehash
|
|
| 11 |
MODEL_NAME = "facebook/dinov2-small"
|
| 12 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
| 13 |
model = Dinov2Model.from_pretrained(MODEL_NAME)
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
title="API de Embedding de Imagem",
|
|
@@ -21,7 +22,8 @@ app = FastAPI(
|
|
| 21 |
|
| 22 |
class ImageRequest(BaseModel):
|
| 23 |
image: str
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
@app.post("/embed")
|
| 27 |
async def get_embedding(request: ImageRequest):
|
|
@@ -30,19 +32,28 @@ async def get_embedding(request: ImageRequest):
|
|
| 30 |
image_data = base64.b64decode(img_base64)
|
| 31 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 32 |
inputs = processor(images=image, return_tensors="pt")
|
|
|
|
| 33 |
with torch.no_grad():
|
| 34 |
outputs = model(**inputs)
|
| 35 |
last_hidden_state = outputs.last_hidden_state
|
| 36 |
embedding = last_hidden_state[:, 0]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
phash = str(imagehash.phash(image))
|
| 44 |
return {
|
| 45 |
-
"embedding":
|
| 46 |
"phash": phash
|
| 47 |
}
|
| 48 |
except Exception as e:
|
|
|
|
| 11 |
MODEL_NAME = "facebook/dinov2-small"
|
| 12 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
|
| 13 |
model = Dinov2Model.from_pretrained(MODEL_NAME)
|
| 14 |
+
# A camada de projeção para 512 dimensões agora é criada dentro da função,
|
| 15 |
+
# para permitir a escolha entre 384 e 512.
|
| 16 |
|
| 17 |
app = FastAPI(
|
| 18 |
title="API de Embedding de Imagem",
|
|
|
|
| 22 |
|
| 23 |
class ImageRequest(BaseModel):
|
| 24 |
image: str
|
| 25 |
+
target_dim: int = 512 # <-- NOVO: Parâmetro opcional para a dimensão do embedding
|
| 26 |
+
use_float16: bool = False # <-- NOVO: Parâmetro opcional para usar float16
|
| 27 |
|
| 28 |
@app.post("/embed")
|
| 29 |
async def get_embedding(request: ImageRequest):
|
|
|
|
| 32 |
image_data = base64.b64decode(img_base64)
|
| 33 |
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 34 |
inputs = processor(images=image, return_tensors="pt")
|
| 35 |
+
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = model(**inputs)
|
| 38 |
last_hidden_state = outputs.last_hidden_state
|
| 39 |
embedding = last_hidden_state[:, 0]
|
| 40 |
+
|
| 41 |
+
# Lógica para definir a dimensão
|
| 42 |
+
if request.target_dim == 384:
|
| 43 |
+
final_embedding = embedding
|
| 44 |
+
elif request.target_dim == 512:
|
| 45 |
+
projection = nn.Linear(model.config.hidden_size, 512)
|
| 46 |
+
final_embedding = projection(embedding)
|
| 47 |
+
else:
|
| 48 |
+
raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.")
|
| 49 |
+
|
| 50 |
+
# Lógica para a conversão para float16
|
| 51 |
+
if request.use_float16:
|
| 52 |
+
final_embedding = final_embedding.half()
|
| 53 |
|
| 54 |
phash = str(imagehash.phash(image))
|
| 55 |
return {
|
| 56 |
+
"embedding": final_embedding.squeeze().tolist(),
|
| 57 |
"phash": phash
|
| 58 |
}
|
| 59 |
except Exception as e:
|