vinithius commited on
Commit
41a3f77
·
verified ·
1 Parent(s): 162ac4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  from torch import nn
3
- from transformers import pipeline
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
@@ -10,16 +10,13 @@ from pydantic import BaseModel
10
  # Nome do modelo no Hugging Face Hub
11
  MODEL_NAME = "facebook/dinov2-small"
12
 
13
- # Usando um pipeline para carregar o modelo e o processador
14
- feature_extractor = pipeline(
15
- "feature-extraction",
16
- model=MODEL_NAME
17
- )
18
 
19
  # Projeção para 512D
20
- # O pipeline retorna um tensor, então a projeção ainda é necessária
21
- # Você pode remover isso se o embedding de 768D for suficiente
22
- projection = nn.Linear(768, 512)
23
 
24
  # Inicializa o FastAPI
25
  app = FastAPI(
@@ -36,19 +33,25 @@ class ImageRequest(BaseModel):
36
  @app.post("/embed")
37
  async def get_embedding(request: ImageRequest):
38
  try:
 
39
  header, img_base64 = request.image.split(",", 1)
 
 
40
  image_data = base64.b64decode(img_base64)
41
- image = Image.open(BytesIO(image_data))
42
 
43
- # Gera o embedding usando o pipeline
44
- embedding_list = feature_extractor(images=[image])[0][0]
45
 
46
- # Converte a lista de embeddings para um tensor PyTorch para a projeção
47
- embedding_tensor = torch.tensor(embedding_list)
48
 
49
  with torch.no_grad():
50
- embedding_512 = projection(embedding_tensor)
 
 
 
51
 
 
52
  return {"embedding": embedding_512.squeeze().tolist()}
53
 
54
  except Exception as e:
 
1
  import torch
2
  from torch import nn
3
+ from transformers import AutoImageProcessor, AutoModel
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
 
10
  # Nome do modelo no Hugging Face Hub
11
  MODEL_NAME = "facebook/dinov2-small"
12
 
13
+ # Carregando processador e modelo
14
+ # Usamos a mesma lógica do seu script original que funciona
15
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
16
+ model = AutoModel.from_pretrained(MODEL_NAME)
 
17
 
18
  # Projeção para 512D
19
+ projection = nn.Linear(model.config.hidden_size, 512)
 
 
20
 
21
  # Inicializa o FastAPI
22
  app = FastAPI(
 
33
  @app.post("/embed")
34
  async def get_embedding(request: ImageRequest):
35
  try:
36
+ # Extrai a string Base64 do formato "data:image/png;base64,..."
37
  header, img_base64 = request.image.split(",", 1)
38
+
39
+ # Decodifica a string Base64
40
  image_data = base64.b64decode(img_base64)
 
41
 
42
+ # Abre a imagem com Pillow
43
+ image = Image.open(BytesIO(image_data))
44
 
45
+ # --- Lógica de Inferência do Seu Script Original ---
46
+ inputs = processor(images=image, return_tensors="pt")
47
 
48
  with torch.no_grad():
49
+ outputs = model(**inputs)
50
+ last_hidden_state = outputs.last_hidden_state
51
+ embedding = last_hidden_state[:, 0]
52
+ embedding_512 = projection(embedding)
53
 
54
+ # Converte para lista Python e retorna
55
  return {"embedding": embedding_512.squeeze().tolist()}
56
 
57
  except Exception as e: