"""Visualization utilities for Teacher Agent system.""" import matplotlib.pyplot as plt import numpy as np from typing import Dict, List from teacher_agent import TeacherAgent def plot_learning_curves(history: Dict, save_path: str = 'learning_curves.png'): """ Plot student accuracy and teacher reward over time. Args: history: Dictionary with 'iterations', 'student_accuracies', 'teacher_rewards' save_path: Where to save the plot """ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8)) iterations = history['iterations'] # Plot student accuracy ax1.plot(iterations, history['student_accuracies'], label='Student Accuracy', linewidth=2) ax1.set_xlabel('Iteration') ax1.set_ylabel('Accuracy') ax1.set_title('Student Learning Curve') ax1.grid(True, alpha=0.3) ax1.legend() ax1.set_ylim([0, 1]) # Plot teacher reward (smoothed) rewards = np.array(history['teacher_rewards']) window = 50 if len(rewards) > window: smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid') smoothed_iterations = iterations[window-1:] ax2.plot(smoothed_iterations, smoothed, label=f'Smoothed Reward (window={window})', linewidth=2) ax2.plot(iterations, rewards, alpha=0.3, label='Raw Reward', linewidth=0.5) else: ax2.plot(iterations, rewards, label='Reward', linewidth=2) ax2.set_xlabel('Iteration') ax2.set_ylabel('Reward') ax2.set_title('Teacher Reward Over Time') ax2.grid(True, alpha=0.3) ax2.legend() plt.tight_layout() plt.savefig(save_path, dpi=150) print(f"Saved learning curves to {save_path}") plt.close() def plot_curriculum_heatmap(history: Dict, save_path: str = 'curriculum_heatmap.png'): """ Visualize teacher's curriculum choices over time. Args: history: Dictionary with 'iterations', 'topics', 'difficulties', 'is_reviews' save_path: Where to save the plot """ topics = list(set(history['topics'])) topics.sort() # Create grid: time (iterations) vs topics num_iterations = len(history['iterations']) num_topics = len(topics) # Map difficulty to numeric value difficulty_map = {'easy': 1, 'medium': 2, 'hard': 3} # Create heatmap data heatmap_data = np.zeros((num_topics, num_iterations)) for i, (topic, difficulty, is_review) in enumerate(zip( history['topics'], history['difficulties'], history['is_reviews'] )): topic_idx = topics.index(topic) diff_value = difficulty_map[difficulty] if is_review: diff_value = 0.5 # Mark reviews differently heatmap_data[topic_idx, i] = diff_value fig, ax = plt.subplots(figsize=(14, 6)) im = ax.imshow(heatmap_data, aspect='auto', cmap='viridis', interpolation='nearest') ax.set_yticks(range(num_topics)) ax.set_yticklabels(topics) ax.set_xlabel('Iteration') ax.set_ylabel('Topic') ax.set_title('Curriculum Heatmap (Light=Easy/Review, Dark=Hard)') # Add colorbar cbar = plt.colorbar(im, ax=ax) cbar.set_label('Difficulty (0.5=Review, 1=Easy, 2=Medium, 3=Hard)') # Sample iterations for x-axis labels if num_iterations > 20: step = num_iterations // 10 ax.set_xticks(range(0, num_iterations, step)) ax.set_xticklabels(range(0, num_iterations, step)) plt.tight_layout() plt.savefig(save_path, dpi=150) print(f"Saved curriculum heatmap to {save_path}") plt.close() def plot_action_distributions(teacher: TeacherAgent, save_path: str = 'action_dist.png'): """ Show which actions teacher prefers. Args: teacher: Trained TeacherAgent save_path: Where to save the plot """ stats = teacher.get_statistics() fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # 1. Topic distribution topic_counts = {} for idx, count in enumerate(stats['action_counts']): if count > 0: action = teacher._index_to_action(idx) topic_counts[action.topic] = topic_counts.get(action.topic, 0) + count ax = axes[0, 0] topics = list(topic_counts.keys()) counts = list(topic_counts.values()) ax.bar(topics, counts) ax.set_xlabel('Topic') ax.set_ylabel('Count') ax.set_title('Topic Selection Distribution') ax.tick_params(axis='x', rotation=45) # 2. Difficulty distribution difficulty_counts = {'easy': 0, 'medium': 0, 'hard': 0} for idx, count in enumerate(stats['action_counts']): if count > 0: action = teacher._index_to_action(idx) difficulty_counts[action.difficulty] += count ax = axes[0, 1] difficulties = list(difficulty_counts.keys()) counts = list(difficulty_counts.values()) ax.bar(difficulties, counts) ax.set_xlabel('Difficulty') ax.set_ylabel('Count') ax.set_title('Difficulty Selection Distribution') # 3. Review vs New review_counts = {'New': 0, 'Review': 0} for idx, count in enumerate(stats['action_counts']): if count > 0: action = teacher._index_to_action(idx) key = 'Review' if action.is_review else 'New' review_counts[key] += count ax = axes[1, 0] labels = list(review_counts.keys()) sizes = list(review_counts.values()) ax.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90) ax.set_title('New vs Review Distribution') # 4. Average reward per topic topic_rewards = {} for idx in range(len(stats['action_counts'])): if stats['action_counts'][idx] > 0: action = teacher._index_to_action(idx) avg_reward = stats['action_rewards'][idx] / stats['action_counts'][idx] topic_rewards[action.topic] = topic_rewards.get(action.topic, []) + [avg_reward] # Compute mean reward per topic topic_avg_rewards = {topic: np.mean(rewards) for topic, rewards in topic_rewards.items()} ax = axes[1, 1] topics = list(topic_avg_rewards.keys()) rewards = list(topic_avg_rewards.values()) ax.bar(topics, rewards) ax.set_xlabel('Topic') ax.set_ylabel('Average Reward') ax.set_title('Average Reward per Topic') ax.tick_params(axis='x', rotation=45) plt.tight_layout() plt.savefig(save_path, dpi=150) print(f"Saved action distributions to {save_path}") plt.close() def plot_comparison(histories: Dict[str, Dict], save_path: str = 'comparison.png'): """ Compare teacher vs baselines. Args: histories: Dictionary mapping strategy name to history dict e.g., {'teacher': history1, 'random': history2, 'fixed': history3} save_path: Where to save the plot """ fig, axes = plt.subplots(2, 1, figsize=(12, 8)) # Plot accuracy comparison ax = axes[0] for name, history in histories.items(): iterations = history['iterations'] accuracies = history['student_accuracies'] ax.plot(iterations, accuracies, label=name, linewidth=2) ax.set_xlabel('Iteration') ax.set_ylabel('Accuracy') ax.set_title('Student Accuracy Comparison') ax.legend() ax.grid(True, alpha=0.3) ax.set_ylim([0, 1]) # Plot reward comparison (smoothed) ax = axes[1] window = 50 for name, history in histories.items(): rewards = np.array(history['teacher_rewards']) iterations = history['iterations'] if len(rewards) > window: smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid') smoothed_iterations = iterations[window-1:] ax.plot(smoothed_iterations, smoothed, label=f'{name} (smoothed)', linewidth=2) else: ax.plot(iterations, rewards, label=name, linewidth=2) ax.set_xlabel('Iteration') ax.set_ylabel('Reward') ax.set_title('Teacher Reward Comparison') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(save_path, dpi=150) print(f"Saved comparison plot to {save_path}") plt.close() if __name__ == "__main__": # Example usage print("This module provides visualization functions.") print("Import and use them with training results:") print() print(" from train_teacher import train_teacher") print(" from visualize import *") print() print(" history, teacher, student = train_teacher(num_iterations=500)") print(" plot_learning_curves(history)") print(" plot_curriculum_heatmap(history)") print(" plot_action_distributions(teacher)")