Spaces:
Sleeping
Sleeping
File size: 4,868 Bytes
e221c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import platform
# --- Matplotlib ํ๊ธ ํฐํธ ์ค์ ---
try:
if platform.system() == 'Windows':
plt.rc('font', family='Malgun Gothic')
elif platform.system() == 'Darwin': # Mac OS
plt.rc('font', family='AppleGothic')
else: # Linux (์ฝ๋ฉ ๋ฑ)
# Colab ๋ฑ์์ ์คํ ์, ๋จผ์ !sudo apt-get install -y fonts-nanum ์คํ ํ์
plt.rc('font', family='NanumBarunGothic')
plt.rcParams['axes.unicode_minus'] = False
print("ํ๊ธ ํฐํธ๊ฐ ์ค์ ๋์์ต๋๋ค.")
except Exception as e:
print(f"ํ๊ธ ํฐํธ ์ค์ ์ ์คํจํ์ต๋๋ค: {e}")
# --- 1. JSON ํ์ผ ๋ก๋ ๋ฐ ํ์ฑ ---
# [๋ณ๊ฒฝ] v2 ํ๋ จ์ trainer_state.json ๊ฒฝ๋ก๋ก ์์
file_path = r'E:\Emotion\results\emotion_model_v2_manual\checkpoint-29050\trainer_state.json'
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
log_history = data.get('log_history', [])
# ํ์ต ๋ก๊ทธ์ ํ๊ฐ ๋ก๊ทธ ๋ถ๋ฆฌ
train_logs = []
eval_logs = []
for item in log_history:
if 'eval_loss' in item:
eval_logs.append(item)
elif 'loss' in item:
train_logs.append(item)
if not eval_logs:
print("ํ์ผ์์ ํ๊ฐ ๋ก๊ทธ('eval_loss' ํฌํจ)๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.")
else:
# DataFrame์ผ๋ก ๋ณํ
df_train = pd.DataFrame(train_logs).sort_values(by='epoch')
df_eval = pd.DataFrame(eval_logs).sort_values(by='epoch')
# --- 2. ์ต์ ๋ชจ๋ธ ์ ๋ณด ์ฐพ๊ธฐ ---
best_step = data.get('best_global_step')
best_epoch_log = next((item for item in eval_logs if item['step'] == best_step), None)
best_epoch = best_epoch_log['epoch'] if best_epoch_log else None
# --- 3. ๊ทธ๋ํ ์์ฑ (1ํ 2์ด) ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))
# --- 3-1. Loss ๊ทธ๋ํ (Train vs Validation) ---
ax1.plot(df_train['epoch'], df_train['loss'], 'o-', label='Train Loss (ํ์ต ์์ค)', alpha=0.7)
ax1.plot(df_eval['epoch'], df_eval['eval_loss'], 's--', label='Validation Loss (๊ฒ์ฆ ์์ค)', linewidth=2, markersize=8)
ax1.set_title('๋ชจ๋ธ ํ์ต ๊ณผ์ (Loss)', fontsize=16)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, linestyle=':')
if best_epoch:
ax1.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2,
label=f'Best Model (Epoch {best_epoch:g})')
# ์ต๊ณ ์ ํ
์คํธ ์ถ๊ฐ (์ต์ Loss๊ฐ ์๋ Best Model์ Epoch ๊ธฐ์ค)
best_val_loss = next((e['eval_loss'] for e in eval_logs if e['epoch'] == best_epoch), None)
if best_val_loss:
ax1.plot(best_epoch, best_val_loss, 'r*', markersize=15) # ์ต๊ณ ์ง์ ๋ง์ปค
ax1.text(best_epoch, best_val_loss * 1.05, f'Best @ Epoch {best_epoch:g}',
color='red', horizontalalignment='center', fontsize=12)
# --- 3-2. Metrics ๊ทธ๋ํ (Validation Accuracy vs F1-Score) ---
ax2.plot(df_eval['epoch'], df_eval['eval_accuracy'] * 100, 'o-',
label='Validation Accuracy (๊ฒ์ฆ ์ ํ๋)', linewidth=2, markersize=8)
ax2.plot(df_eval['epoch'], df_eval['eval_f1'] * 100, 's--',
label='Validation F1-Score (๊ฒ์ฆ F1)', linewidth=2, markersize=8)
ax2.set_title('๋ชจ๋ธ ํ๊ฐ ์งํ (Accuracy & F1-Score)', fontsize=16)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Score (%)', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, linestyle=':')
if best_epoch:
ax2.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2,
label=f'Best Model (Epoch {best_epoch:g})')
# ์ต๊ณ ์ ํ
์คํธ ์ถ๊ฐ (Accuracy ๊ธฐ์ค)
best_acc = best_epoch_log['eval_accuracy'] * 100
ax2.plot(best_epoch, best_acc, 'r*', markersize=15) # ์ต๊ณ ์ง์ ๋ง์ปค
ax2.text(best_epoch, best_acc * 0.99, f'Best Acc: {best_acc:.2f}%',
color='red', horizontalalignment='center', verticalalignment='top', fontsize=12)
# ๊ทธ๋ํ ๋ ์ด์์ ์ ๋ฆฌ ๋ฐ ํ์ผ๋ก ์ ์ฅ
plt.tight_layout()
plt.savefig('training_visualization_graph.png', dpi=300)
print("\n'training_visualization_graph.png' ํ์ผ๋ก ๊ทธ๋ํ๊ฐ ์ ์ฅ๋์์ต๋๋ค.")
except FileNotFoundError:
print(f"์ค๋ฅ: '{file_path}' ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค. ๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.")
except Exception as e:
print(f"๋ฐ์ดํฐ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}") |