hfexample's picture
Deploy clean snapshot of the repository
e221c83
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import platform
# --- Matplotlib ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ---
try:
if platform.system() == 'Windows':
plt.rc('font', family='Malgun Gothic')
elif platform.system() == 'Darwin': # Mac OS
plt.rc('font', family='AppleGothic')
else: # Linux (์ฝ”๋žฉ ๋“ฑ)
# Colab ๋“ฑ์—์„œ ์‹คํ–‰ ์‹œ, ๋จผ์ € !sudo apt-get install -y fonts-nanum ์‹คํ–‰ ํ•„์š”
plt.rc('font', family='NanumBarunGothic')
plt.rcParams['axes.unicode_minus'] = False
print("ํ•œ๊ธ€ ํฐํŠธ๊ฐ€ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"ํ•œ๊ธ€ ํฐํŠธ ์„ค์ •์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
# --- 1. JSON ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ํŒŒ์‹ฑ ---
# [๋ณ€๊ฒฝ] v2 ํ›ˆ๋ จ์˜ trainer_state.json ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
file_path = r'E:\Emotion\results\emotion_model_v2_manual\checkpoint-29050\trainer_state.json'
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
log_history = data.get('log_history', [])
# ํ•™์Šต ๋กœ๊ทธ์™€ ํ‰๊ฐ€ ๋กœ๊ทธ ๋ถ„๋ฆฌ
train_logs = []
eval_logs = []
for item in log_history:
if 'eval_loss' in item:
eval_logs.append(item)
elif 'loss' in item:
train_logs.append(item)
if not eval_logs:
print("ํŒŒ์ผ์—์„œ ํ‰๊ฐ€ ๋กœ๊ทธ('eval_loss' ํฌํ•จ)๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
else:
# DataFrame์œผ๋กœ ๋ณ€ํ™˜
df_train = pd.DataFrame(train_logs).sort_values(by='epoch')
df_eval = pd.DataFrame(eval_logs).sort_values(by='epoch')
# --- 2. ์ตœ์  ๋ชจ๋ธ ์ •๋ณด ์ฐพ๊ธฐ ---
best_step = data.get('best_global_step')
best_epoch_log = next((item for item in eval_logs if item['step'] == best_step), None)
best_epoch = best_epoch_log['epoch'] if best_epoch_log else None
# --- 3. ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ (1ํ–‰ 2์—ด) ---
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))
# --- 3-1. Loss ๊ทธ๋ž˜ํ”„ (Train vs Validation) ---
ax1.plot(df_train['epoch'], df_train['loss'], 'o-', label='Train Loss (ํ•™์Šต ์†์‹ค)', alpha=0.7)
ax1.plot(df_eval['epoch'], df_eval['eval_loss'], 's--', label='Validation Loss (๊ฒ€์ฆ ์†์‹ค)', linewidth=2, markersize=8)
ax1.set_title('๋ชจ๋ธ ํ•™์Šต ๊ณผ์ • (Loss)', fontsize=16)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.legend(fontsize=11)
ax1.grid(True, linestyle=':')
if best_epoch:
ax1.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2,
label=f'Best Model (Epoch {best_epoch:g})')
# ์ตœ๊ณ ์  ํ…์ŠคํŠธ ์ถ”๊ฐ€ (์ตœ์ € Loss๊ฐ€ ์•„๋‹Œ Best Model์˜ Epoch ๊ธฐ์ค€)
best_val_loss = next((e['eval_loss'] for e in eval_logs if e['epoch'] == best_epoch), None)
if best_val_loss:
ax1.plot(best_epoch, best_val_loss, 'r*', markersize=15) # ์ตœ๊ณ  ์ง€์  ๋งˆ์ปค
ax1.text(best_epoch, best_val_loss * 1.05, f'Best @ Epoch {best_epoch:g}',
color='red', horizontalalignment='center', fontsize=12)
# --- 3-2. Metrics ๊ทธ๋ž˜ํ”„ (Validation Accuracy vs F1-Score) ---
ax2.plot(df_eval['epoch'], df_eval['eval_accuracy'] * 100, 'o-',
label='Validation Accuracy (๊ฒ€์ฆ ์ •ํ™•๋„)', linewidth=2, markersize=8)
ax2.plot(df_eval['epoch'], df_eval['eval_f1'] * 100, 's--',
label='Validation F1-Score (๊ฒ€์ฆ F1)', linewidth=2, markersize=8)
ax2.set_title('๋ชจ๋ธ ํ‰๊ฐ€ ์ง€ํ‘œ (Accuracy & F1-Score)', fontsize=16)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Score (%)', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, linestyle=':')
if best_epoch:
ax2.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2,
label=f'Best Model (Epoch {best_epoch:g})')
# ์ตœ๊ณ ์  ํ…์ŠคํŠธ ์ถ”๊ฐ€ (Accuracy ๊ธฐ์ค€)
best_acc = best_epoch_log['eval_accuracy'] * 100
ax2.plot(best_epoch, best_acc, 'r*', markersize=15) # ์ตœ๊ณ  ์ง€์  ๋งˆ์ปค
ax2.text(best_epoch, best_acc * 0.99, f'Best Acc: {best_acc:.2f}%',
color='red', horizontalalignment='center', verticalalignment='top', fontsize=12)
# ๊ทธ๋ž˜ํ”„ ๋ ˆ์ด์•„์›ƒ ์ •๋ฆฌ ๋ฐ ํŒŒ์ผ๋กœ ์ €์žฅ
plt.tight_layout()
plt.savefig('training_visualization_graph.png', dpi=300)
print("\n'training_visualization_graph.png' ํŒŒ์ผ๋กœ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
except FileNotFoundError:
print(f"์˜ค๋ฅ˜: '{file_path}' ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ฒฝ๋กœ๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
except Exception as e:
print(f"๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")