Cornelius commited on
Commit
85ce009
·
1 Parent(s): c775d45

Fix GPU error - default to CPU and enhance GPU detection

Browse files
Files changed (2) hide show
  1. app.py +39 -20
  2. teacher_agent_dev/compare_strategies.py +16 -2
app.py CHANGED
@@ -3,17 +3,13 @@ Gradio app for MentorFlow - Teacher-Student RL System
3
  Deployed on Hugging Face Spaces with GPU support
4
  """
5
 
6
- import gradio as gr
7
  import sys
8
  import os
9
  import subprocess
10
  from pathlib import Path
11
 
12
- # Monkey-patch to fix Gradio 4.44.x schema generation bug
13
  # Prevents TypeError: argument of type 'bool' is not iterable
14
- import sys
15
-
16
- # Patch BEFORE importing gradio to ensure it takes effect
17
  def _patch_gradio_schema_bug():
18
  """Patch Gradio's buggy schema generation."""
19
  try:
@@ -25,12 +21,10 @@ def _patch_gradio_schema_bug():
25
 
26
  def _patched_get_type(schema):
27
  """Handle bool schemas that cause the bug."""
28
- # Bug fix: schema is sometimes a bool
29
  if isinstance(schema, bool):
30
  return "bool"
31
  if schema is None:
32
  return "Any"
33
- # Must be dict to check membership
34
  if not isinstance(schema, dict):
35
  return "Any"
36
  try:
@@ -42,7 +36,7 @@ def _patch_gradio_schema_bug():
42
 
43
  gradio_client_utils.get_type = _patched_get_type
44
 
45
- # Also patch the wrapper function that calls get_type
46
  if hasattr(gradio_client_utils, '_json_schema_to_python_type'):
47
  _original_json_to_type = gradio_client_utils._json_schema_to_python_type
48
 
@@ -51,18 +45,20 @@ def _patch_gradio_schema_bug():
51
  try:
52
  return _original_json_to_type(schema, defs)
53
  except (TypeError, AttributeError) as e:
54
- if "is not iterable" in str(e) or "bool" in str(type(e)):
55
  return "Any"
56
  raise
57
 
58
  gradio_client_utils._json_schema_to_python_type = _patched_json_to_type
59
-
60
  except (ImportError, AttributeError):
61
  pass
62
 
63
- # Apply patch immediately
64
  _patch_gradio_schema_bug()
65
 
 
 
 
66
  # Add project paths
67
  sys.path.insert(0, str(Path(__file__).parent))
68
  sys.path.insert(0, str(Path(__file__).parent / "teacher_agent_dev"))
@@ -80,19 +76,32 @@ def run_comparison(iterations: int, seed: int, use_deterministic: bool, device:
80
  """
81
 
82
  # Set device environment variable for subprocess
83
- # Check if CUDA is actually available before using
84
  if device == "cuda":
85
  try:
86
  import torch
87
- if not torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
88
  device = "cpu"
89
  except ImportError:
 
90
  device = "cpu"
91
- except Exception:
 
92
  device = "cpu"
93
 
94
  # Set environment variable for subprocess to pick up
95
  os.environ["CUDA_DEVICE"] = device
 
96
 
97
  # Prepare command
98
  cmd = [
@@ -163,11 +172,21 @@ def check_gpu():
163
  try:
164
  import torch
165
  if torch.cuda.is_available():
166
- return f"✅ GPU Available: {torch.cuda.get_device_name(0)}"
 
 
 
 
 
167
  else:
 
 
 
168
  return "⚠️ No GPU available, using CPU"
169
- except:
170
- return "⚠️ Could not check GPU status"
 
 
171
 
172
 
173
  # Create Gradio interface
@@ -221,10 +240,10 @@ with gr.Blocks(title="MentorFlow - Strategy Comparison") as demo:
221
  )
222
 
223
  device = gr.Radio(
224
- choices=["cuda", "cpu"],
225
- value="cuda",
226
  label="Device",
227
- info="Use GPU (cuda) if available, CPU otherwise"
228
  )
229
 
230
  with gr.Column():
 
3
  Deployed on Hugging Face Spaces with GPU support
4
  """
5
 
 
6
  import sys
7
  import os
8
  import subprocess
9
  from pathlib import Path
10
 
11
+ # Monkey-patch to fix Gradio schema generation bug BEFORE importing gradio
12
  # Prevents TypeError: argument of type 'bool' is not iterable
 
 
 
13
  def _patch_gradio_schema_bug():
14
  """Patch Gradio's buggy schema generation."""
15
  try:
 
21
 
22
  def _patched_get_type(schema):
23
  """Handle bool schemas that cause the bug."""
 
24
  if isinstance(schema, bool):
25
  return "bool"
26
  if schema is None:
27
  return "Any"
 
28
  if not isinstance(schema, dict):
29
  return "Any"
30
  try:
 
36
 
37
  gradio_client_utils.get_type = _patched_get_type
38
 
39
+ # Patch the wrapper function too
40
  if hasattr(gradio_client_utils, '_json_schema_to_python_type'):
41
  _original_json_to_type = gradio_client_utils._json_schema_to_python_type
42
 
 
45
  try:
46
  return _original_json_to_type(schema, defs)
47
  except (TypeError, AttributeError) as e:
48
+ if "is not iterable" in str(e):
49
  return "Any"
50
  raise
51
 
52
  gradio_client_utils._json_schema_to_python_type = _patched_json_to_type
 
53
  except (ImportError, AttributeError):
54
  pass
55
 
56
+ # Apply patch BEFORE importing gradio
57
  _patch_gradio_schema_bug()
58
 
59
+ # Now import gradio (patch will be in effect)
60
+ import gradio as gr
61
+
62
  # Add project paths
63
  sys.path.insert(0, str(Path(__file__).parent))
64
  sys.path.insert(0, str(Path(__file__).parent / "teacher_agent_dev"))
 
76
  """
77
 
78
  # Set device environment variable for subprocess
79
+ # On Hugging Face Spaces, check GPU availability more carefully
80
  if device == "cuda":
81
  try:
82
  import torch
83
+ # Check if CUDA is available
84
+ if torch.cuda.is_available():
85
+ try:
86
+ # Try to get device name to verify GPU works
87
+ gpu_name = torch.cuda.get_device_name(0)
88
+ print(f"✅ GPU available: {gpu_name}")
89
+ except Exception as e:
90
+ print(f"⚠️ GPU detection failed: {e}, falling back to CPU")
91
+ device = "cpu"
92
+ else:
93
+ print("⚠️ CUDA not available, using CPU")
94
  device = "cpu"
95
  except ImportError:
96
+ print("⚠️ PyTorch not available, using CPU")
97
  device = "cpu"
98
+ except Exception as e:
99
+ print(f"⚠️ GPU check error: {e}, using CPU")
100
  device = "cpu"
101
 
102
  # Set environment variable for subprocess to pick up
103
  os.environ["CUDA_DEVICE"] = device
104
+ print(f"🔧 Using device: {device}")
105
 
106
  # Prepare command
107
  cmd = [
 
172
  try:
173
  import torch
174
  if torch.cuda.is_available():
175
+ try:
176
+ gpu_name = torch.cuda.get_device_name(0)
177
+ gpu_count = torch.cuda.device_count()
178
+ return f"✅ GPU Available: {gpu_name} (Count: {gpu_count})"
179
+ except Exception as e:
180
+ return f"⚠️ GPU detected but error accessing: {str(e)}"
181
  else:
182
+ # Check if we're on Hugging Face Spaces
183
+ if os.getenv("SPACE_ID"):
184
+ return "⚠️ No GPU available on this Space. Please upgrade to GPU tier."
185
  return "⚠️ No GPU available, using CPU"
186
+ except ImportError:
187
+ return "⚠️ PyTorch not installed"
188
+ except Exception as e:
189
+ return f"⚠️ Could not check GPU status: {str(e)}"
190
 
191
 
192
  # Create Gradio interface
 
240
  )
241
 
242
  device = gr.Radio(
243
+ choices=["cpu", "cuda"],
244
+ value="cpu", # Default to CPU for reliability on HF Spaces
245
  label="Device",
246
+ info="CPU (recommended) or CUDA/GPU if available on your Space"
247
  )
248
 
249
  with gr.Column():
teacher_agent_dev/compare_strategies.py CHANGED
@@ -90,11 +90,25 @@ def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accu
90
  if device == "cuda":
91
  try:
92
  import torch
93
- if not torch.cuda.is_available():
 
 
 
 
 
 
 
 
94
  device = "cpu"
95
  print("⚠️ CUDA not available, using CPU")
96
- except:
 
 
 
97
  device = "cpu"
 
 
 
98
 
99
  student = LMStudentAgent(
100
  learning_rate=5e-5, # LM fine-tuning learning rate
 
90
  if device == "cuda":
91
  try:
92
  import torch
93
+ if torch.cuda.is_available():
94
+ try:
95
+ # Verify GPU actually works
96
+ gpu_name = torch.cuda.get_device_name(0)
97
+ print(f"✅ Using GPU: {gpu_name}")
98
+ except Exception as e:
99
+ print(f"⚠️ GPU access failed: {e}, using CPU")
100
+ device = "cpu"
101
+ else:
102
  device = "cpu"
103
  print("⚠️ CUDA not available, using CPU")
104
+ except ImportError:
105
+ device = "cpu"
106
+ print("⚠️ PyTorch not available, using CPU")
107
+ except Exception as e:
108
  device = "cpu"
109
+ print(f"⚠️ GPU check error: {e}, using CPU")
110
+
111
+ print(f"🔧 LM Student device: {device}")
112
 
113
  student = LMStudentAgent(
114
  learning_rate=5e-5, # LM fine-tuning learning rate