Spaces:
Paused
Paused
File size: 5,306 Bytes
a52f96d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Main training script for student agent.
Integrates student with mock teacher/task generator and generates
comprehensive visualizations.
"""
import torch
from student_agent import StudentAgent
from student_metrics import StudentMetrics
from mock_teacher import MockTeacherAgent
from mock_task_generator import MockTaskGenerator
from visualize_student import create_comprehensive_report
def compute_teacher_reward(
accuracy_before: float,
accuracy_after: float,
difficulty: str,
is_review: bool
) -> float:
"""Reward function for teacher (shared with teacher agent)."""
improvement = accuracy_after - accuracy_before
difficulty_bonus = {'easy': 0.5, 'medium': 1.0, 'hard': 2.0}.get(difficulty, 1.0)
review_bonus = 1.0 if (is_review and improvement > 0) else 0.0
review_penalty = -0.5 if (is_review and accuracy_after > 0.9) else 0.0
return improvement + difficulty_bonus + review_bonus + review_penalty
def train_student(
num_iterations: int = 500,
device: str = 'cpu',
learning_rate: float = 5e-5,
retention_constant: float = 80.0,
verbose: bool = True
):
"""
Train student agent with mock teacher and task generator.
Args:
num_iterations: Number of training iterations
device: 'cpu' or 'cuda'
learning_rate: Student LM learning rate
retention_constant: Memory decay rate (higher = slower forgetting)
verbose: Print progress
Returns:
Tuple of (metrics, student, teacher, generator)
"""
# Initialize components
if verbose:
print("Initializing student agent...")
student = StudentAgent(
learning_rate=learning_rate,
retention_constant=retention_constant,
device=device
)
teacher = MockTeacherAgent()
generator = MockTaskGenerator()
# Create evaluation set (held-out for measuring progress)
eval_tasks = []
for topic in generator.get_available_topics():
for difficulty in ['easy', 'medium', 'hard']:
for _ in range(2): # 2 tasks per (topic, difficulty)
eval_tasks.append(generator.generate_task(topic, difficulty))
if verbose:
print(f"Created evaluation set: {len(eval_tasks)} tasks")
print(f"Training for {num_iterations} iterations...\n")
# Initialize metrics tracker
metrics = StudentMetrics()
# Training loop
for iteration in range(num_iterations):
# 1. Get student state
student_state = student.get_state()
# 2. Teacher selects action
action = teacher.select_action(student_state)
# 3. Generate task
task = generator.generate_task(action.topic, action.difficulty)
# 4. Evaluate BEFORE learning
accuracy_before = student.evaluate(eval_tasks)
# 5. Student learns from task
was_correct = student.learn(task)
# 6. Evaluate AFTER learning
accuracy_after = student.evaluate(eval_tasks)
# 7. Compute teacher reward (for compatibility with teacher agent)
reward = compute_teacher_reward(
accuracy_before, accuracy_after,
action.difficulty, action.is_review
)
# 8. Update teacher (mock doesn't use this)
teacher.update(action, reward)
# 9. Time passes (for forgetting)
student.advance_time(1.0)
# 10. Log metrics
topic_accuracies = {
topic: student.memory.get_effective_skill(topic)
for topic in student.topic_base_skills
}
retention_factors = {
topic: student.memory.get_retention_factor(topic)
for topic in student.topic_base_skills
}
metrics.log_iteration(
iteration=iteration,
overall_acc=accuracy_after,
topic_accs=topic_accuracies,
task=task,
correct=was_correct,
retention_factors=retention_factors
)
# 11. Print progress
if verbose and iteration % 50 == 0:
avg_acc = accuracy_after
topics_practiced = len(student.topic_base_skills)
print(f"Iteration {iteration:3d} | "
f"Accuracy: {avg_acc:.3f} | "
f"Topics: {topics_practiced} | "
f"Correct: {'โ' if was_correct else 'โ'}")
if verbose:
print("\nโ
Training complete!")
return metrics, student, teacher, generator
def main():
"""Main entry point."""
# Check if CUDA available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}\n")
# Train student
metrics, student, teacher, generator = train_student(
num_iterations=500,
device=device,
learning_rate=5e-5,
retention_constant=80.0,
verbose=True
)
# Generate visualizations
create_comprehensive_report(metrics, output_dir='student_visualizations')
# Save model checkpoint
student.save('student_checkpoint.pt')
if verbose:
print("\n๐พ Saved student checkpoint to student_checkpoint.pt")
if __name__ == "__main__":
main()
|