Fred808 commited on
Commit
9bd5c0b
·
verified ·
1 Parent(s): d301135

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -39
app.py CHANGED
@@ -1,8 +1,9 @@
1
- from fastapi import FastAPI, File, UploadFile, HTTPException
2
  import requests
3
  import base64
4
  from pydantic import BaseModel
5
- from typing import Optional
 
6
  import re
7
 
8
  app = FastAPI()
@@ -11,20 +12,20 @@ app = FastAPI()
11
  NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
12
  API_KEY = "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" # Replace securely in production
13
 
14
- # Request model for single user message
 
 
 
15
  class TextRequest(BaseModel):
16
- message: str
17
  max_tokens: Optional[int] = 512
18
  temperature: Optional[float] = 1.0
19
  top_p: Optional[float] = 1.0
20
 
21
- # Common pre-prompts
22
  PRE_PROMPT_MESSAGES = [
23
  {"role": "system", "content": "You are a helpful multimodal assistant powered by LLaMA 3.2 Vision-Instruct."},
24
- {"role": "assistant", "content": "Hi! You can send text or image-based questions. What would you like to know?"}
25
  ]
26
 
27
- # Function to call the NVIDIA API
28
  def call_nvidia_api(payload: dict):
29
  headers = {
30
  "Authorization": f"Bearer {API_KEY}",
@@ -35,11 +36,11 @@ def call_nvidia_api(payload: dict):
35
  raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed")
36
  return response.json()
37
 
38
- # /chat/text endpoint: Adds new user message to pre-prompted context
39
  @app.post("/chat/text")
40
  async def chat_with_text(request: TextRequest):
41
- messages = PRE_PROMPT_MESSAGES + [{"role": "user", "content": request.message}]
42
-
 
43
  payload = {
44
  "model": "meta/llama-3.2-90b-vision-instruct",
45
  "messages": messages,
@@ -54,28 +55,31 @@ async def chat_with_text(request: TextRequest):
54
  except Exception as e:
55
  raise HTTPException(status_code=500, detail=str(e))
56
 
57
- # /chat/vision endpoint: Handles messages containing image URLs
58
  @app.post("/chat/vision")
59
  async def chat_from_text_with_image_url(request: TextRequest):
60
- # Detect image URL
61
- match = re.search(r'https?://\S+\.(jpg|jpeg|png|webp|gif)', request.message)
62
- if not match:
63
- raise HTTPException(status_code=400, detail="No image URL found")
 
 
 
 
 
 
 
 
 
 
64
 
65
- image_url = match.group(0)
66
- try:
67
- img_response = requests.get(image_url)
68
- img_response.raise_for_status()
69
- base64_image = base64.b64encode(img_response.content).decode("utf-8")
70
- img_tag = f'<img src="data:image/png;base64,{base64_image}" />'
71
- except Exception as e:
72
- raise HTTPException(status_code=400, detail=f"Failed to fetch image: {e}")
73
 
74
- # Replace image URL in message
75
- modified_message = request.message.replace(image_url, img_tag)
76
 
77
- messages = PRE_PROMPT_MESSAGES + [{"role": "user", "content": modified_message}]
78
-
79
  payload = {
80
  "model": "meta/llama-3.2-90b-vision-instruct",
81
  "messages": messages,
@@ -90,15 +94,3 @@ async def chat_from_text_with_image_url(request: TextRequest):
90
  return {"response": response["choices"][0]["message"]["content"]}
91
  except Exception as e:
92
  raise HTTPException(status_code=500, detail=str(e))
93
-
94
- # Root endpoint
95
- @app.get("/")
96
- async def root():
97
- return {
98
- "message": "Welcome to the NVIDIA Vision Chat API!",
99
- "endpoints": {
100
- "/chat/text": "Send plain text questions (just provide your message).",
101
- "/chat/vision": "Send a message with an image URL (e.g. 'What is this? https://example.com/cat.jpg')",
102
- },
103
- "note": "You do NOT need to include assistant history or system roles — it's pre-injected automatically."
104
- }
 
1
+ from fastapi import FastAPI, HTTPException
2
  import requests
3
  import base64
4
  from pydantic import BaseModel
5
+ from typing import Optional, List
6
+
7
  import re
8
 
9
  app = FastAPI()
 
12
  NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
13
  API_KEY = "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" # Replace securely in production
14
 
15
+ class ChatMessage(BaseModel):
16
+ role: str # "user" or "assistant" or "system"
17
+ content: str
18
+
19
  class TextRequest(BaseModel):
20
+ messages: List[ChatMessage]
21
  max_tokens: Optional[int] = 512
22
  temperature: Optional[float] = 1.0
23
  top_p: Optional[float] = 1.0
24
 
 
25
  PRE_PROMPT_MESSAGES = [
26
  {"role": "system", "content": "You are a helpful multimodal assistant powered by LLaMA 3.2 Vision-Instruct."},
 
27
  ]
28
 
 
29
  def call_nvidia_api(payload: dict):
30
  headers = {
31
  "Authorization": f"Bearer {API_KEY}",
 
36
  raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed")
37
  return response.json()
38
 
 
39
  @app.post("/chat/text")
40
  async def chat_with_text(request: TextRequest):
41
+ # Combine pre-prompt with user-provided history
42
+ messages = PRE_PROMPT_MESSAGES + [msg.dict() for msg in request.messages]
43
+
44
  payload = {
45
  "model": "meta/llama-3.2-90b-vision-instruct",
46
  "messages": messages,
 
55
  except Exception as e:
56
  raise HTTPException(status_code=500, detail=str(e))
57
 
58
+
59
  @app.post("/chat/vision")
60
  async def chat_from_text_with_image_url(request: TextRequest):
61
+ # Find image URLs in the last user message(s)
62
+ new_messages = []
63
+ for msg in request.messages:
64
+ if msg.role == "user":
65
+ # Replace all image URLs in user messages with base64 img tags
66
+ def replace_img_url(match):
67
+ url = match.group(0)
68
+ try:
69
+ img_resp = requests.get(url)
70
+ img_resp.raise_for_status()
71
+ b64 = base64.b64encode(img_resp.content).decode("utf-8")
72
+ return f'<img src="data:image/png;base64,{b64}" />'
73
+ except Exception:
74
+ return url # fallback to original URL if fetch fails
75
 
76
+ content_with_img = re.sub(r'https?://\S+\.(jpg|jpeg|png|webp|gif)', replace_img_url, msg.content)
77
+ new_messages.append({"role": "user", "content": content_with_img})
78
+ else:
79
+ new_messages.append(msg.dict())
 
 
 
 
80
 
81
+ messages = PRE_PROMPT_MESSAGES + new_messages
 
82
 
 
 
83
  payload = {
84
  "model": "meta/llama-3.2-90b-vision-instruct",
85
  "messages": messages,
 
94
  return {"response": response["choices"][0]["message"]["content"]}
95
  except Exception as e:
96
  raise HTTPException(status_code=500, detail=str(e))