Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,12 +12,14 @@ import psutil
|
|
| 12 |
import gc
|
| 13 |
import threading
|
| 14 |
from queue import Queue
|
| 15 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
|
| 17 |
# Set environment variables to optimize CPU performance
|
| 18 |
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
| 19 |
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
# Load models with caching and CPU optimization
|
| 22 |
@st.cache_resource()
|
| 23 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
|
@@ -43,9 +45,6 @@ def load_model(model_name, model_class, is_bc=False, device=None):
|
|
| 43 |
model.to(device)
|
| 44 |
return tokenizer, model
|
| 45 |
|
| 46 |
-
# Set device globally
|
| 47 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 48 |
-
|
| 49 |
# Pre-process text function to avoid doing it multiple times
|
| 50 |
@st.cache_data
|
| 51 |
def preprocess_text(text):
|
|
@@ -271,41 +270,19 @@ with st.container():
|
|
| 271 |
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
| 272 |
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
if 'history' not in st.session_state:
|
| 276 |
-
st.session_state.history = []
|
| 277 |
-
if 'latest_result' not in st.session_state:
|
| 278 |
-
st.session_state.latest_result = None
|
| 279 |
if 'models_loaded' not in st.session_state:
|
| 280 |
-
st.session_state.models_loaded = False
|
| 281 |
-
|
| 282 |
-
# Load the selected models - only reload if model selection changes
|
| 283 |
-
if not st.session_state.models_loaded or 'prev_models' not in st.session_state or (
|
| 284 |
-
st.session_state.prev_models['qatc'] != qatc_model_name or
|
| 285 |
-
st.session_state.prev_models['bc'] != bc_model_name or
|
| 286 |
-
st.session_state.prev_models['tc'] != tc_model_name):
|
| 287 |
-
|
| 288 |
with st.spinner("Loading models..."):
|
| 289 |
-
# Clear memory before loading new models
|
| 290 |
-
gc.collect()
|
| 291 |
-
if DEVICE == "cpu":
|
| 292 |
-
torch.set_num_threads(num_threads)
|
| 293 |
-
|
| 294 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
| 295 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
| 296 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
| 297 |
-
|
| 298 |
-
st.session_state.prev_models = {
|
| 299 |
-
'qatc': qatc_model_name,
|
| 300 |
-
'bc': bc_model_name,
|
| 301 |
-
'tc': tc_model_name
|
| 302 |
-
}
|
| 303 |
st.session_state.models_loaded = True
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
| 309 |
|
| 310 |
# Icons for results
|
| 311 |
verdict_icons = {
|
|
|
|
| 12 |
import gc
|
| 13 |
import threading
|
| 14 |
from queue import Queue
|
|
|
|
| 15 |
|
| 16 |
# Set environment variables to optimize CPU performance
|
| 17 |
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
| 18 |
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
| 19 |
|
| 20 |
+
# Set device globally
|
| 21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
|
| 23 |
# Load models with caching and CPU optimization
|
| 24 |
@st.cache_resource()
|
| 25 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
|
|
|
| 45 |
model.to(device)
|
| 46 |
return tokenizer, model
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
# Pre-process text function to avoid doing it multiple times
|
| 49 |
@st.cache_data
|
| 50 |
def preprocess_text(text):
|
|
|
|
| 270 |
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
| 271 |
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
| 272 |
|
| 273 |
+
# Load models once and keep them in memory
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
if 'models_loaded' not in st.session_state:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
with st.spinner("Loading models..."):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
| 277 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
| 278 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
st.session_state.models_loaded = True
|
| 280 |
+
|
| 281 |
+
# Store verification history
|
| 282 |
+
if 'history' not in st.session_state:
|
| 283 |
+
st.session_state.history = []
|
| 284 |
+
if 'latest_result' not in st.session_state:
|
| 285 |
+
st.session_state.latest_result = None
|
| 286 |
|
| 287 |
# Icons for results
|
| 288 |
verdict_icons = {
|