|
|
import os |
|
|
|
|
|
os.environ["HF_HOME"] = "/data" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "/data/transformers" |
|
|
os.environ["HF_HUB_CACHE"] = "/data/hub" |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
os.environ.setdefault("PYTORCH_FORCE_MPS_FALLBACK", "1") |
|
|
import threading |
|
|
from typing import List, Dict, Iterator |
|
|
|
|
|
import torch |
|
|
from fastapi import FastAPI, Body |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
os.environ.setdefault("PYTORCH_FORCE_MPS_FALLBACK", "1") |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
|
|
|
USE_4BIT = False |
|
|
COMPUTE_DTYPE = torch.float32 |
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "unsloth/Qwen2.5-1.5B-Instruct") |
|
|
ADAPTER_ID = os.environ.get("ADAPTER_ID", "WildOjisan/qwen2_5_lora_adapter_test1") |
|
|
|
|
|
|
|
|
try: |
|
|
torch.set_num_threads(max(1, os.cpu_count() or 1)) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
print(f"[BOOT] Base: {MODEL_ID}") |
|
|
print(f"[BOOT] LoRA: {ADAPTER_ID}") |
|
|
|
|
|
device_map = "cpu" |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False, trust_remote_code=True) |
|
|
print("[BOOT] Tokenizer loaded from ADAPTER_ID.") |
|
|
except Exception: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True) |
|
|
print("[BOOT] Tokenizer loaded from MODEL_ID.") |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=COMPUTE_DTYPE, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_ID) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str = Field(..., description="system | user | assistant") |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] |
|
|
max_new_tokens: int = 128 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.9 |
|
|
repetition_penalty: float = 1.1 |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
text: str |
|
|
|
|
|
app = FastAPI(title="Qwen2.5-1.5B 4bit + LoRA API") |
|
|
|
|
|
@app.get("/") |
|
|
def health(): |
|
|
return {"status": "ok", "base": MODEL_ID, "adapter": ADAPTER_ID, "use_4bit": USE_4BIT} |
|
|
|
|
|
def build_prompt(messages: List[Dict[str, str]]) -> str: |
|
|
|
|
|
return tokenizer.apply_chat_template( |
|
|
[{"role": m["role"], "content": m["content"]} for m in messages], |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
|
|
|
@app.post("/v1/chat", response_model=ChatResponse) |
|
|
def chat(req: ChatRequest): |
|
|
prompt = build_prompt([m.dict() for m in req.messages]) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
repetition_penalty=req.repetition_penalty, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
|
|
return ChatResponse(text=text) |
|
|
|
|
|
def stream_generate(req: ChatRequest) -> Iterator[str]: |
|
|
prompt = build_prompt([m.dict() for m in req.messages]) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
**inputs, |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
repetition_penalty=req.repetition_penalty, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
streamer=streamer, |
|
|
) |
|
|
|
|
|
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
thread.start() |
|
|
|
|
|
for token_text in streamer: |
|
|
yield f'{{"delta": {token_text.__repr__()}}}\n' |
|
|
|
|
|
@app.post("/v1/chat/stream") |
|
|
def chat_stream(req: ChatRequest = Body(...)): |
|
|
return StreamingResponse(stream_generate(req), media_type="application/x-ndjson") |
|
|
|