Spaces:
Paused
Paused
| from fastapi import FastAPI, HTTPException | |
| import requests | |
| import base64 | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import re | |
| app = FastAPI() | |
| # New NVIDIA API endpoint and API key (adjust for the new model) | |
| NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.1-nemotron-nano-vl-8b-v1/chat/completions" | |
| API_KEY = "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" # Replace securely in production | |
| class ChatMessage(BaseModel): | |
| role: str # "user", "assistant", or "system" | |
| content: str | |
| class TextRequest(BaseModel): | |
| messages: List[ChatMessage] | |
| max_tokens: Optional[int] = 1024 | |
| temperature: Optional[float] = 1.0 | |
| top_p: Optional[float] = 0.01 | |
| PRE_PROMPT_MESSAGES = [ | |
| {"role": "system", "content": "You are a helpful multimodal assistant "}, | |
| ] | |
| def call_nvidia_api(payload: dict): | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Accept": "application/json", | |
| } | |
| response = requests.post(NVIDIA_API_URL, headers=headers, json=payload) | |
| if response.status_code != 200: | |
| raise HTTPException(status_code=response.status_code, detail="NVIDIA API request failed") | |
| return response.json() | |
| async def chat_with_text(request: TextRequest): | |
| # Combine pre-prompt with user-provided history | |
| messages = PRE_PROMPT_MESSAGES + [msg.dict() for msg in request.messages] | |
| payload = { | |
| "model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1", | |
| "messages": messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "stream": True, # Set to True if you want streaming | |
| } | |
| try: | |
| response = call_nvidia_api(payload) | |
| return {"response": response["choices"][0]["message"]["content"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat_from_text_with_image_url(request: TextRequest): | |
| # Find image URLs in the last user message(s) | |
| new_messages = [] | |
| for msg in request.messages: | |
| if msg.role == "user": | |
| # Replace all image URLs in user messages with base64 img tags | |
| def replace_img_url(match): | |
| url = match.group(0) | |
| try: | |
| img_resp = requests.get(url) | |
| img_resp.raise_for_status() | |
| b64 = base64.b64encode(img_resp.content).decode("utf-8") | |
| return f'<img src="data:image/png;base64,{b64}" />' | |
| except Exception: | |
| return url # fallback to original URL if fetch fails | |
| content_with_img = re.sub(r'https?://\S+\.(jpg|jpeg|png|webp|gif)', replace_img_url, msg.content) | |
| new_messages.append({"role": "user", "content": content_with_img}) | |
| else: | |
| new_messages.append(msg.dict()) | |
| messages = PRE_PROMPT_MESSAGES + new_messages | |
| payload = { | |
| "model": "nvidia/llama-3.1-nemotron-nano-vl-8b-v1", | |
| "messages": messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| "top_p": request.top_p, | |
| "stream": True, # Set to True if you want streaming | |
| } | |
| try: | |
| response = call_nvidia_api(payload) | |
| return {"response": response["choices"][0]["message"]["content"]} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |