File size: 3,478 Bytes
9bd5c0b
d2c732d
 
 
9bd5c0b
d301135
d2c732d
 
 
e8ef0cf
 
f74b029
d2c732d
9bd5c0b
e8ef0cf
9bd5c0b
 
d2c732d
9bd5c0b
e8ef0cf
d2c732d
e8ef0cf
d2c732d
d301135
12e1af5
d301135
 
d2c732d
 
 
 
 
 
 
 
 
 
 
 
9bd5c0b
 
 
d2c732d
e8ef0cf
d301135
d2c732d
 
 
e8ef0cf
d2c732d
 
 
 
 
 
 
b3e6729
 
9bd5c0b
 
 
 
 
 
 
 
 
 
 
 
 
 
b3e6729
9bd5c0b
 
 
 
b3e6729
9bd5c0b
4878be9
d2c732d
e8ef0cf
d301135
b3e6729
 
 
e8ef0cf
d2c732d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, HTTPException
import requests
import base64
from pydantic import BaseModel
from typing import Optional, List
import re

app = FastAPI()

# New NVIDIA API endpoint and API key (adjust for the new model)
NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.1-nemotron-nano-vl-8b-v1/chat/completions"
API_KEY = "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"  # Replace securely in production

class ChatMessage(BaseModel):
    role: str  # "user", "assistant", or "system"
    content: str

class TextRequest(BaseModel):
    messages: List[ChatMessage]
    max_tokens: Optional[int] = 1024
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 0.01

PRE_PROMPT_MESSAGES = [
    {"role": "system", "content": "You are a helpful multimodal assistant "},
]

def call_nvidia_api(payload: dict):
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Accept": "application/json",
    }
    response = requests.post(NVIDIA_API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed")
    return response.json()

@app.post("/chat/text")
async def chat_with_text(request: TextRequest):
    # Combine pre-prompt with user-provided history
    messages = PRE_PROMPT_MESSAGES + [msg.dict() for msg in request.messages]

    payload = {
        "model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1",
        "messages": messages,
        "max_tokens": request.max_tokens,
        "temperature": request.temperature,
        "top_p": request.top_p,
        "stream": True,  # Set to True if you want streaming
    }
    try:
        response = call_nvidia_api(payload)
        return {"response": response["choices"][0]["message"]["content"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/chat/vision")
async def chat_from_text_with_image_url(request: TextRequest):
    # Find image URLs in the last user message(s)
    new_messages = []
    for msg in request.messages:
        if msg.role == "user":
            # Replace all image URLs in user messages with base64 img tags
            def replace_img_url(match):
                url = match.group(0)
                try:
                    img_resp = requests.get(url)
                    img_resp.raise_for_status()
                    b64 = base64.b64encode(img_resp.content).decode("utf-8")
                    return f'<img src="data:image/png;base64,{b64}" />'
                except Exception:
                    return url  # fallback to original URL if fetch fails

            content_with_img = re.sub(r'https?://\S+\.(jpg|jpeg|png|webp|gif)', replace_img_url, msg.content)
            new_messages.append({"role": "user", "content": content_with_img})
        else:
            new_messages.append(msg.dict())

    messages = PRE_PROMPT_MESSAGES + new_messages

    payload = {
        "model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1",
        "messages": messages,
        "max_tokens": request.max_tokens,
        "temperature": request.temperature,
        "top_p": request.top_p,
        "stream": True,  # Set to True if you want streaming
    }

    try:
        response = call_nvidia_api(payload)
        return {"response": response["choices"][0]["message"]["content"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))