WildOjisan's picture
.
ac3772f
import os
# ✅ Hugging Face 캐시/토큰 경로를 쓰기 가능한 위치로 지정 (Spaces에서는 /data가 안전)
os.environ["HF_HOME"] = "/data"
os.environ["TRANSFORMERS_CACHE"] = "/data/transformers"
os.environ["HF_HUB_CACHE"] = "/data/hub"
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("PYTORCH_FORCE_MPS_FALLBACK", "1")
import threading
from typing import List, Dict, Iterator
import torch
from fastapi import FastAPI, Body
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from peft import PeftModel
# ----------------- 환경 기본값 -----------------
os.environ.setdefault("PYTORCH_FORCE_MPS_FALLBACK", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
# CPU-only: 4bit 비활성화, float32
USE_4BIT = False
COMPUTE_DTYPE = torch.float32
# 베이스/어댑터 경로
MODEL_ID = os.environ.get("MODEL_ID", "unsloth/Qwen2.5-1.5B-Instruct")
ADAPTER_ID = os.environ.get("ADAPTER_ID", "WildOjisan/qwen2_5_lora_adapter_test1")
# 스레드 수
try:
torch.set_num_threads(max(1, os.cpu_count() or 1))
except Exception:
pass
# ----------------- 로드 -----------------
print(f"[BOOT] Base: {MODEL_ID}")
print(f"[BOOT] LoRA: {ADAPTER_ID}")
device_map = "cpu"
# 토크나이저: 어댑터 쪽에 커스텀 토큰/템플릿이 있을 수 있으니 우선 시도
try:
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False, trust_remote_code=True)
print("[BOOT] Tokenizer loaded from ADAPTER_ID.")
except Exception:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=False, trust_remote_code=True)
print("[BOOT] Tokenizer loaded from MODEL_ID.")
# pad 토큰 보정(Colab 코드와 동일한 경고 회피)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 베이스 모델 CPU(float32) 로드
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map=device_map,
trust_remote_code=True,
torch_dtype=COMPUTE_DTYPE,
low_cpu_mem_usage=True,
)
# LoRA 어댑터 얹기 (merge 금지: Colab과 같은 동작)
model = PeftModel.from_pretrained(base_model, ADAPTER_ID)
model.eval()
# ----------------- API 스키마/앱 -----------------
class ChatMessage(BaseModel):
role: str = Field(..., description="system | user | assistant")
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_new_tokens: int = 128
temperature: float = 0.7 # Colab 기본에 맞춤
top_p: float = 0.9 # Colab 기본에 맞춤
repetition_penalty: float = 1.1
class ChatResponse(BaseModel):
text: str
app = FastAPI(title="Qwen2.5-1.5B 4bit + LoRA API")
@app.get("/")
def health():
return {"status": "ok", "base": MODEL_ID, "adapter": ADAPTER_ID, "use_4bit": USE_4BIT}
def build_prompt(messages: List[Dict[str, str]]) -> str:
# Qwen 권장 chat 템플릿 (Colab과 동일)
return tokenizer.apply_chat_template(
[{"role": m["role"], "content": m["content"]} for m in messages],
tokenize=False,
add_generation_prompt=True,
)
@app.post("/v1/chat", response_model=ChatResponse)
def chat(req: ChatRequest):
prompt = build_prompt([m.dict() for m in req.messages])
inputs = tokenizer(prompt, return_tensors="pt")
# 모델의 디바이스로 이동
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return ChatResponse(text=text)
def stream_generate(req: ChatRequest) -> Iterator[str]:
prompt = build_prompt([m.dict() for m in req.messages])
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**inputs,
max_new_tokens=req.max_new_tokens,
do_sample=True,
temperature=req.temperature,
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for token_text in streamer:
yield f'{{"delta": {token_text.__repr__()}}}\n'
@app.post("/v1/chat/stream")
def chat_stream(req: ChatRequest = Body(...)):
return StreamingResponse(stream_generate(req), media_type="application/x-ndjson")