Spaces:
Running
Running
| import json | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib.font_manager as fm | |
| import platform | |
| # --- Matplotlib ํ๊ธ ํฐํธ ์ค์ --- | |
| try: | |
| if platform.system() == 'Windows': | |
| plt.rc('font', family='Malgun Gothic') | |
| elif platform.system() == 'Darwin': # Mac OS | |
| plt.rc('font', family='AppleGothic') | |
| else: # Linux (์ฝ๋ฉ ๋ฑ) | |
| # Colab ๋ฑ์์ ์คํ ์, ๋จผ์ !sudo apt-get install -y fonts-nanum ์คํ ํ์ | |
| plt.rc('font', family='NanumBarunGothic') | |
| plt.rcParams['axes.unicode_minus'] = False | |
| print("ํ๊ธ ํฐํธ๊ฐ ์ค์ ๋์์ต๋๋ค.") | |
| except Exception as e: | |
| print(f"ํ๊ธ ํฐํธ ์ค์ ์ ์คํจํ์ต๋๋ค: {e}") | |
| # --- 1. JSON ํ์ผ ๋ก๋ ๋ฐ ํ์ฑ --- | |
| # [๋ณ๊ฒฝ] v2 ํ๋ จ์ trainer_state.json ๊ฒฝ๋ก๋ก ์์ | |
| file_path = r'E:\Emotion\results\emotion_model_v2_manual\checkpoint-29050\trainer_state.json' | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| log_history = data.get('log_history', []) | |
| # ํ์ต ๋ก๊ทธ์ ํ๊ฐ ๋ก๊ทธ ๋ถ๋ฆฌ | |
| train_logs = [] | |
| eval_logs = [] | |
| for item in log_history: | |
| if 'eval_loss' in item: | |
| eval_logs.append(item) | |
| elif 'loss' in item: | |
| train_logs.append(item) | |
| if not eval_logs: | |
| print("ํ์ผ์์ ํ๊ฐ ๋ก๊ทธ('eval_loss' ํฌํจ)๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| else: | |
| # DataFrame์ผ๋ก ๋ณํ | |
| df_train = pd.DataFrame(train_logs).sort_values(by='epoch') | |
| df_eval = pd.DataFrame(eval_logs).sort_values(by='epoch') | |
| # --- 2. ์ต์ ๋ชจ๋ธ ์ ๋ณด ์ฐพ๊ธฐ --- | |
| best_step = data.get('best_global_step') | |
| best_epoch_log = next((item for item in eval_logs if item['step'] == best_step), None) | |
| best_epoch = best_epoch_log['epoch'] if best_epoch_log else None | |
| # --- 3. ๊ทธ๋ํ ์์ฑ (1ํ 2์ด) --- | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7)) | |
| # --- 3-1. Loss ๊ทธ๋ํ (Train vs Validation) --- | |
| ax1.plot(df_train['epoch'], df_train['loss'], 'o-', label='Train Loss (ํ์ต ์์ค)', alpha=0.7) | |
| ax1.plot(df_eval['epoch'], df_eval['eval_loss'], 's--', label='Validation Loss (๊ฒ์ฆ ์์ค)', linewidth=2, markersize=8) | |
| ax1.set_title('๋ชจ๋ธ ํ์ต ๊ณผ์ (Loss)', fontsize=16) | |
| ax1.set_xlabel('Epoch', fontsize=12) | |
| ax1.set_ylabel('Loss', fontsize=12) | |
| ax1.legend(fontsize=11) | |
| ax1.grid(True, linestyle=':') | |
| if best_epoch: | |
| ax1.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2, | |
| label=f'Best Model (Epoch {best_epoch:g})') | |
| # ์ต๊ณ ์ ํ ์คํธ ์ถ๊ฐ (์ต์ Loss๊ฐ ์๋ Best Model์ Epoch ๊ธฐ์ค) | |
| best_val_loss = next((e['eval_loss'] for e in eval_logs if e['epoch'] == best_epoch), None) | |
| if best_val_loss: | |
| ax1.plot(best_epoch, best_val_loss, 'r*', markersize=15) # ์ต๊ณ ์ง์ ๋ง์ปค | |
| ax1.text(best_epoch, best_val_loss * 1.05, f'Best @ Epoch {best_epoch:g}', | |
| color='red', horizontalalignment='center', fontsize=12) | |
| # --- 3-2. Metrics ๊ทธ๋ํ (Validation Accuracy vs F1-Score) --- | |
| ax2.plot(df_eval['epoch'], df_eval['eval_accuracy'] * 100, 'o-', | |
| label='Validation Accuracy (๊ฒ์ฆ ์ ํ๋)', linewidth=2, markersize=8) | |
| ax2.plot(df_eval['epoch'], df_eval['eval_f1'] * 100, 's--', | |
| label='Validation F1-Score (๊ฒ์ฆ F1)', linewidth=2, markersize=8) | |
| ax2.set_title('๋ชจ๋ธ ํ๊ฐ ์งํ (Accuracy & F1-Score)', fontsize=16) | |
| ax2.set_xlabel('Epoch', fontsize=12) | |
| ax2.set_ylabel('Score (%)', fontsize=12) | |
| ax2.legend(fontsize=11) | |
| ax2.grid(True, linestyle=':') | |
| if best_epoch: | |
| ax2.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2, | |
| label=f'Best Model (Epoch {best_epoch:g})') | |
| # ์ต๊ณ ์ ํ ์คํธ ์ถ๊ฐ (Accuracy ๊ธฐ์ค) | |
| best_acc = best_epoch_log['eval_accuracy'] * 100 | |
| ax2.plot(best_epoch, best_acc, 'r*', markersize=15) # ์ต๊ณ ์ง์ ๋ง์ปค | |
| ax2.text(best_epoch, best_acc * 0.99, f'Best Acc: {best_acc:.2f}%', | |
| color='red', horizontalalignment='center', verticalalignment='top', fontsize=12) | |
| # ๊ทธ๋ํ ๋ ์ด์์ ์ ๋ฆฌ ๋ฐ ํ์ผ๋ก ์ ์ฅ | |
| plt.tight_layout() | |
| plt.savefig('training_visualization_graph.png', dpi=300) | |
| print("\n'training_visualization_graph.png' ํ์ผ๋ก ๊ทธ๋ํ๊ฐ ์ ์ฅ๋์์ต๋๋ค.") | |
| except FileNotFoundError: | |
| print(f"์ค๋ฅ: '{file_path}' ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.") | |
| except Exception as e: | |
| print(f"๋ฐ์ดํฐ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}") |