from fastapi import FastAPI, Query, HTTPException from transformers import AutoTokenizer, AutoModelForCausalLM from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles import os import torch # ✅ Hugging Face cache directory os.environ["HF_HOME"] = "/tmp" os.environ["TRANSFORMERS_CACHE"] = "/tmp" # ----------------------- # Model Setup # ----------------------- model_id = "LLM360/K2-Think" # Load tokenizer and model print("Loading tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp") model = AutoModelForCausalLM.from_pretrained( model_id, cache_dir="/tmp", device_map="auto", # Automatically select GPU/CPU torch_dtype=torch.float16 ) print("Model loaded successfully!") # ----------------------- # FastAPI Setup # ----------------------- app = FastAPI(title="K2-Think QA API", description="Serving K2-Think Hugging Face model with FastAPI") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static folder app.mount("/static", StaticFiles(directory="static"), name="static") # ----------------------- # Request Schema # ----------------------- class QueryRequest(BaseModel): question: str max_new_tokens: int = 50 temperature: float = 0.7 top_p: float = 0.9 # ----------------------- # Endpoints # ----------------------- @app.get("/") def home(): return {"message": "Welcome to K2-Think QA API 🚀"} @app.get("/ui", response_class=HTMLResponse) def serve_ui(): html_path = os.path.join("static", "index.html") with open(html_path, "r", encoding="utf-8") as f: return HTMLResponse(f.read()) @app.get("/health") def health(): return {"status": "ok"} @app.get("/ask") def ask(question: str = Query(...), max_new_tokens: int = Query(50)): try: inputs = tokenizer(question, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True ) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) return {"question": question, "answer": answer} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict") def predict(request: QueryRequest): try: inputs = tokenizer(request.question, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=request.max_new_tokens, do_sample=True, temperature=request.temperature, top_p=request.top_p, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True ) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) return {"question": request.question, "answer": answer} except Exception as e: raise HTTPException(status_code=500, detail=str(e))