fugthchat commited on
Commit
4b99165
·
verified ·
1 Parent(s): a867634

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -74
app.py CHANGED
@@ -1,15 +1,16 @@
1
  import os
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from huggingface_hub import hf_hub_download
8
- import logging
9
- import threading
10
  from contextlib import asynccontextmanager
11
 
12
- # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
 
15
  # --- MODEL MAP (Using the smarter Phi-3) ---
@@ -28,46 +29,18 @@ MODEL_MAP = {
28
  }
29
  }
30
 
31
- # --- GLOBAL CACHE & LOCK ---
32
  llm_cache = {}
33
- model_lock = threading.Lock() # For loading models
34
- llm_lock = threading.Lock() # For running generation
35
-
36
- # --- LIFESPAN FUNCTION ---
37
- @asynccontextmanager
38
- async def lifespan(app: FastAPI):
39
- # This code runs ON STARTUP
40
- logging.info("Server starting up... Acquiring lock to pre-load 'light' model (Phi-3).")
41
- with model_lock:
42
- get_llm_instance("light")
43
- logging.info("Server is ready and 'light' model (Phi-3) is loaded.")
44
-
45
- yield
46
-
47
- # This code runs ON SHUTDOWN
48
- logging.info("Server shutting down...")
49
- llm_cache.clear()
50
-
51
- # Pass the lifespan function to FastAPI
52
- app = FastAPI(lifespan=lifespan)
53
-
54
- # --- CORS ---
55
- app.add_middleware(
56
- CORSMiddleware,
57
- allow_origins=["*"],
58
- allow_credentials=True,
59
- allow_methods=["*"],
60
- allow_headers=["*"],
61
- )
62
 
63
- # --- Helper Function to Load Model ---
64
  def get_llm_instance(choice: str) -> Llama:
65
- # Use the *model* lock for loading
66
  with model_lock:
67
  if choice not in MODEL_MAP:
68
  logging.error(f"Invalid model choice: {choice}")
69
  return None
70
-
71
  if choice in llm_cache:
72
  logging.info(f"Using cached model: {choice}")
73
  return llm_cache[choice]
@@ -99,16 +72,83 @@ def get_llm_instance(choice: str) -> Llama:
99
  logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
100
  return None
101
 
102
- # --- API Data Models (SIMPLIFIED) ---
103
- class StoryPrompt(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  prompt: str
105
  model_choice: str
106
- feedback: str = ""
107
- story_memory: str = ""
108
 
109
  # --- API Endpoints ---
110
  @app.get("/")
111
  def get_status():
 
112
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
113
  return {
114
  "status": "AI server is online",
@@ -116,40 +156,36 @@ def get_status():
116
  "models": list(MODEL_MAP.keys())
117
  }
118
 
119
- @app.post("/generate")
120
- async def generate_story(prompt: StoryPrompt):
121
  """
122
- Main generation endpoint.
123
- This is simple and stable.
124
  """
125
- logging.info("Request received. Waiting to acquire LLM lock...")
126
- # Use the *generation* lock
127
- with llm_lock:
128
- logging.info("Lock acquired. Processing request.")
129
- try:
130
- llm = get_llm_instance(prompt.model_choice)
131
- if llm is None:
132
- logging.error(f"Failed to get model for choice: {prompt.model_choice}")
133
- return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
134
-
135
- # We trust the frontend to build the full prompt
136
- final_prompt = prompt.prompt
137
-
138
- logging.info(f"Generating with {prompt.model_choice}...")
139
- output = llm(
140
- final_prompt,
141
- max_tokens=512,
142
- stop=["<|user|>", "<|endoftext|>", "user:"],
143
- echo=False
144
- )
145
-
146
- generated_text = output["choices"][0]["text"].strip()
147
- logging.info("Generation complete.")
148
-
149
- return {"story_text": generated_text}
150
 
151
- except Exception as e:
152
- logging.error(f"An internal error occurred during generation: {e}", exc_info=True)
153
- return JSONResponse(status_code=500, content={"error": "An unexpected error occurred."})
154
- finally:
155
- logging.info("Releasing LLM lock.")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import uuid
3
+ import threading
4
+ import logging
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
8
  from llama_cpp import Llama
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import hf_hub_download
 
 
11
  from contextlib import asynccontextmanager
12
 
13
+ # --- Setup ---
14
  logging.basicConfig(level=logging.INFO)
15
 
16
  # --- MODEL MAP (Using the smarter Phi-3) ---
 
29
  }
30
  }
31
 
32
+ # --- GLOBAL CACHE & LOCKS ---
33
  llm_cache = {}
34
+ model_lock = threading.Lock() # Ensures only one model loads at a time
35
+ llm_lock = threading.Lock() # Ensures only one generation job runs at a time
36
+ JOBS = {} # Our in-memory "database" for background jobs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # --- Helper: Load Model ---
39
  def get_llm_instance(choice: str) -> Llama:
 
40
  with model_lock:
41
  if choice not in MODEL_MAP:
42
  logging.error(f"Invalid model choice: {choice}")
43
  return None
 
44
  if choice in llm_cache:
45
  logging.info(f"Using cached model: {choice}")
46
  return llm_cache[choice]
 
72
  logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
73
  return None
74
 
75
+ # --- Helper: The Background AI Task ---
76
+ def run_generation_in_background(job_id: str, model_choice: str, prompt: str):
77
+ """
78
+ This function runs in a separate thread.
79
+ It performs the long-running AI generation.
80
+ """
81
+ global JOBS
82
+ try:
83
+ logging.info(f"Job {job_id}: Waiting to acquire LLM lock...")
84
+ with llm_lock:
85
+ logging.info(f"Job {job_id}: Lock acquired. Loading model.")
86
+ llm = get_llm_instance(model_choice)
87
+ if llm is None:
88
+ raise Exception("Model could not be loaded.")
89
+
90
+ JOBS[job_id]["status"] = "processing"
91
+ logging.info(f"Job {job_id}: Processing prompt...")
92
+
93
+ output = llm(
94
+ prompt,
95
+ max_tokens=512,
96
+ stop=["<|user|>", "<|endoftext|>", "user:"],
97
+ echo=False
98
+ )
99
+
100
+ generated_text = output["choices"][0]["text"].strip()
101
+
102
+ JOBS[job_id]["status"] = "complete"
103
+ JOBS[job_id]["result"] = generated_text
104
+ logging.info(f"Job {job_id}: Complete.")
105
+
106
+ except Exception as e:
107
+ logging.error(f"Job {job_id}: Failed. Error: {e}")
108
+ JOBS[job_id]["status"] = "error"
109
+ JOBS[job_id]["result"] = str(e)
110
+ finally:
111
+ logging.info(f"Job {job_id}: LLM lock released.")
112
+
113
+
114
+ # --- FastAPI App & Lifespan ---
115
+ @asynccontextmanager
116
+ async def lifespan(app: FastAPI):
117
+ logging.info("Server starting up... Pre-loading 'light' model.")
118
+ get_llm_instance("light")
119
+ logging.info("Server is ready and 'light' model is loaded.")
120
+ yield
121
+ logging.info("Server shutting down...")
122
+ llm_cache.clear()
123
+
124
+ app = FastAPI(lifespan=lifespan)
125
+
126
+ # --- !!! THIS IS THE CORS FIX !!! ---
127
+ # We are explicitly adding your GitHub Pages URL
128
+ origins = [
129
+ "https://fugthchat.github.io", # <-- YOUR LIVE SITE
130
+ "http://localhost", # For local testing
131
+ "http://127.0.0.1:5500" # For local testing
132
+ ]
133
+
134
+ app.add_middleware(
135
+ CORSMiddleware,
136
+ allow_origins=origins,
137
+ allow_credentials=True,
138
+ allow_methods=["*"],
139
+ allow_headers=["*"],
140
+ )
141
+ # --- END OF CORS FIX ---
142
+
143
+ # --- API Data Models ---
144
+ class SubmitPrompt(BaseModel):
145
  prompt: str
146
  model_choice: str
 
 
147
 
148
  # --- API Endpoints ---
149
  @app.get("/")
150
  def get_status():
151
+ """This is the 'wake up' and status check endpoint."""
152
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
153
  return {
154
  "status": "AI server is online",
 
156
  "models": list(MODEL_MAP.keys())
157
  }
158
 
159
+ @app.post("/submit_job")
160
+ async def submit_job(prompt: SubmitPrompt):
161
  """
162
+ Instantly accepts a job and starts it in the background.
 
163
  """
164
+ job_id = str(uuid.uuid4())
165
+ JOBS[job_id] = {"status": "pending", "result": None}
166
+
167
+ thread = threading.Thread(
168
+ target=run_generation_in_background,
169
+ args=(job_id, prompt.model_choice, prompt.prompt)
170
+ )
171
+ thread.start()
172
+
173
+ logging.info(f"Job {job_id} submitted.")
174
+ return {"job_id": job_id}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ @app.get("/get_job_status/{job_id}")
177
+ async def get_job_status(job_id: str):
178
+ """
179
+ Allows the frontend to check on a job.
180
+ """
181
+ job = JOBS.get(job_id)
182
+
183
+ if job is None:
184
+ return JSONResponse(status_code=404, content={"error": "Job not found."})
185
+
186
+ if job["status"] in ["complete", "error"]:
187
+ result = job
188
+ del JOBS[job_id] # Clean up
189
+ return result
190
+
191
+ return {"status": job["status"]}