rabiyulfahim's picture
Update app.py
ff74138 verified
raw
history blame
3.22 kB
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import os
import torch
# ✅ Hugging Face cache directory
os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# -----------------------
# Model Setup
# -----------------------
model_id = "LLM360/K2-Think"
# Load tokenizer and model
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
model = AutoModelForCausalLM.from_pretrained(
model_id,
cache_dir="/tmp",
device_map="auto", # Automatically select GPU/CPU
torch_dtype=torch.float16
)
print("Model loaded successfully!")
# -----------------------
# FastAPI Setup
# -----------------------
app = FastAPI(title="K2-Think QA API", description="Serving K2-Think Hugging Face model with FastAPI")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static folder
app.mount("/static", StaticFiles(directory="static"), name="static")
# -----------------------
# Request Schema
# -----------------------
class QueryRequest(BaseModel):
question: str
max_new_tokens: int = 50
temperature: float = 0.7
top_p: float = 0.9
# -----------------------
# Endpoints
# -----------------------
@app.get("/")
def home():
return {"message": "Welcome to K2-Think QA API 🚀"}
@app.get("/ui", response_class=HTMLResponse)
def serve_ui():
html_path = os.path.join("static", "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/ask")
def ask(question: str = Query(...), max_new_tokens: int = Query(50)):
try:
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {"question": question, "answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
def predict(request: QueryRequest):
try:
inputs = tokenizer(request.question, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
do_sample=True,
temperature=request.temperature,
top_p=request.top_p,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {"question": request.question, "answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))