vinithius commited on
Commit
1bdcb1f
·
verified ·
1 Parent(s): 96fcd47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -11,7 +11,8 @@ import imagehash
11
  MODEL_NAME = "facebook/dinov2-small"
12
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
13
  model = Dinov2Model.from_pretrained(MODEL_NAME)
14
- projection = nn.Linear(model.config.hidden_size, 512)
 
15
 
16
  app = FastAPI(
17
  title="API de Embedding de Imagem",
@@ -21,7 +22,8 @@ app = FastAPI(
21
 
22
  class ImageRequest(BaseModel):
23
  image: str
24
- use_float16: bool = False # <-- NOVO: Parâmetro opcional com valor padrão False
 
25
 
26
  @app.post("/embed")
27
  async def get_embedding(request: ImageRequest):
@@ -30,19 +32,28 @@ async def get_embedding(request: ImageRequest):
30
  image_data = base64.b64decode(img_base64)
31
  image = Image.open(BytesIO(image_data)).convert("RGB")
32
  inputs = processor(images=image, return_tensors="pt")
 
33
  with torch.no_grad():
34
  outputs = model(**inputs)
35
  last_hidden_state = outputs.last_hidden_state
36
  embedding = last_hidden_state[:, 0]
37
- embedding_512 = projection(embedding)
38
-
39
- # <-- NOVA LÓGICA: Conversão condicional para float16
40
- if request.use_float16:
41
- embedding_512 = embedding_512.half()
 
 
 
 
 
 
 
 
42
 
43
  phash = str(imagehash.phash(image))
44
  return {
45
- "embedding": embedding_512.squeeze().tolist(),
46
  "phash": phash
47
  }
48
  except Exception as e:
 
11
  MODEL_NAME = "facebook/dinov2-small"
12
  processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
13
  model = Dinov2Model.from_pretrained(MODEL_NAME)
14
+ # A camada de projeção para 512 dimensões agora é criada dentro da função,
15
+ # para permitir a escolha entre 384 e 512.
16
 
17
  app = FastAPI(
18
  title="API de Embedding de Imagem",
 
22
 
23
  class ImageRequest(BaseModel):
24
  image: str
25
+ target_dim: int = 512 # <-- NOVO: Parâmetro opcional para a dimensão do embedding
26
+ use_float16: bool = False # <-- NOVO: Parâmetro opcional para usar float16
27
 
28
  @app.post("/embed")
29
  async def get_embedding(request: ImageRequest):
 
32
  image_data = base64.b64decode(img_base64)
33
  image = Image.open(BytesIO(image_data)).convert("RGB")
34
  inputs = processor(images=image, return_tensors="pt")
35
+
36
  with torch.no_grad():
37
  outputs = model(**inputs)
38
  last_hidden_state = outputs.last_hidden_state
39
  embedding = last_hidden_state[:, 0]
40
+
41
+ # Lógica para definir a dimensão
42
+ if request.target_dim == 384:
43
+ final_embedding = embedding
44
+ elif request.target_dim == 512:
45
+ projection = nn.Linear(model.config.hidden_size, 512)
46
+ final_embedding = projection(embedding)
47
+ else:
48
+ raise HTTPException(status_code=400, detail="Dimensão inválida. Escolha entre 384 ou 512.")
49
+
50
+ # Lógica para a conversão para float16
51
+ if request.use_float16:
52
+ final_embedding = final_embedding.half()
53
 
54
  phash = str(imagehash.phash(image))
55
  return {
56
+ "embedding": final_embedding.squeeze().tolist(),
57
  "phash": phash
58
  }
59
  except Exception as e: