CorneliusWang commited on
Commit
d06d2e6
·
verified ·
1 Parent(s): eb20fef

Update teacher_agent_dev/compare_strategies.py

Browse files
Files changed (1) hide show
  1. teacher_agent_dev/compare_strategies.py +131 -425
teacher_agent_dev/compare_strategies.py CHANGED
@@ -9,6 +9,8 @@ Uses LM Student (DistilBERT) instead of MockStudentAgent.
9
 
10
  import sys
11
  import os
 
 
12
  from pathlib import Path
13
 
14
  # Add student_agent_dev to path for LM student import
@@ -46,9 +48,6 @@ from train_teacher import train_teacher
46
  def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_questions: int = 20) -> float:
47
  """
48
  Evaluate student on difficult questions from all topics.
49
-
50
- Returns:
51
- Accuracy on difficult questions (0.0 to 1.0)
52
  """
53
  topics = generator.get_available_topics()
54
  eval_tasks = []
@@ -65,82 +64,58 @@ def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_ques
65
  def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accuracy: float = 0.75) -> Dict:
66
  """
67
  Strategy 1: Random questions until student can confidently pass difficult questions.
68
-
69
- Selection strategy:
70
- - Randomly chooses a topic (uniform across all topics)
71
- - Randomly chooses a difficulty (uniform across all difficulties)
72
- - No curriculum structure - completely random
73
-
74
- Args:
75
- num_iterations: Maximum iterations to train
76
- seed: Random seed
77
- target_accuracy: Target accuracy on difficult questions to consider "passing"
78
-
79
- Returns:
80
- Training history dictionary
81
  """
82
- import random
 
 
 
83
  rng = random.Random(seed)
84
 
85
- # Use LM Student instead of MockStudentAgent
86
- # LM Student uses retention_constant instead of forgetting_rate (higher = slower forgetting)
87
- # retention_constant=80.0 means ~80% retention after 1 time unit
88
- # Get device from environment or default to cpu
89
  device = os.environ.get("CUDA_DEVICE", "cpu")
90
  if device == "cuda":
91
  try:
92
  import torch
93
  if torch.cuda.is_available():
94
- try:
95
- # Verify GPU actually works
96
- gpu_name = torch.cuda.get_device_name(0)
97
- print(f"✅ Using GPU: {gpu_name}")
98
- except Exception as e:
99
- print(f"⚠️ GPU access failed: {e}, using CPU")
100
- device = "cpu"
101
  else:
102
  device = "cpu"
103
- print("⚠️ CUDA not available, using CPU")
104
- except ImportError:
105
- device = "cpu"
106
- print("⚠️ PyTorch not available, using CPU")
107
- except Exception as e:
108
  device = "cpu"
109
- print(f"⚠️ GPU check error: {e}, using CPU")
110
 
111
  print(f"🔧 LM Student device: {device}")
112
 
113
  student = LMStudentAgent(
114
- learning_rate=5e-5, # LM fine-tuning learning rate
115
- retention_constant=80.0, # Slower forgetting than mock student
116
- device=device, # Use GPU if available
117
  max_length=256,
118
  gradient_accumulation_steps=4
119
  ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
120
- generator = MockTaskGenerator(seed=seed)
 
 
121
 
122
  topics = generator.get_available_topics()
123
  difficulties = generator.get_available_difficulties()
124
 
125
  # Evaluation on difficult questions - CREATE FIXED SET ONCE
126
- # Use 'expert' or 'master' for truly difficult questions (with expanded difficulty levels)
127
  hard_eval_tasks = []
128
- eval_difficulty = 'expert' if 'expert' in difficulties else 'hard' # Use expert level for challenging eval
129
  for topic in topics:
130
- for _ in range(5): # 5 difficult questions per topic
131
  hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
132
 
133
- # Create FIXED general eval set (medium difficulty, all topics)
134
  general_eval_tasks = [
135
  generator.generate_task(topic, 'medium')
136
  for topic in topics
137
- for _ in range(3) # 3 tasks per topic
138
  ]
139
 
140
  history = {
141
  'iterations': [],
142
  'student_accuracies': [],
143
- 'difficult_accuracies': [], # Accuracy on hard questions
144
  'teacher_rewards': [],
145
  'topics': [],
146
  'difficulties': [],
@@ -152,25 +127,19 @@ def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accu
152
  iterator = tqdm(iterator, desc="Random Strategy", unit="iter")
153
 
154
  for iteration in iterator:
155
- # Random strategy: choose random topic AND random difficulty independently
156
- topic = rng.choice(topics) # Random topic
157
- difficulty = rng.choice(difficulties) # Random difficulty
158
 
159
  task = generator.generate_task(topic, difficulty)
160
 
161
- # Evaluate before learning
162
  accuracy_before = student.evaluate(hard_eval_tasks)
163
-
164
- # Student learns
165
  student.learn(task)
166
 
167
- # Evaluate after learning (BEFORE time advance for accurate snapshot)
168
  accuracy_after = student.evaluate(hard_eval_tasks)
169
- general_accuracy = student.evaluate(general_eval_tasks) # Use FIXED eval set
170
 
171
  student.advance_time(1.0)
172
 
173
- # Track metrics
174
  history['iterations'].append(iteration)
175
  history['student_accuracies'].append(general_accuracy)
176
  history['difficult_accuracies'].append(accuracy_after)
@@ -178,8 +147,7 @@ def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accu
178
  history['topics'].append(topic)
179
  history['difficulties'].append(difficulty)
180
 
181
- # Check if we've reached target (optional early stopping)
182
- if accuracy_after >= target_accuracy and iteration > 50: # Give at least 50 iterations
183
  if 'reached_target' not in locals():
184
  print(f" Random strategy reached target accuracy {target_accuracy:.2f} at iteration {iteration}")
185
  reached_target = True
@@ -190,20 +158,10 @@ def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accu
190
  def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dict:
191
  """
192
  Strategy 2: Progressive difficulty within each family.
193
- Easy → Medium → Hard for each topic, then move to next topic.
194
-
195
- Args:
196
- num_iterations: Number of iterations
197
- seed: Random seed
198
-
199
- Returns:
200
- Training history dictionary
201
  """
202
- # Reduce forgetting rate OR use periodic time reset for long training
203
- # Option 1: Lower forgetting rate (better for long training)
204
- # Option 2: Reset time periodically (keeps forgetting realistic but prevents complete loss)
205
- # Using Option 1: lower forgetting rate
206
- # Use LM Student instead of MockStudentAgent
207
  student = LMStudentAgent(
208
  learning_rate=5e-5,
209
  retention_constant=80.0,
@@ -211,26 +169,24 @@ def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dic
211
  max_length=256,
212
  gradient_accumulation_steps=4
213
  ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
214
- generator = MockTaskGenerator(seed=seed)
 
 
215
 
216
  topics = generator.get_available_topics()
217
  all_difficulties = generator.get_available_difficulties()
218
- # Progressive: use all difficulties in order
219
- difficulties = all_difficulties # Use all 7 difficulty levels
220
 
221
- # Evaluation on difficult questions - CREATE FIXED SET ONCE
222
- # Use 'expert' or 'master' for truly difficult questions
223
  hard_eval_tasks = []
224
  eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
225
  for topic in topics:
226
  for _ in range(5):
227
  hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
228
 
229
- # Create FIXED general eval set (medium difficulty, all topics)
230
  general_eval_tasks = [
231
  generator.generate_task(topic, 'medium')
232
  for topic in topics
233
- for _ in range(3) # 3 tasks per topic
234
  ]
235
 
236
  history = {
@@ -243,8 +199,6 @@ def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dic
243
  'strategy': 'progressive'
244
  }
245
 
246
- # Progressive curriculum: cycle through topics, increase difficulty over time
247
- # Structure: For each topic, do easy → medium → hard
248
  questions_per_difficulty = max(1, num_iterations // (len(topics) * len(difficulties)))
249
 
250
  iterator = range(num_iterations)
@@ -252,7 +206,6 @@ def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dic
252
  iterator = tqdm(iterator, desc="Progressive Strategy", unit="iter")
253
 
254
  for iteration in iterator:
255
- # Determine current phase
256
  phase = iteration // questions_per_difficulty if questions_per_difficulty > 0 else iteration
257
  topic_idx = (phase // len(difficulties)) % len(topics)
258
  diff_idx = phase % len(difficulties)
@@ -262,19 +215,14 @@ def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dic
262
 
263
  task = generator.generate_task(topic, difficulty)
264
 
265
- # Evaluate before learning
266
  accuracy_before = student.evaluate(hard_eval_tasks)
267
-
268
- # Student learns
269
  student.learn(task)
270
 
271
- # Evaluate after learning (BEFORE time advance for accurate snapshot)
272
  accuracy_after = student.evaluate(hard_eval_tasks)
273
- general_accuracy = student.evaluate(general_eval_tasks) # Use FIXED eval set
274
 
275
  student.advance_time(1.0)
276
 
277
- # Track metrics
278
  history['iterations'].append(iteration)
279
  history['student_accuracies'].append(general_accuracy)
280
  history['difficult_accuracies'].append(accuracy_after)
@@ -288,18 +236,15 @@ def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dic
288
  def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
289
  """
290
  Strategy 3: RL Teacher Agent learns optimal curriculum.
 
 
 
291
 
292
- Args:
293
- num_iterations: Number of iterations
294
- seed: Random seed
 
295
 
296
- Returns:
297
- Training history dictionary with difficult_accuracies added
298
- """
299
- # Initialize components
300
- generator = MockTaskGenerator(seed=seed)
301
- teacher = TeacherAgent(exploration_bonus=2.0, task_generator=generator) # Dynamic action space
302
- # Use LM Student instead of MockStudentAgent
303
  student = LMStudentAgent(
304
  learning_rate=5e-5,
305
  retention_constant=80.0,
@@ -310,14 +255,12 @@ def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
310
 
311
  topics = generator.get_available_topics()
312
 
313
- # Create evaluation sets
314
  eval_tasks = [
315
  generator.generate_task(topic, 'medium')
316
  for topic in topics
317
  for _ in range(3)
318
  ]
319
 
320
- # Create difficult question evaluation set - use expert/master level
321
  all_difficulties = generator.get_available_difficulties()
322
  eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
323
  hard_eval_tasks = [
@@ -326,7 +269,6 @@ def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
326
  for _ in range(5)
327
  ]
328
 
329
- # Track metrics
330
  history = {
331
  'iterations': [],
332
  'student_accuracies': [],
@@ -344,30 +286,22 @@ def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
344
  iterator = tqdm(iterator, desc="Teacher Strategy", unit="iter")
345
 
346
  for iteration in iterator:
347
- # 1. Get student state
348
  student_state = student.get_state()
349
-
350
- # 2. Teacher selects action
351
  action = teacher.select_action(student_state)
352
 
353
- # 3. Generate task
354
  if action.is_review:
355
  task = generator.generate_task(action.topic, 'medium')
356
  else:
357
  task = generator.generate_task(action.topic, action.difficulty)
358
 
359
- # 4. Evaluate student BEFORE learning
360
  accuracy_before = student.evaluate(eval_tasks)
361
  difficult_acc_before = student.evaluate(hard_eval_tasks)
362
 
363
- # 5. Student learns from task
364
  student.learn(task)
365
 
366
- # 6. Evaluate student AFTER learning
367
  accuracy_after = student.evaluate(eval_tasks)
368
  difficult_acc_after = student.evaluate(hard_eval_tasks)
369
 
370
- # 7. Compute reward for teacher
371
  reward = compute_reward(
372
  accuracy_before,
373
  accuracy_after,
@@ -375,13 +309,9 @@ def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
375
  action.is_review
376
  )
377
 
378
- # 8. Update teacher's policy
379
  teacher.update(action, reward)
380
-
381
- # 9. Time passes (for forgetting)
382
  student.advance_time(1.0)
383
 
384
- # 10. Log metrics
385
  history['iterations'].append(iteration)
386
  history['student_accuracies'].append(accuracy_after)
387
  history['difficult_accuracies'].append(difficult_acc_after)
@@ -397,231 +327,116 @@ def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
397
  def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_dev/comparison_all_strategies.png'):
398
  """
399
  Create comprehensive comparison plots of all three strategies.
400
-
401
- Args:
402
- histories: Dictionary mapping strategy name to history
403
- e.g., {'Random': history1, 'Progressive': history2, 'Teacher': history3}
404
- save_path: Where to save the plot
405
  """
406
  import matplotlib.pyplot as plt
407
 
 
 
 
408
  fig, axes = plt.subplots(4, 1, figsize=(16, 14))
409
 
410
- # Define colors and styles for each strategy
411
  colors = {
412
  'Random': '#FF6B6B', # Red
413
  'Progressive': '#4ECDC4', # Teal
414
- 'Teacher': '#2ECC71' # Green (highlight teacher as best)
415
  }
416
 
417
  line_styles = {
418
- 'Random': '--', # Dashed = stochastic/erratic
419
- 'Progressive': '-.', # Dash-dot = linear/rigid
420
- 'Teacher': '-' # Solid = smooth/exponential
421
  }
422
 
423
  line_widths = {
424
  'Random': 2.0,
425
  'Progressive': 2.0,
426
- 'Teacher': 3.5 # Much thicker line for teacher to emphasize exponential growth
427
  }
428
 
429
- # 1. Plot 1: General Accuracy Over Time - Emphasize Exponential vs Stochastic
430
  ax = axes[0]
431
-
432
- # Plot raw data with different styles to show stochasticity vs smoothness
433
  for name, history in histories.items():
434
  iterations = history['iterations']
435
  accuracies = history['student_accuracies']
436
 
437
- if name == 'Teacher':
438
- # Teacher: Show exponential growth clearly with smooth curve
439
- # Less smoothing to show actual exponential curve
440
- window = 10 if len(accuracies) > 50 else 5
441
  smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same')
442
  ax.plot(iterations, smoothed,
443
- label=f'{name} (Exponential Growth)',
444
- color=colors[name],
445
- linestyle=line_styles[name],
446
- linewidth=line_widths[name],
447
- alpha=0.95,
448
- zorder=10) # On top
449
  else:
450
- # Random/Progressive: Show stochastic/erratic nature
451
- # Plot raw noisy data with some transparency to show variance
452
- if len(accuracies) > 50:
453
- # Show variance with raw data (more stochastic)
454
- ax.plot(iterations, accuracies,
455
- label=f'{name} (Stochastic/Erratic)',
456
- color=colors[name],
457
- linestyle=line_styles[name],
458
- linewidth=line_widths[name],
459
- alpha=0.4, # Lighter to show noise
460
- zorder=1)
461
- # Overlay smoothed version
462
- window = 30
463
- smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same')
464
- ax.plot(iterations, smoothed,
465
- color=colors[name],
466
- linestyle=line_styles[name],
467
- linewidth=line_widths[name],
468
- alpha=0.8)
469
- else:
470
- ax.plot(iterations, accuracies,
471
- label=f'{name} (Stochastic)',
472
- color=colors[name],
473
- linestyle=line_styles[name],
474
- linewidth=line_widths[name],
475
- alpha=0.8)
476
-
477
- ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold')
478
- ax.set_ylabel('General Accuracy', fontsize=12, fontweight='bold')
479
- ax.set_title('Learning Curves: Exponential (Teacher) vs Stochastic (Baselines)', fontsize=14, fontweight='bold')
480
- ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
481
- ax.grid(True, alpha=0.3, linestyle='--')
482
- ax.set_ylim([0.2, 1.0])
483
-
484
- # Add text annotation highlighting exponential vs stochastic
485
- ax.text(0.02, 0.98,
486
- '📈 Teacher: Smooth exponential growth\n📉 Baselines: Erratic, stochastic learning',
487
- transform=ax.transAxes,
488
- fontsize=10,
489
- verticalalignment='top',
490
- bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
491
-
492
- # Add final accuracy annotations
493
- for name, history in histories.items():
494
- final_acc = history['student_accuracies'][-1]
495
- final_iter = history['iterations'][-1]
496
- ax.annotate(f'{final_acc:.3f}',
497
- xy=(final_iter, final_acc),
498
- xytext=(10, 10),
499
- textcoords='offset points',
500
- fontsize=10,
501
- bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[name], alpha=0.5),
502
- arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
503
-
504
- # 2. Plot 2: Difficult Question Accuracy - Show Exponential Growth Clearly
505
  ax = axes[1]
506
-
507
  for name, history in histories.items():
508
  iterations = history['iterations']
509
  difficult_accuracies = history['difficult_accuracies']
510
 
511
- if name == 'Teacher':
512
- # Teacher: Emphasize exponential growth
513
- window = 8 # Less smoothing to show exponential shape
514
  smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same')
515
  ax.plot(iterations, smoothed,
516
- label=f'{name} (Exponential)',
517
- color=colors[name],
518
- linestyle=line_styles[name],
519
- linewidth=line_widths[name],
520
- alpha=0.95,
521
- zorder=10)
522
  else:
523
- # Baselines: Show stochastic nature
524
- if len(difficult_accuracies) > 50:
525
- # Show raw noisy data
526
- ax.plot(iterations, difficult_accuracies,
527
- label=f'{name} (Erratic)',
528
- color=colors[name],
529
- linestyle=line_styles[name],
530
- linewidth=line_widths[name],
531
- alpha=0.3,
532
- zorder=1)
533
- # Overlay smoothed
534
- window = 25
535
- smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same')
536
- ax.plot(iterations, smoothed,
537
- color=colors[name],
538
- linestyle=line_styles[name],
539
- linewidth=line_widths[name],
540
- alpha=0.8)
541
- else:
542
- ax.plot(iterations, difficult_accuracies,
543
- label=name,
544
- color=colors[name],
545
- linestyle=line_styles[name],
546
- linewidth=line_widths[name],
547
- alpha=0.8)
548
-
549
- ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold')
550
- ax.set_ylabel('Accuracy on Difficult Questions', fontsize=12, fontweight='bold')
551
- ax.set_title('Difficult Question Performance: Exponential vs Stochastic Learning',
552
- fontsize=14, fontweight='bold', color='darkred')
553
- ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
554
- ax.grid(True, alpha=0.3, linestyle='--')
555
- ax.set_ylim([0.2, 1.0])
556
-
557
- # Highlight target accuracy line (75%)
558
- ax.axhline(y=0.75, color='gray', linestyle=':', linewidth=1, alpha=0.5)
559
-
560
- # Add final accuracy annotations
561
- for name, history in histories.items():
562
- final_acc = history['difficult_accuracies'][-1]
563
- final_iter = history['iterations'][-1]
564
- ax.annotate(f'{final_acc:.3f}',
565
- xy=(final_iter, final_acc),
566
- xytext=(10, 10),
567
- textcoords='offset points',
568
- fontsize=10,
569
- bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[name], alpha=0.3),
570
- arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
571
-
572
- # 3. Plot 3: Curriculum Efficiency - Topic Coverage Over Time
573
  ax = axes[2]
574
-
575
- # Track unique topics seen over time to show curriculum diversity
576
  for name, history in histories.items():
577
  iterations = history['iterations']
578
  topics_seen = history['topics']
579
 
580
- # Count unique topics up to each iteration
581
  unique_topics = []
582
  seen_so_far = set()
583
-
584
  for topic in topics_seen:
585
  seen_so_far.add(topic)
586
  unique_topics.append(len(seen_so_far))
587
 
588
- if name == 'Teacher':
589
- ax.plot(iterations, unique_topics,
590
- label=f'{name} (Diverse Curriculum)',
591
- color=colors[name],
592
- linestyle=line_styles[name],
593
- linewidth=line_widths[name],
594
- alpha=0.9,
595
- zorder=10,
596
- marker='o', markersize=3)
597
- else:
598
- ax.plot(iterations, unique_topics,
599
- label=f'{name}',
600
- color=colors[name],
601
- linestyle=line_styles[name],
602
- linewidth=line_widths[name],
603
- alpha=0.8,
604
- marker='s', markersize=2)
605
-
606
- ax.set_xlabel('Training Iteration', fontsize=12, fontweight='bold')
607
- ax.set_ylabel('Number of Unique Topics Covered', fontsize=12, fontweight='bold')
608
- ax.set_title('Curriculum Diversity: Topic Coverage Over Time',
609
- fontsize=14, fontweight='bold')
610
- ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
611
- ax.grid(True, alpha=0.3, linestyle='--')
612
-
613
- # Add total topics line if available
614
- if histories:
615
- first_history = list(histories.values())[0]
616
- if 'topics' in first_history and first_history['topics']:
617
- all_unique_topics = len(set(first_history['topics']))
618
- ax.axhline(y=all_unique_topics, color='gray', linestyle=':',
619
- alpha=0.5, label=f'Total topics: {all_unique_topics}')
620
- ax.legend(loc='lower right', fontsize=11, framealpha=0.9)
621
-
622
- # 4. Plot 4: Learning Speed Comparison (Iterations to reach 75% on difficult)
623
- ax = axes[3]
624
 
 
 
625
  target_acc = 0.75
626
  strategy_stats = {}
627
 
@@ -629,7 +444,6 @@ def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_
629
  difficult_accuracies = history['difficult_accuracies']
630
  iterations = history['iterations']
631
 
632
- # Find when target is reached
633
  reached_target = False
634
  target_iteration = len(iterations) - 1
635
 
@@ -645,7 +459,6 @@ def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_
645
  'final_acc': difficult_accuracies[-1]
646
  }
647
 
648
- # Create bar plot
649
  names = list(strategy_stats.keys())
650
  iterations_to_target = [
651
  strategy_stats[n]['iteration'] if strategy_stats[n]['reached'] else len(histories[n]['iterations'])
@@ -656,169 +469,62 @@ def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_
656
  x = np.arange(len(names))
657
  width = 0.35
658
 
659
- bars1 = ax.bar(x - width/2, iterations_to_target, width, label='Iterations to 75% on Difficult',
660
- color=[colors[n] for n in names], alpha=0.7)
661
- bars2 = ax.bar(x + width/2, [acc * max(iterations_to_target) for acc in final_accs], width,
662
- label='Final Difficult Accuracy (scaled)',
663
- color=[colors[n] for n in names], alpha=0.5)
664
 
665
- ax.set_xlabel('Strategy', fontsize=12, fontweight='bold')
666
- ax.set_ylabel('Iterations / Scaled Accuracy', fontsize=12, fontweight='bold')
667
- ax.set_title('Learning Efficiency: Iterations to Reach Target vs Final Performance',
668
- fontsize=14, fontweight='bold')
669
  ax.set_xticks(x)
670
  ax.set_xticklabels(names)
671
- ax.legend(fontsize=10, framealpha=0.9)
672
- ax.grid(True, alpha=0.3, linestyle='--', axis='y')
673
-
674
- # Add value labels on bars
675
- for i, (bar1, bar2, name) in enumerate(zip(bars1, bars2, names)):
676
- height1 = bar1.get_height()
677
- height2 = bar2.get_height()
678
-
679
- # Label for iterations
680
- if strategy_stats[name]['reached']:
681
- ax.text(bar1.get_x() + bar1.get_width()/2., height1,
682
- f'{int(height1)}',
683
- ha='center', va='bottom', fontsize=9, fontweight='bold')
684
- else:
685
- ax.text(bar1.get_x() + bar1.get_width()/2., height1,
686
- 'Not reached',
687
- ha='center', va='bottom', fontsize=9, fontweight='bold')
688
-
689
- # Label for final accuracy
690
- ax.text(bar2.get_x() + bar2.get_width()/2., height2,
691
- f'{final_accs[i]:.2f}',
692
- ha='center', va='bottom', fontsize=9, fontweight='bold')
693
 
694
  plt.tight_layout()
695
- plt.savefig(save_path, dpi=150, bbox_inches='tight')
696
  print(f"\n✅ Saved comparison plot to {save_path}")
697
  plt.close()
698
-
699
- # Print summary statistics
700
- print("\n" + "=" * 70)
701
- print("STRATEGY COMPARISON SUMMARY")
702
- print("=" * 70)
703
- for name, stats in strategy_stats.items():
704
- status = "✅ Reached" if stats['reached'] else "❌ Not reached"
705
- print(f"{name:15s} | {status:15s} | Iterations: {stats['iteration']:4d} | Final Acc: {stats['final_acc']:.3f}")
706
- print("=" * 70)
707
 
708
 
709
  if __name__ == "__main__":
710
  import argparse
711
  import time
712
 
713
- parser = argparse.ArgumentParser(description='Compare training strategies with configurable randomness')
714
- parser.add_argument('--seed', type=int, default=None,
715
- help='Random seed for reproducibility (default: None = use current time)')
716
- parser.add_argument('--iterations', type=int, default=500,
717
- help='Number of training iterations (default: 500)')
718
- parser.add_argument('--deterministic', action='store_true',
719
- help='Use fixed seed=42 for reproducible results (deterministic)')
720
- parser.add_argument('--runs', type=int, default=1,
721
- help='Number of runs for variance analysis (default: 1)')
722
 
723
  args = parser.parse_args()
724
 
725
- # Determine seed
726
  if args.deterministic:
727
  seed = 42
728
- print("⚠️ Using deterministic mode (seed=42) - results will be identical every run")
729
  elif args.seed is not None:
730
  seed = args.seed
731
- print(f"Using specified seed: {seed}")
732
  else:
733
- seed = int(time.time()) % 10000 # Use current time as seed
734
- print(f"Using random seed: {seed} (results will vary each run)")
 
735
 
736
  num_iterations = args.iterations
737
 
738
- print("=" * 70)
739
- print("COMPARING THREE TRAINING STRATEGIES")
740
- print("=" * 70)
741
- print("\n1. Random: Random questions until student can pass difficult")
742
- print("2. Progressive: Easy → Medium → Hard within each family")
743
- print("3. Teacher: RL teacher agent learns optimal curriculum")
744
- print("\n" + "=" * 70 + "\n")
745
-
746
- # Run multiple times for variance analysis if requested
747
- if args.runs > 1:
748
- print(f"Running {args.runs} times for variance analysis...\n")
749
- all_results = {
750
- 'Random': [],
751
- 'Progressive': [],
752
- 'Teacher': []
753
- }
754
-
755
- for run in range(args.runs):
756
- run_seed = seed + run # Different seed for each run
757
- print(f"Run {run + 1}/{args.runs} (seed={run_seed})...")
758
-
759
- history_random = train_strategy_random(num_iterations=num_iterations, seed=run_seed)
760
- history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=run_seed)
761
- history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=run_seed)
762
-
763
- all_results['Random'].append(history_random)
764
- all_results['Progressive'].append(history_progressive)
765
- all_results['Teacher'].append(history_teacher)
766
-
767
- # Compute statistics across runs
768
- print("\n" + "=" * 70)
769
- print("VARIANCE ANALYSIS ACROSS RUNS")
770
- print("=" * 70)
771
-
772
- for strategy_name in ['Random', 'Progressive', 'Teacher']:
773
- final_accs = [h['difficult_accuracies'][-1] for h in all_results[strategy_name]]
774
- iterations_to_target = []
775
- for h in all_results[strategy_name]:
776
- target_acc = 0.75
777
- reached = False
778
- for i, acc in enumerate(h['difficult_accuracies']):
779
- if acc >= target_acc:
780
- iterations_to_target.append(i)
781
- reached = True
782
- break
783
- if not reached:
784
- iterations_to_target.append(len(h['difficult_accuracies']))
785
-
786
- mean_final = np.mean(final_accs)
787
- std_final = np.std(final_accs)
788
- mean_iters = np.mean(iterations_to_target)
789
- std_iters = np.std(iterations_to_target)
790
-
791
- print(f"\n{strategy_name}:")
792
- print(f" Final Accuracy: {mean_final:.3f} ± {std_final:.3f} (range: {min(final_accs):.3f} - {max(final_accs):.3f})")
793
- print(f" Iterations to Target: {mean_iters:.1f} ± {std_iters:.1f} (range: {min(iterations_to_target)} - {max(iterations_to_target)})")
794
-
795
- # Use first run for plotting (or could average)
796
- history_random = all_results['Random'][0]
797
- history_progressive = all_results['Progressive'][0]
798
- history_teacher = all_results['Teacher'][0]
799
- else:
800
- # Single run
801
- # Train all three strategies
802
- print("Training Random Strategy...")
803
- history_random = train_strategy_random(num_iterations=num_iterations, seed=seed)
804
-
805
- print("\nTraining Progressive Strategy...")
806
- history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=seed)
807
-
808
- print("\nTraining Teacher Strategy...")
809
- history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=seed)
810
 
811
- # Create comparison plots
812
- print("\nGenerating comparison plots...")
813
  histories = {
814
  'Random': history_random,
815
  'Progressive': history_progressive,
816
  'Teacher': history_teacher
817
  }
818
 
819
- plot_comparison(histories, save_path='comparison_all_strategies.png')
820
-
821
- print("\n✅ Comparison complete! Check 'comparison_all_strategies.png'")
822
- if not args.deterministic and args.seed is None:
823
- print(f"💡 Tip: Results vary each run. Use --deterministic for reproducible results, or --seed <N> for specific seed.")
824
-
 
9
 
10
  import sys
11
  import os
12
+ import random # Added for global seeding
13
+ import numpy as np # Added for global seeding
14
  from pathlib import Path
15
 
16
  # Add student_agent_dev to path for LM student import
 
48
  def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_questions: int = 20) -> float:
49
  """
50
  Evaluate student on difficult questions from all topics.
 
 
 
51
  """
52
  topics = generator.get_available_topics()
53
  eval_tasks = []
 
64
  def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accuracy: float = 0.75) -> Dict:
65
  """
66
  Strategy 1: Random questions until student can confidently pass difficult questions.
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  """
68
+ # Set global seeds to ensure MockTaskGenerator behaves deterministically
69
+ random.seed(seed)
70
+ np.random.seed(seed)
71
+
72
  rng = random.Random(seed)
73
 
 
 
 
 
74
  device = os.environ.get("CUDA_DEVICE", "cpu")
75
  if device == "cuda":
76
  try:
77
  import torch
78
  if torch.cuda.is_available():
79
+ print(f"✅ Using GPU: {torch.cuda.get_device_name(0)}")
 
 
 
 
 
 
80
  else:
81
  device = "cpu"
82
+ except:
 
 
 
 
83
  device = "cpu"
 
84
 
85
  print(f"🔧 LM Student device: {device}")
86
 
87
  student = LMStudentAgent(
88
+ learning_rate=5e-5,
89
+ retention_constant=80.0,
90
+ device=device,
91
  max_length=256,
92
  gradient_accumulation_steps=4
93
  ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
94
+
95
+ # --- FIX 1: REMOVED seed=seed ---
96
+ generator = MockTaskGenerator()
97
 
98
  topics = generator.get_available_topics()
99
  difficulties = generator.get_available_difficulties()
100
 
101
  # Evaluation on difficult questions - CREATE FIXED SET ONCE
 
102
  hard_eval_tasks = []
103
+ eval_difficulty = 'expert' if 'expert' in difficulties else 'hard'
104
  for topic in topics:
105
+ for _ in range(5):
106
  hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
107
 
108
+ # Create FIXED general eval set
109
  general_eval_tasks = [
110
  generator.generate_task(topic, 'medium')
111
  for topic in topics
112
+ for _ in range(3)
113
  ]
114
 
115
  history = {
116
  'iterations': [],
117
  'student_accuracies': [],
118
+ 'difficult_accuracies': [],
119
  'teacher_rewards': [],
120
  'topics': [],
121
  'difficulties': [],
 
127
  iterator = tqdm(iterator, desc="Random Strategy", unit="iter")
128
 
129
  for iteration in iterator:
130
+ topic = rng.choice(topics)
131
+ difficulty = rng.choice(difficulties)
 
132
 
133
  task = generator.generate_task(topic, difficulty)
134
 
 
135
  accuracy_before = student.evaluate(hard_eval_tasks)
 
 
136
  student.learn(task)
137
 
 
138
  accuracy_after = student.evaluate(hard_eval_tasks)
139
+ general_accuracy = student.evaluate(general_eval_tasks)
140
 
141
  student.advance_time(1.0)
142
 
 
143
  history['iterations'].append(iteration)
144
  history['student_accuracies'].append(general_accuracy)
145
  history['difficult_accuracies'].append(accuracy_after)
 
147
  history['topics'].append(topic)
148
  history['difficulties'].append(difficulty)
149
 
150
+ if accuracy_after >= target_accuracy and iteration > 50:
 
151
  if 'reached_target' not in locals():
152
  print(f" Random strategy reached target accuracy {target_accuracy:.2f} at iteration {iteration}")
153
  reached_target = True
 
158
  def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dict:
159
  """
160
  Strategy 2: Progressive difficulty within each family.
 
 
 
 
 
 
 
 
161
  """
162
+ random.seed(seed)
163
+ np.random.seed(seed)
164
+
 
 
165
  student = LMStudentAgent(
166
  learning_rate=5e-5,
167
  retention_constant=80.0,
 
169
  max_length=256,
170
  gradient_accumulation_steps=4
171
  ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
172
+
173
+ # --- FIX 2: REMOVED seed=seed ---
174
+ generator = MockTaskGenerator()
175
 
176
  topics = generator.get_available_topics()
177
  all_difficulties = generator.get_available_difficulties()
178
+ difficulties = all_difficulties
 
179
 
 
 
180
  hard_eval_tasks = []
181
  eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
182
  for topic in topics:
183
  for _ in range(5):
184
  hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
185
 
 
186
  general_eval_tasks = [
187
  generator.generate_task(topic, 'medium')
188
  for topic in topics
189
+ for _ in range(3)
190
  ]
191
 
192
  history = {
 
199
  'strategy': 'progressive'
200
  }
201
 
 
 
202
  questions_per_difficulty = max(1, num_iterations // (len(topics) * len(difficulties)))
203
 
204
  iterator = range(num_iterations)
 
206
  iterator = tqdm(iterator, desc="Progressive Strategy", unit="iter")
207
 
208
  for iteration in iterator:
 
209
  phase = iteration // questions_per_difficulty if questions_per_difficulty > 0 else iteration
210
  topic_idx = (phase // len(difficulties)) % len(topics)
211
  diff_idx = phase % len(difficulties)
 
215
 
216
  task = generator.generate_task(topic, difficulty)
217
 
 
218
  accuracy_before = student.evaluate(hard_eval_tasks)
 
 
219
  student.learn(task)
220
 
 
221
  accuracy_after = student.evaluate(hard_eval_tasks)
222
+ general_accuracy = student.evaluate(general_eval_tasks)
223
 
224
  student.advance_time(1.0)
225
 
 
226
  history['iterations'].append(iteration)
227
  history['student_accuracies'].append(general_accuracy)
228
  history['difficult_accuracies'].append(accuracy_after)
 
236
  def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
237
  """
238
  Strategy 3: RL Teacher Agent learns optimal curriculum.
239
+ """
240
+ random.seed(seed)
241
+ np.random.seed(seed)
242
 
243
+ # --- FIX 3: REMOVED seed=seed ---
244
+ generator = MockTaskGenerator()
245
+
246
+ teacher = TeacherAgent(exploration_bonus=2.0, task_generator=generator)
247
 
 
 
 
 
 
 
 
248
  student = LMStudentAgent(
249
  learning_rate=5e-5,
250
  retention_constant=80.0,
 
255
 
256
  topics = generator.get_available_topics()
257
 
 
258
  eval_tasks = [
259
  generator.generate_task(topic, 'medium')
260
  for topic in topics
261
  for _ in range(3)
262
  ]
263
 
 
264
  all_difficulties = generator.get_available_difficulties()
265
  eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
266
  hard_eval_tasks = [
 
269
  for _ in range(5)
270
  ]
271
 
 
272
  history = {
273
  'iterations': [],
274
  'student_accuracies': [],
 
286
  iterator = tqdm(iterator, desc="Teacher Strategy", unit="iter")
287
 
288
  for iteration in iterator:
 
289
  student_state = student.get_state()
 
 
290
  action = teacher.select_action(student_state)
291
 
 
292
  if action.is_review:
293
  task = generator.generate_task(action.topic, 'medium')
294
  else:
295
  task = generator.generate_task(action.topic, action.difficulty)
296
 
 
297
  accuracy_before = student.evaluate(eval_tasks)
298
  difficult_acc_before = student.evaluate(hard_eval_tasks)
299
 
 
300
  student.learn(task)
301
 
 
302
  accuracy_after = student.evaluate(eval_tasks)
303
  difficult_acc_after = student.evaluate(hard_eval_tasks)
304
 
 
305
  reward = compute_reward(
306
  accuracy_before,
307
  accuracy_after,
 
309
  action.is_review
310
  )
311
 
 
312
  teacher.update(action, reward)
 
 
313
  student.advance_time(1.0)
314
 
 
315
  history['iterations'].append(iteration)
316
  history['student_accuracies'].append(accuracy_after)
317
  history['difficult_accuracies'].append(difficult_acc_after)
 
327
  def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_dev/comparison_all_strategies.png'):
328
  """
329
  Create comprehensive comparison plots of all three strategies.
 
 
 
 
 
330
  """
331
  import matplotlib.pyplot as plt
332
 
333
+ # Ensure directory exists
334
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
335
+
336
  fig, axes = plt.subplots(4, 1, figsize=(16, 14))
337
 
 
338
  colors = {
339
  'Random': '#FF6B6B', # Red
340
  'Progressive': '#4ECDC4', # Teal
341
+ 'Teacher': '#2ECC71' # Green
342
  }
343
 
344
  line_styles = {
345
+ 'Random': '--',
346
+ 'Progressive': '-.',
347
+ 'Teacher': '-'
348
  }
349
 
350
  line_widths = {
351
  'Random': 2.0,
352
  'Progressive': 2.0,
353
+ 'Teacher': 3.5
354
  }
355
 
356
+ # 1. Plot 1: General Accuracy
357
  ax = axes[0]
 
 
358
  for name, history in histories.items():
359
  iterations = history['iterations']
360
  accuracies = history['student_accuracies']
361
 
362
+ if len(accuracies) > 50:
363
+ # Smooth curves
364
+ window = 10
 
365
  smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same')
366
  ax.plot(iterations, smoothed,
367
+ label=name,
368
+ color=colors[name],
369
+ linestyle=line_styles[name],
370
+ linewidth=line_widths[name],
371
+ alpha=0.9)
 
372
  else:
373
+ ax.plot(iterations, accuracies,
374
+ label=name,
375
+ color=colors[name],
376
+ linestyle=line_styles[name],
377
+ linewidth=line_widths[name])
378
+
379
+ ax.set_xlabel('Training Iteration')
380
+ ax.set_ylabel('General Accuracy')
381
+ ax.set_title('Learning Curves')
382
+ ax.legend(loc='lower right')
383
+ ax.grid(True, alpha=0.3)
384
+ ax.set_ylim([0.0, 1.0])
385
+
386
+ # 2. Plot 2: Difficult Question Accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  ax = axes[1]
 
388
  for name, history in histories.items():
389
  iterations = history['iterations']
390
  difficult_accuracies = history['difficult_accuracies']
391
 
392
+ if len(difficult_accuracies) > 50:
393
+ window = 10
 
394
  smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same')
395
  ax.plot(iterations, smoothed,
396
+ label=name,
397
+ color=colors[name],
398
+ linestyle=line_styles[name],
399
+ linewidth=line_widths[name])
 
 
400
  else:
401
+ ax.plot(iterations, difficult_accuracies,
402
+ label=name,
403
+ color=colors[name],
404
+ linestyle=line_styles[name],
405
+ linewidth=line_widths[name])
406
+
407
+ ax.set_xlabel('Training Iteration')
408
+ ax.set_ylabel('Accuracy on Hard Questions')
409
+ ax.set_title('Performance on Difficult Content')
410
+ ax.legend(loc='lower right')
411
+ ax.grid(True, alpha=0.3)
412
+ ax.set_ylim([0.0, 1.0])
413
+
414
+ # 3. Plot 3: Topic Coverage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  ax = axes[2]
 
 
416
  for name, history in histories.items():
417
  iterations = history['iterations']
418
  topics_seen = history['topics']
419
 
 
420
  unique_topics = []
421
  seen_so_far = set()
 
422
  for topic in topics_seen:
423
  seen_so_far.add(topic)
424
  unique_topics.append(len(seen_so_far))
425
 
426
+ ax.plot(iterations, unique_topics,
427
+ label=name,
428
+ color=colors[name],
429
+ linestyle=line_styles[name],
430
+ linewidth=line_widths[name])
431
+
432
+ ax.set_xlabel('Training Iteration')
433
+ ax.set_ylabel('Unique Topics Seen')
434
+ ax.set_title('Curriculum Diversity')
435
+ ax.legend(loc='lower right')
436
+ ax.grid(True, alpha=0.3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
+ # 4. Plot 4: Learning Efficiency
439
+ ax = axes[3]
440
  target_acc = 0.75
441
  strategy_stats = {}
442
 
 
444
  difficult_accuracies = history['difficult_accuracies']
445
  iterations = history['iterations']
446
 
 
447
  reached_target = False
448
  target_iteration = len(iterations) - 1
449
 
 
459
  'final_acc': difficult_accuracies[-1]
460
  }
461
 
 
462
  names = list(strategy_stats.keys())
463
  iterations_to_target = [
464
  strategy_stats[n]['iteration'] if strategy_stats[n]['reached'] else len(histories[n]['iterations'])
 
469
  x = np.arange(len(names))
470
  width = 0.35
471
 
472
+ ax.bar(x - width/2, iterations_to_target, width, label='Iterations to 75% on Hard',
473
+ color=[colors[n] for n in names], alpha=0.7)
474
+ ax.bar(x + width/2, [acc * max(iterations_to_target) for acc in final_accs], width,
475
+ label='Final Hard Accuracy (scaled)',
476
+ color=[colors[n] for n in names], alpha=0.5)
477
 
478
+ ax.set_title('Learning Efficiency')
 
 
 
479
  ax.set_xticks(x)
480
  ax.set_xticklabels(names)
481
+ ax.legend()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  plt.tight_layout()
484
+ plt.savefig(save_path, dpi=150)
485
  print(f"\n✅ Saved comparison plot to {save_path}")
486
  plt.close()
 
 
 
 
 
 
 
 
 
487
 
488
 
489
  if __name__ == "__main__":
490
  import argparse
491
  import time
492
 
493
+ parser = argparse.ArgumentParser()
494
+ parser.add_argument('--seed', type=int, default=None)
495
+ parser.add_argument('--iterations', type=int, default=500)
496
+ parser.add_argument('--deterministic', action='store_true')
497
+ parser.add_argument('--runs', type=int, default=1)
 
 
 
 
498
 
499
  args = parser.parse_args()
500
 
 
501
  if args.deterministic:
502
  seed = 42
503
+ print("⚠️ Using deterministic mode (seed=42)")
504
  elif args.seed is not None:
505
  seed = args.seed
 
506
  else:
507
+ seed = int(time.time()) % 10000
508
+
509
+ print(f"Using seed: {seed}")
510
 
511
  num_iterations = args.iterations
512
 
513
+ # Run strategies
514
+ print("Training Random Strategy...")
515
+ history_random = train_strategy_random(num_iterations=num_iterations, seed=seed)
516
+
517
+ print("\nTraining Progressive Strategy...")
518
+ history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=seed)
519
+
520
+ print("\nTraining Teacher Strategy...")
521
+ history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
 
 
523
  histories = {
524
  'Random': history_random,
525
  'Progressive': history_progressive,
526
  'Teacher': history_teacher
527
  }
528
 
529
+ plot_comparison(histories, save_path='teacher_agent_dev/comparison_all_strategies.png')
530
+ print("\n✅ Comparison complete!")