Proff12 commited on
Commit
e618a4f
·
verified ·
1 Parent(s): c57d186
backend/app/__init__.py ADDED
File without changes
backend/app/main.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import List, Literal, Optional
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.staticfiles import StaticFiles
8
+ from pydantic import BaseModel
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
+ import torch
11
+
12
+ APP_TITLE = "HF Chat (Fathom-R1-14B)"
13
+ APP_VERSION = "0.2.0"
14
+
15
+ # ---- Config via ENV ----
16
+ MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
17
+ PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation")
18
+ MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192")) # keep prompt reasonable
19
+ STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
20
+ ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
21
+ QUANTIZE = os.getenv("QUANTIZE", "auto") # auto|4bit|8bit|none
22
+
23
+ app = FastAPI(title=APP_TITLE, version=APP_VERSION)
24
+
25
+ if ALLOWED_ORIGINS:
26
+ origins = [o.strip() for o in ALLOWED_ORIGINS.split(",") if o.strip()]
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=origins,
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ class Message(BaseModel):
36
+ role: Literal["system", "user", "assistant"]
37
+ content: str
38
+
39
+ class ChatRequest(BaseModel):
40
+ messages: List[Message]
41
+ max_new_tokens: int = 512
42
+ temperature: float = 0.7
43
+ top_p: float = 0.95
44
+ repetition_penalty: Optional[float] = 1.0
45
+ stop: Optional[List[str]] = None
46
+
47
+ class ChatResponse(BaseModel):
48
+ reply: str
49
+ model: str
50
+
51
+ tokenizer = None
52
+ model = None
53
+ generator = None
54
+
55
+ def load_pipeline():
56
+ global tokenizer, model, generator
57
+ device = "cuda" if torch.cuda.is_available() else "cpu"
58
+
59
+ # Load tokenizer
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
61
+ if tokenizer.pad_token is None and tokenizer.eos_token is not None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ # Determine load strategy
65
+ load_kwargs = {}
66
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
67
+
68
+ if device == "cuda":
69
+ # try quantization if requested
70
+ if QUANTIZE.lower() in ("4bit", "8bit", "auto"):
71
+ try:
72
+ import bitsandbytes as bnb # noqa: F401
73
+ if QUANTIZE.lower() == "8bit":
74
+ load_kwargs.update(dict(load_in_8bit=True))
75
+ else:
76
+ # 4bit or auto (prefer 4bit)
77
+ load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16))
78
+ except Exception:
79
+ # bitsandbytes not available; fall back to full precision on GPU
80
+ pass
81
+ load_kwargs.setdefault("torch_dtype", dtype)
82
+ load_kwargs.setdefault("device_map", "auto")
83
+ else:
84
+ # CPU fallback
85
+ load_kwargs.setdefault("torch_dtype", dtype)
86
+
87
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
88
+
89
+ generator = pipeline(
90
+ PIPELINE_TASK,
91
+ model=model,
92
+ tokenizer=tokenizer,
93
+ device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1),
94
+ )
95
+
96
+ @app.on_event("startup")
97
+ def _startup():
98
+ load_pipeline()
99
+
100
+ def messages_to_prompt(messages: List[Message]) -> str:
101
+ """
102
+ Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript.
103
+ """
104
+ try:
105
+ # Convert to HF chat format: list of dicts with role/content
106
+ chat = [{"role": m.role, "content": m.content} for m in messages]
107
+ return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
108
+ except Exception:
109
+ # Fallback formatting
110
+ parts = []
111
+ for m in messages:
112
+ if m.role == "system":
113
+ parts.append(f"System: {m.content}
114
+ ")
115
+ elif m.role == "user":
116
+ parts.append(f"User: {m.content}
117
+ ")
118
+ else:
119
+ parts.append(f"Assistant: {m.content}
120
+ ")
121
+ parts.append("Assistant:")
122
+ return "
123
+ ".join(parts)
124
+
125
+ def truncate_prompt(prompt: str, max_tokens: int) -> str:
126
+ ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
127
+ if len(ids) <= max_tokens:
128
+ return prompt
129
+ trimmed = ids[-max_tokens:]
130
+ return tokenizer.decode(trimmed, skip_special_tokens=True)
131
+
132
+ @app.get("/api/health")
133
+ def health():
134
+ device = next(model.parameters()).device.type if model is not None else "N/A"
135
+ return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device}
136
+
137
+ @app.post("/api/chat", response_model=ChatResponse)
138
+ def chat(req: ChatRequest):
139
+ if generator is None:
140
+ raise HTTPException(status_code=503, detail="Model not loaded")
141
+ if not req.messages:
142
+ raise HTTPException(status_code=400, detail="messages cannot be empty")
143
+
144
+ raw_prompt = messages_to_prompt(req.messages)
145
+ prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS)
146
+
147
+ gen_kwargs = {
148
+ "max_new_tokens": req.max_new_tokens,
149
+ "do_sample": req.temperature > 0,
150
+ "temperature": req.temperature,
151
+ "top_p": req.top_p,
152
+ "repetition_penalty": req.repetition_penalty,
153
+ "eos_token_id": tokenizer.eos_token_id,
154
+ "pad_token_id": tokenizer.pad_token_id,
155
+ "return_full_text": True,
156
+ }
157
+ if req.stop:
158
+ gen_kwargs["stop"] = req.stop
159
+
160
+ outputs = generator(prompt, **gen_kwargs)
161
+ if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]:
162
+ full = outputs[0]["generated_text"]
163
+ reply = full[len(prompt):].strip() if full.startswith(prompt) else full
164
+ else:
165
+ reply = str(outputs)
166
+ if not reply:
167
+ reply = "(No response generated.)"
168
+ return ChatResponse(reply=reply, model=MODEL_ID)
169
+
170
+ # Serve frontend build (if present)
171
+ if os.path.isdir(STATIC_DIR):
172
+ app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")
backend/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ fastapi>=0.115,<1
3
+ uvicorn[standard]>=0.30,<1
4
+ transformers>=4.44.0
5
+ accelerate>=0.33.0
6
+ bitsandbytes>=0.43.0
7
+ pydantic>=2.8,<3
8
+ safetensors
frontend/index.html ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!doctype html>
3
+ <html lang="en">
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
+ <title>Fathom R1 Chat</title>
8
+ </head>
9
+ <body>
10
+ <div id="root"></div>
11
+ <script type="module" src="/src/main.jsx"></script>
12
+ </body>
13
+ </html>
frontend/package.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ {
3
+ "name": "hf-fathom-chat-frontend",
4
+ "version": "0.1.0",
5
+ "private": true,
6
+ "type": "module",
7
+ "scripts": {
8
+ "dev": "vite",
9
+ "build": "vite build",
10
+ "preview": "vite preview --port 5173"
11
+ },
12
+ "dependencies": {
13
+ "react": "^18.3.1",
14
+ "react-dom": "^18.3.1"
15
+ },
16
+ "devDependencies": {
17
+ "@vitejs/plugin-react": "^4.3.1",
18
+ "vite": "^5.4.9"
19
+ }
20
+ }
frontend/src/App.jsx ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import React, { useEffect, useRef, useState } from 'react'
3
+
4
+ export default function App() {
5
+ const [messages, setMessages] = useState([
6
+ { role: 'system', content: 'You are a helpful assistant that explains your reasoning clearly and concisely.' }
7
+ ])
8
+ const [input, setInput] = useState('')
9
+ const [loading, setLoading] = useState(false)
10
+ const [model, setModel] = useState('')
11
+ const endRef = useRef(null)
12
+
13
+ useEffect(() => {
14
+ endRef.current?.scrollIntoView({ behavior: 'smooth' })
15
+ }, [messages, loading])
16
+
17
+ async function sendChat(nextMessages) {
18
+ const res = await fetch('/api/chat', {
19
+ method: 'POST',
20
+ headers: { 'Content-Type': 'application/json' },
21
+ body: JSON.stringify({
22
+ messages: nextMessages,
23
+ max_new_tokens: 512,
24
+ temperature: 0.7,
25
+ top_p: 0.95
26
+ })
27
+ })
28
+ if (!res.ok) {
29
+ const t = await res.text()
30
+ throw new Error(`API ${res.status}: ${t}`)
31
+ }
32
+ return res.json()
33
+ }
34
+
35
+ const onSend = async () => {
36
+ const text = input.trim()
37
+ if (!text || loading) return
38
+ const next = [...messages, { role: 'user', content: text }]
39
+ setMessages(next)
40
+ setInput('')
41
+ setLoading(true)
42
+ try {
43
+ const { reply, model } = await sendChat(next)
44
+ setModel(model)
45
+ setMessages([...next, { role: 'assistant', content: reply }])
46
+ } catch (e) {
47
+ setMessages([...next, { role: 'assistant', content: `(Error) ${e.message}` }])
48
+ } finally {
49
+ setLoading(false)
50
+ }
51
+ }
52
+
53
+ const onKeyDown = (e) => {
54
+ if (e.key === 'Enter' && !e.shiftKey) {
55
+ e.preventDefault()
56
+ onSend()
57
+ }
58
+ }
59
+
60
+ return (
61
+ <div className="app">
62
+ <header className="header">
63
+ <div className="brand">Fathom R1 Chat</div>
64
+ {model && <div className="model">{model}</div>}
65
+ </header>
66
+
67
+ <main className="chat">
68
+ {messages.filter(m => m.role !== 'system').map((m, i) => (
69
+ <div key={i} className={`bubble ${m.role}`}>
70
+ <div className="sender">{m.role === 'user' ? 'You' : 'Assistant'}</div>
71
+ <div className="content">{m.content}</div>
72
+ </div>
73
+ ))}
74
+ {loading && <div className="bubble assistant"><div className="content">Thinking…</div></div>}
75
+ <div ref={endRef} />
76
+ </main>
77
+
78
+ <footer className="composer">
79
+ <textarea
80
+ value={input}
81
+ onChange={(e) => setInput(e.target.value)}
82
+ onKeyDown={onKeyDown}
83
+ placeholder="Ask a question…"
84
+ rows={2}
85
+ />
86
+ <button onClick={onSend} disabled={loading || !input.trim()}>Send</button>
87
+ </footer>
88
+ </div>
89
+ )
90
+ }
frontend/src/main.jsx ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ import React from 'react'
3
+ import { createRoot } from 'react-dom/client'
4
+ import App from './App.jsx'
5
+ import './styles.css'
6
+
7
+ createRoot(document.getElementById('root')).render(<App />)
frontend/src/styles.css ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ :root {
3
+ color-scheme: light dark;
4
+ --bg: #0b0f1a;
5
+ --panel: #0f172a;
6
+ --border: #1f2937;
7
+ --text: #e5e7eb;
8
+ --muted: #94a3b8;
9
+ --user: #2563eb;
10
+ --assistant: #374151;
11
+ --accent: #22c55e;
12
+ }
13
+ * { box-sizing: border-box; }
14
+ html, body, #root { height: 100%; margin: 0; }
15
+ body { background: var(--bg); color: var(--text); font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial; }
16
+ .app { display: grid; grid-template-rows: auto 1fr auto; height: 100%; max-width: 900px; margin: 0 auto; }
17
+ .header { display: flex; align-items: center; gap: 12px; padding: 12px 16px; border-bottom: 1px solid var(--border); background: var(--panel); }
18
+ .brand { font-weight: 700; }
19
+ .model { font-size: 12px; color: var(--muted); margin-left: auto; }
20
+ .chat { padding: 16px; display: flex; flex-direction: column; gap: 12px; overflow-y: auto; }
21
+ .bubble { max-width: 80%; padding: 10px 12px; border-radius: 12px; }
22
+ .bubble .sender { font-size: 11px; color: var(--muted); margin-bottom: 4px; }
23
+ .bubble .content { white-space: pre-wrap; line-height: 1.4; }
24
+ .bubble.user { margin-left: auto; background: var(--user); color: white; }
25
+ .bubble.assistant { margin-right: auto; background: var(--assistant); color: #f3f4f6; }
26
+ .composer { display: flex; gap: 8px; padding: 12px; border-top: 1px solid var(--border); background: var(--panel); }
27
+ textarea { flex: 1; resize: none; padding: 10px; border-radius: 8px; border: 1px solid #263144; background: #0b1220; color: var(--text); }
28
+ button { padding: 10px 16px; border-radius: 8px; background: var(--accent); border: none; color: #062010; font-weight: 600; cursor: pointer; }
29
+ button:disabled { opacity: 0.6; cursor: default; }
frontend/vite.config.js ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import { defineConfig } from 'vite'
3
+ import react from '@vitejs/plugin-react'
4
+
5
+ export default defineConfig({
6
+ plugins: [react()],
7
+ server: {
8
+ port: 5173,
9
+ proxy: {
10
+ '/api': 'http://localhost:8000'
11
+ }
12
+ },
13
+ build: { outDir: 'dist' }
14
+ })