fugthchat commited on
Commit
a867634
·
verified ·
1 Parent(s): c8d6e8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -112
app.py CHANGED
@@ -1,19 +1,18 @@
1
  import os
2
- import uuid
3
- import threading
4
- import logging
5
  from fastapi import FastAPI, Request
6
  from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
8
  from llama_cpp import Llama
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from huggingface_hub import hf_hub_download
 
 
11
  from contextlib import asynccontextmanager
12
 
13
- # --- Setup ---
14
  logging.basicConfig(level=logging.INFO)
15
 
16
- # --- Model Map (Using the smarter Phi-3) ---
17
  MODEL_MAP = {
18
  "light": {
19
  "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
@@ -29,17 +28,41 @@ MODEL_MAP = {
29
  }
30
  }
31
 
32
- # --- Global Caches & Locks ---
33
  llm_cache = {}
34
- model_lock = threading.Lock() # Ensures only one model loads at a time
35
- llm_lock = threading.Lock() # Ensures only one generation job runs at a time
36
 
37
- # This is our new in-memory "database" for jobs
38
- # It will hold the status and results of background tasks
39
- JOBS = {}
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # --- Helper: Load Model ---
 
 
 
 
 
 
 
 
 
 
 
 
42
  def get_llm_instance(choice: str) -> Llama:
 
43
  with model_lock:
44
  if choice not in MODEL_MAP:
45
  logging.error(f"Invalid model choice: {choice}")
@@ -76,76 +99,16 @@ def get_llm_instance(choice: str) -> Llama:
76
  logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
77
  return None
78
 
79
- # --- Helper: The Background AI Task ---
80
- def run_generation_in_background(job_id: str, model_choice: str, prompt: str):
81
- """
82
- This function runs in a separate thread.
83
- It performs the long-running AI generation.
84
- """
85
- global JOBS
86
- try:
87
- # Acquire the lock. If another job is running, this will wait.
88
- logging.info(f"Job {job_id}: Waiting to acquire LLM lock...")
89
- with llm_lock:
90
- logging.info(f"Job {job_id}: Lock acquired. Loading model.")
91
- llm = get_llm_instance(model_choice)
92
- if llm is None:
93
- raise Exception("Model could not be loaded.")
94
-
95
- JOBS[job_id]["status"] = "processing"
96
- logging.info(f"Job {job_id}: Processing prompt...")
97
-
98
- output = llm(
99
- prompt,
100
- max_tokens=512,
101
- stop=["<|user|>", "<|endoftext|>", "user:"],
102
- echo=False
103
- )
104
-
105
- generated_text = output["choices"][0]["text"].strip()
106
-
107
- # Save the result and mark as complete
108
- JOBS[job_id]["status"] = "complete"
109
- JOBS[job_id]["result"] = generated_text
110
- logging.info(f"Job {job_id}: Complete.")
111
-
112
- except Exception as e:
113
- logging.error(f"Job {job_id}: Failed. Error: {e}")
114
- JOBS[job_id]["status"] = "error"
115
- JOBS[job_id]["result"] = str(e)
116
- finally:
117
- # The lock is automatically released by the 'with' statement
118
- logging.info(f"Job {job_id}: LLM lock released.")
119
-
120
-
121
- # --- FastAPI App & Lifespan ---
122
- @asynccontextmanager
123
- async def lifespan(app: FastAPI):
124
- logging.info("Server starting up... Pre-loading 'light' model.")
125
- get_llm_instance("light")
126
- logging.info("Server is ready and 'light' model is loaded.")
127
- yield
128
- logging.info("Server shutting down...")
129
- llm_cache.clear()
130
-
131
- app = FastAPI(lifespan=lifespan)
132
- app.add_middleware(
133
- CORSMiddleware,
134
- allow_origins=["*"],
135
- allow_credentials=True,
136
- allow_methods=["*"],
137
- allow_headers=["*"],
138
- )
139
-
140
- # --- API Data Models ---
141
- class SubmitPrompt(BaseModel):
142
  prompt: str
143
  model_choice: str
 
 
144
 
145
  # --- API Endpoints ---
146
  @app.get("/")
147
  def get_status():
148
- """This is the 'wake up' and status check endpoint."""
149
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
150
  return {
151
  "status": "AI server is online",
@@ -153,42 +116,40 @@ def get_status():
153
  "models": list(MODEL_MAP.keys())
154
  }
155
 
156
- @app.post("/submit_job")
157
- async def submit_job(prompt: SubmitPrompt):
158
  """
159
- NEW: Instantly accepts a job and starts it in the background.
 
160
  """
161
- job_id = str(uuid.uuid4())
162
-
163
- # Store the job as "pending"
164
- JOBS[job_id] = {"status": "pending", "result": None}
165
-
166
- # Start the background thread
167
- thread = threading.Thread(
168
- target=run_generation_in_background,
169
- args=(job_id, prompt.model_choice, prompt.prompt)
170
- )
171
- thread.start()
172
-
173
- logging.info(f"Job {job_id} submitted.")
174
- # Return the Job ID to the user immediately
175
- return {"job_id": job_id}
176
 
177
- @app.get("/get_job_status/{job_id}")
178
- async def get_job_status(job_id: str):
179
- """
180
- NEW: Allows the frontend to check on a job.
181
- """
182
- job = JOBS.get(job_id)
183
-
184
- if job is None:
185
- return JSONResponse(status_code=404, content={"error": "Job not found."})
186
-
187
- # If the job is done, send the result and remove it from memory
188
- if job["status"] in ["complete", "error"]:
189
- result = job
190
- del JOBS[job_id] # Clean up
191
- return result
192
-
193
- # If not done, just send the current status
194
- return {"status": job["status"]}
 
 
 
 
1
  import os
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse
4
  from pydantic import BaseModel
5
  from llama_cpp import Llama
6
  from fastapi.middleware.cors import CORSMiddleware
7
  from huggingface_hub import hf_hub_download
8
+ import logging
9
+ import threading
10
  from contextlib import asynccontextmanager
11
 
12
+ # Set up logging
13
  logging.basicConfig(level=logging.INFO)
14
 
15
+ # --- MODEL MAP (Using the smarter Phi-3) ---
16
  MODEL_MAP = {
17
  "light": {
18
  "repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
 
28
  }
29
  }
30
 
31
+ # --- GLOBAL CACHE & LOCK ---
32
  llm_cache = {}
33
+ model_lock = threading.Lock() # For loading models
34
+ llm_lock = threading.Lock() # For running generation
35
 
36
+ # --- LIFESPAN FUNCTION ---
37
+ @asynccontextmanager
38
+ async def lifespan(app: FastAPI):
39
+ # This code runs ON STARTUP
40
+ logging.info("Server starting up... Acquiring lock to pre-load 'light' model (Phi-3).")
41
+ with model_lock:
42
+ get_llm_instance("light")
43
+ logging.info("Server is ready and 'light' model (Phi-3) is loaded.")
44
+
45
+ yield
46
+
47
+ # This code runs ON SHUTDOWN
48
+ logging.info("Server shutting down...")
49
+ llm_cache.clear()
50
 
51
+ # Pass the lifespan function to FastAPI
52
+ app = FastAPI(lifespan=lifespan)
53
+
54
+ # --- CORS ---
55
+ app.add_middleware(
56
+ CORSMiddleware,
57
+ allow_origins=["*"],
58
+ allow_credentials=True,
59
+ allow_methods=["*"],
60
+ allow_headers=["*"],
61
+ )
62
+
63
+ # --- Helper Function to Load Model ---
64
  def get_llm_instance(choice: str) -> Llama:
65
+ # Use the *model* lock for loading
66
  with model_lock:
67
  if choice not in MODEL_MAP:
68
  logging.error(f"Invalid model choice: {choice}")
 
99
  logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
100
  return None
101
 
102
+ # --- API Data Models (SIMPLIFIED) ---
103
+ class StoryPrompt(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  prompt: str
105
  model_choice: str
106
+ feedback: str = ""
107
+ story_memory: str = ""
108
 
109
  # --- API Endpoints ---
110
  @app.get("/")
111
  def get_status():
 
112
  loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
113
  return {
114
  "status": "AI server is online",
 
116
  "models": list(MODEL_MAP.keys())
117
  }
118
 
119
+ @app.post("/generate")
120
+ async def generate_story(prompt: StoryPrompt):
121
  """
122
+ Main generation endpoint.
123
+ This is simple and stable.
124
  """
125
+ logging.info("Request received. Waiting to acquire LLM lock...")
126
+ # Use the *generation* lock
127
+ with llm_lock:
128
+ logging.info("Lock acquired. Processing request.")
129
+ try:
130
+ llm = get_llm_instance(prompt.model_choice)
131
+ if llm is None:
132
+ logging.error(f"Failed to get model for choice: {prompt.model_choice}")
133
+ return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
 
 
 
 
 
 
134
 
135
+ # We trust the frontend to build the full prompt
136
+ final_prompt = prompt.prompt
137
+
138
+ logging.info(f"Generating with {prompt.model_choice}...")
139
+ output = llm(
140
+ final_prompt,
141
+ max_tokens=512,
142
+ stop=["<|user|>", "<|endoftext|>", "user:"],
143
+ echo=False
144
+ )
145
+
146
+ generated_text = output["choices"][0]["text"].strip()
147
+ logging.info("Generation complete.")
148
+
149
+ return {"story_text": generated_text}
150
+
151
+ except Exception as e:
152
+ logging.error(f"An internal error occurred during generation: {e}", exc_info=True)
153
+ return JSONResponse(status_code=500, content={"error": "An unexpected error occurred."})
154
+ finally:
155
+ logging.info("Releasing LLM lock.")