Fred808's picture
Update app.py
f74b029 verified
from fastapi import FastAPI, HTTPException
import requests
import base64
from pydantic import BaseModel
from typing import Optional, List
import re
app = FastAPI()
# New NVIDIA API endpoint and API key (adjust for the new model)
NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.1-nemotron-nano-vl-8b-v1/chat/completions"
API_KEY = "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" # Replace securely in production
class ChatMessage(BaseModel):
role: str # "user", "assistant", or "system"
content: str
class TextRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: Optional[int] = 1024
temperature: Optional[float] = 1.0
top_p: Optional[float] = 0.01
PRE_PROMPT_MESSAGES = [
{"role": "system", "content": "You are a helpful multimodal assistant "},
]
def call_nvidia_api(payload: dict):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Accept": "application/json",
}
response = requests.post(NVIDIA_API_URL, headers=headers, json=payload)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed")
return response.json()
@app.post("/chat/text")
async def chat_with_text(request: TextRequest):
# Combine pre-prompt with user-provided history
messages = PRE_PROMPT_MESSAGES + [msg.dict() for msg in request.messages]
payload = {
"model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1",
"messages": messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"stream": True, # Set to True if you want streaming
}
try:
response = call_nvidia_api(payload)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/vision")
async def chat_from_text_with_image_url(request: TextRequest):
# Find image URLs in the last user message(s)
new_messages = []
for msg in request.messages:
if msg.role == "user":
# Replace all image URLs in user messages with base64 img tags
def replace_img_url(match):
url = match.group(0)
try:
img_resp = requests.get(url)
img_resp.raise_for_status()
b64 = base64.b64encode(img_resp.content).decode("utf-8")
return f'<img src="data:image/png;base64,{b64}" />'
except Exception:
return url # fallback to original URL if fetch fails
content_with_img = re.sub(r'https?://\S+\.(jpg|jpeg|png|webp|gif)', replace_img_url, msg.content)
new_messages.append({"role": "user", "content": content_with_img})
else:
new_messages.append(msg.dict())
messages = PRE_PROMPT_MESSAGES + new_messages
payload = {
"model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1",
"messages": messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"stream": True, # Set to True if you want streaming
}
try:
response = call_nvidia_api(payload)
return {"response": response["choices"][0]["message"]["content"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))