vinithius commited on
Commit
162ac4b
·
verified ·
1 Parent(s): 08d8ceb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from torch import nn
3
- from transformers import AutoImageProcessor, Dinov2Model
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
@@ -10,12 +10,16 @@ from pydantic import BaseModel
10
  # Nome do modelo no Hugging Face Hub
11
  MODEL_NAME = "facebook/dinov2-small"
12
 
13
- # Carregando processador e modelo
14
- processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
15
- model = Dinov2Model.from_pretrained(MODEL_NAME)
 
 
16
 
17
  # Projeção para 512D
18
- projection = nn.Linear(model.config.hidden_size, 512)
 
 
19
 
20
  # Inicializa o FastAPI
21
  app = FastAPI(
@@ -32,25 +36,19 @@ class ImageRequest(BaseModel):
32
  @app.post("/embed")
33
  async def get_embedding(request: ImageRequest):
34
  try:
35
- # Extrai a string Base64 do formato "data:image/png;base64,..."
36
  header, img_base64 = request.image.split(",", 1)
37
-
38
- # Decodifica a string Base64
39
  image_data = base64.b64decode(img_base64)
40
-
41
- # Abre a imagem com Pillow
42
  image = Image.open(BytesIO(image_data))
43
 
44
- # Preprocessamento
45
- inputs = processor(images=image, return_tensors="pt")
 
 
 
46
 
47
  with torch.no_grad():
48
- outputs = model(**inputs)
49
- last_hidden_state = outputs.last_hidden_state
50
- embedding = last_hidden_state[:, 0]
51
- embedding_512 = projection(embedding)
52
 
53
- # Converte para lista Python e retorna
54
  return {"embedding": embedding_512.squeeze().tolist()}
55
 
56
  except Exception as e:
 
1
  import torch
2
  from torch import nn
3
+ from transformers import pipeline
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
 
10
  # Nome do modelo no Hugging Face Hub
11
  MODEL_NAME = "facebook/dinov2-small"
12
 
13
+ # Usando um pipeline para carregar o modelo e o processador
14
+ feature_extractor = pipeline(
15
+ "feature-extraction",
16
+ model=MODEL_NAME
17
+ )
18
 
19
  # Projeção para 512D
20
+ # O pipeline retorna um tensor, então a projeção ainda é necessária
21
+ # Você pode remover isso se o embedding de 768D for suficiente
22
+ projection = nn.Linear(768, 512)
23
 
24
  # Inicializa o FastAPI
25
  app = FastAPI(
 
36
  @app.post("/embed")
37
  async def get_embedding(request: ImageRequest):
38
  try:
 
39
  header, img_base64 = request.image.split(",", 1)
 
 
40
  image_data = base64.b64decode(img_base64)
 
 
41
  image = Image.open(BytesIO(image_data))
42
 
43
+ # Gera o embedding usando o pipeline
44
+ embedding_list = feature_extractor(images=[image])[0][0]
45
+
46
+ # Converte a lista de embeddings para um tensor PyTorch para a projeção
47
+ embedding_tensor = torch.tensor(embedding_list)
48
 
49
  with torch.no_grad():
50
+ embedding_512 = projection(embedding_tensor)
 
 
 
51
 
 
52
  return {"embedding": embedding_512.squeeze().tolist()}
53
 
54
  except Exception as e: