|
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/data" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/data/transformers" |
|
|
os.environ["HF_HUB_CACHE"] = "/data/hub" |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
os.environ.setdefault("PYTORCH_FORCE_MPS_FALLBACK", "1") |
|
|
|
|
|
import threading |
|
|
from typing import List, Optional, Dict, Any, Iterator |
|
|
|
|
|
import torch |
|
|
from fastapi import FastAPI, Body |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
|
|
|
MODEL_ID = "unsloth/Qwen2.5-1.5B-Instruct" |
|
|
|
|
|
try: |
|
|
torch.set_num_threads(max(1, os.cpu_count() or 1)) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
print(f"[BOOT] Loading {MODEL_ID} on CPU(float32)...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
use_fast=False, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
device_map="cpu", |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str = Field(..., description="system | user | assistant") |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] |
|
|
max_new_tokens: int = 256 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.95 |
|
|
repetition_penalty: float = 1.1 |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
text: str |
|
|
|
|
|
app = FastAPI(title="Qwen2.5-1.5B CPU API") |
|
|
|
|
|
@app.get("/") |
|
|
def health(): |
|
|
return {"status": "ok", "model": MODEL_ID} |
|
|
|
|
|
def build_prompt(messages: List[Dict[str, str]]) -> str: |
|
|
|
|
|
return tokenizer.apply_chat_template( |
|
|
[{"role": m["role"], "content": m["content"]} for m in messages], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
@app.post("/v1/chat", response_model=ChatResponse) |
|
|
def chat(req: ChatRequest): |
|
|
prompt = build_prompt([m.dict() for m in req.messages]) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
**{k: v.to("cpu") for k, v in inputs.items()}, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
repetition_penalty=req.repetition_penalty, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
|
|
return ChatResponse(text=text) |
|
|
|
|
|
def stream_generate(req: ChatRequest) -> Iterator[str]: |
|
|
prompt = build_prompt([m.dict() for m in req.messages]) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
**{k: v.to("cpu") for k, v in inputs.items()}, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
repetition_penalty=req.repetition_penalty, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
) |
|
|
|
|
|
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
for token_text in streamer: |
|
|
yield f'{{"delta": {token_text.__repr__()}}}\n' |
|
|
|
|
|
@app.post("/v1/chat/stream") |
|
|
def chat_stream(req: ChatRequest = Body(...)): |
|
|
return StreamingResponse( |
|
|
stream_generate(req), |
|
|
media_type="application/x-ndjson", |
|
|
) |
|
|
|