Proff12 commited on
Commit
edbf0e3
·
verified ·
1 Parent(s): bfcd528

Upload main.py

Browse files
Files changed (1) hide show
  1. backend/app/main.py +48 -105
backend/app/main.py CHANGED
@@ -1,22 +1,19 @@
1
  import os
2
  from typing import List, Literal, Optional
 
3
 
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.staticfiles import StaticFiles
7
  from pydantic import BaseModel
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
- import torch
10
 
11
- APP_TITLE = "HF Chat (Fathom-R1-14B)"
12
  APP_VERSION = "0.2.0"
13
 
14
  MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
15
- PIPELINE_TASK = os.getenv("PIPELINE_TASK", "text-generation")
16
- MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "8192"))
17
  STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
18
  ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
19
- QUANTIZE = os.getenv("QUANTIZE", "auto")
20
 
21
  app = FastAPI(title=APP_TITLE, version=APP_VERSION)
22
 
@@ -46,121 +43,67 @@ class ChatResponse(BaseModel):
46
  reply: str
47
  model: str
48
 
49
- tokenizer = None
50
- model = None
51
- generator = None
52
-
53
- def load_pipeline():
54
- global tokenizer, model, generator
55
- device = "cuda" if torch.cuda.is_available() else "cpu"
56
-
57
- # Load tokenizer
58
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
59
- if tokenizer.pad_token is None and tokenizer.eos_token is not None:
60
- tokenizer.pad_token = tokenizer.eos_token
61
-
62
- # Determine load strategy
63
- load_kwargs = {}
64
- dtype = torch.bfloat16 if device == "cuda" else torch.float32
65
-
66
- if device == "cuda":
67
- # try quantization if requested
68
- if QUANTIZE.lower() in ("4bit", "8bit", "auto"):
69
- try:
70
- import bitsandbytes as bnb # noqa: F401
71
- if QUANTIZE.lower() == "8bit":
72
- load_kwargs.update(dict(load_in_8bit=True))
73
- else:
74
- # 4bit or auto (prefer 4bit)
75
- load_kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16))
76
- except Exception:
77
- # bitsandbytes not available; fall back to full precision on GPU
78
- pass
79
- load_kwargs.setdefault("torch_dtype", dtype)
80
- load_kwargs.setdefault("device_map", "auto")
81
- else:
82
- # CPU fallback
83
- load_kwargs.setdefault("torch_dtype", dtype)
84
-
85
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **load_kwargs)
86
-
87
- generator = pipeline(
88
- PIPELINE_TASK,
89
- model=model,
90
- tokenizer=tokenizer,
91
- device_map=load_kwargs.get("device_map", None) or (0 if device == "cuda" else -1),
92
- )
93
-
94
- @app.on_event("startup")
95
- def _startup():
96
- load_pipeline()
97
-
98
  def messages_to_prompt(messages: List[Message]) -> str:
99
- """
100
- Prefer tokenizer chat template (Qwen-based models ship one). Fallback to a simple transcript.
101
- """
102
- try:
103
- # Convert to HF chat format: list of dicts with role/content
104
- chat = [{"role": m.role, "content": m.content} for m in messages]
105
- return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
106
- except Exception:
107
- # Fallback formatting
108
- parts = []
109
- for m in messages:
110
- if m.role == "system":
111
- parts.append(f"System: {m.content}")
112
- elif m.role == "user":
113
- parts.append(f"User: {m.content}")
114
- else:
115
- parts.append(f"Assistant: {m.content}")
116
- parts.append("Assistant:")
117
- return "".join(parts)
118
-
119
- def truncate_prompt(prompt: str, max_tokens: int) -> str:
120
- ids = tokenizer(prompt, return_tensors="pt", truncation=False)["input_ids"][0]
121
- if len(ids) <= max_tokens:
122
- return prompt
123
- trimmed = ids[-max_tokens:]
124
- return tokenizer.decode(trimmed, skip_special_tokens=True)
125
 
126
  @app.get("/api/health")
127
  def health():
128
- device = next(model.parameters()).device.type if model is not None else "N/A"
129
- return {"status": "ok", "model": MODEL_ID, "task": PIPELINE_TASK, "device": device}
130
 
131
  @app.post("/api/chat", response_model=ChatResponse)
132
  def chat(req: ChatRequest):
133
- if generator is None:
134
- raise HTTPException(status_code=503, detail="Model not loaded")
 
135
  if not req.messages:
136
  raise HTTPException(status_code=400, detail="messages cannot be empty")
137
 
138
- raw_prompt = messages_to_prompt(req.messages)
139
- prompt = truncate_prompt(raw_prompt, MAX_INPUT_TOKENS)
140
-
141
- gen_kwargs = {
142
- "max_new_tokens": req.max_new_tokens,
143
- "do_sample": req.temperature > 0,
144
- "temperature": req.temperature,
145
- "top_p": req.top_p,
146
- "repetition_penalty": req.repetition_penalty,
147
- "eos_token_id": tokenizer.eos_token_id,
148
- "pad_token_id": tokenizer.pad_token_id,
149
- "return_full_text": True,
150
  }
151
- if req.stop:
152
- gen_kwargs["stop"] = req.stop
153
 
154
- outputs = generator(prompt, **gen_kwargs)
155
- if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]:
156
- full = outputs[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  reply = full[len(prompt):].strip() if full.startswith(prompt) else full
158
  else:
159
- reply = str(outputs)
 
160
  if not reply:
161
  reply = "(No response generated.)"
 
162
  return ChatResponse(reply=reply, model=MODEL_ID)
163
 
164
- # Serve frontend build (if present)
165
  if os.path.isdir(STATIC_DIR):
166
- app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")
 
1
  import os
2
  from typing import List, Literal, Optional
3
+ import requests
4
 
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from fastapi.staticfiles import StaticFiles
8
  from pydantic import BaseModel
 
 
9
 
10
+ APP_TITLE = "HF Chat (Fathom-R1-14B via API)"
11
  APP_VERSION = "0.2.0"
12
 
13
  MODEL_ID = os.getenv("MODEL_ID", "FractalAIResearch/Fathom-R1-14B")
 
 
14
  STATIC_DIR = os.getenv("STATIC_DIR", "/app/static")
15
  ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "")
16
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
17
 
18
  app = FastAPI(title=APP_TITLE, version=APP_VERSION)
19
 
 
43
  reply: str
44
  model: str
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def messages_to_prompt(messages: List[Message]) -> str:
47
+ parts = []
48
+ for m in messages:
49
+ if m.role == "system":
50
+ parts.append(f"System: {m.content}")
51
+ elif m.role == "user":
52
+ parts.append(f"User: {m.content}")
53
+ else:
54
+ parts.append(f"Assistant: {m.content}")
55
+ parts.append("Assistant:")
56
+ return "\n".join(parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  @app.get("/api/health")
59
  def health():
60
+ return {"status": "ok", "model": MODEL_ID, "source": "huggingface-inference-api"}
 
61
 
62
  @app.post("/api/chat", response_model=ChatResponse)
63
  def chat(req: ChatRequest):
64
+ if not HF_API_TOKEN:
65
+ raise HTTPException(status_code=500, detail="HF_API_TOKEN not set")
66
+
67
  if not req.messages:
68
  raise HTTPException(status_code=400, detail="messages cannot be empty")
69
 
70
+ prompt = messages_to_prompt(req.messages)
71
+
72
+ headers = {
73
+ "Authorization": f"Bearer {HF_API_TOKEN}"
 
 
 
 
 
 
 
 
74
  }
 
 
75
 
76
+ payload = {
77
+ "inputs": prompt,
78
+ "parameters": {
79
+ "max_new_tokens": req.max_new_tokens,
80
+ "temperature": req.temperature,
81
+ "top_p": req.top_p,
82
+ "repetition_penalty": req.repetition_penalty,
83
+ "return_full_text": True,
84
+ }
85
+ }
86
+
87
+ response = requests.post(
88
+ f"https://api-inference.huggingface.co/models/{MODEL_ID}",
89
+ headers=headers,
90
+ json=payload
91
+ )
92
+
93
+ if response.status_code != 200:
94
+ raise HTTPException(status_code=response.status_code, detail=response.text)
95
+
96
+ result = response.json()
97
+ if isinstance(result, list) and result and "generated_text" in result[0]:
98
+ full = result[0]["generated_text"]
99
  reply = full[len(prompt):].strip() if full.startswith(prompt) else full
100
  else:
101
+ reply = str(result)
102
+
103
  if not reply:
104
  reply = "(No response generated.)"
105
+
106
  return ChatResponse(reply=reply, model=MODEL_ID)
107
 
 
108
  if os.path.isdir(STATIC_DIR):
109
+ app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static")