File size: 4,868 Bytes
e221c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import platform

# --- Matplotlib ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ---
try:
    if platform.system() == 'Windows':
        plt.rc('font', family='Malgun Gothic')
    elif platform.system() == 'Darwin': # Mac OS
        plt.rc('font', family='AppleGothic')
    else: # Linux (์ฝ”๋žฉ ๋“ฑ)
        # Colab ๋“ฑ์—์„œ ์‹คํ–‰ ์‹œ, ๋จผ์ € !sudo apt-get install -y fonts-nanum ์‹คํ–‰ ํ•„์š”
        plt.rc('font', family='NanumBarunGothic')
    
    plt.rcParams['axes.unicode_minus'] = False
    print("ํ•œ๊ธ€ ํฐํŠธ๊ฐ€ ์„ค์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
    print(f"ํ•œ๊ธ€ ํฐํŠธ ์„ค์ •์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")

# --- 1. JSON ํŒŒ์ผ ๋กœ๋“œ ๋ฐ ํŒŒ์‹ฑ ---
# [๋ณ€๊ฒฝ] v2 ํ›ˆ๋ จ์˜ trainer_state.json ๊ฒฝ๋กœ๋กœ ์ˆ˜์ •
file_path = r'E:\Emotion\results\emotion_model_v2_manual\checkpoint-29050\trainer_state.json'

try:
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    log_history = data.get('log_history', [])
    
    # ํ•™์Šต ๋กœ๊ทธ์™€ ํ‰๊ฐ€ ๋กœ๊ทธ ๋ถ„๋ฆฌ
    train_logs = []
    eval_logs = []
    
    for item in log_history:
        if 'eval_loss' in item:
            eval_logs.append(item)
        elif 'loss' in item:
            train_logs.append(item)
            
    if not eval_logs:
        print("ํŒŒ์ผ์—์„œ ํ‰๊ฐ€ ๋กœ๊ทธ('eval_loss' ํฌํ•จ)๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
    else:
        # DataFrame์œผ๋กœ ๋ณ€ํ™˜
        df_train = pd.DataFrame(train_logs).sort_values(by='epoch')
        df_eval = pd.DataFrame(eval_logs).sort_values(by='epoch')

        # --- 2. ์ตœ์  ๋ชจ๋ธ ์ •๋ณด ์ฐพ๊ธฐ ---
        best_step = data.get('best_global_step')
        best_epoch_log = next((item for item in eval_logs if item['step'] == best_step), None)
        best_epoch = best_epoch_log['epoch'] if best_epoch_log else None

        # --- 3. ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ (1ํ–‰ 2์—ด) ---
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))

        # --- 3-1. Loss ๊ทธ๋ž˜ํ”„ (Train vs Validation) ---
        ax1.plot(df_train['epoch'], df_train['loss'], 'o-', label='Train Loss (ํ•™์Šต ์†์‹ค)', alpha=0.7)
        ax1.plot(df_eval['epoch'], df_eval['eval_loss'], 's--', label='Validation Loss (๊ฒ€์ฆ ์†์‹ค)', linewidth=2, markersize=8)
        ax1.set_title('๋ชจ๋ธ ํ•™์Šต ๊ณผ์ • (Loss)', fontsize=16)
        ax1.set_xlabel('Epoch', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.legend(fontsize=11)
        ax1.grid(True, linestyle=':')
        
        if best_epoch:
            ax1.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2, 
                        label=f'Best Model (Epoch {best_epoch:g})')
            # ์ตœ๊ณ ์  ํ…์ŠคํŠธ ์ถ”๊ฐ€ (์ตœ์ € Loss๊ฐ€ ์•„๋‹Œ Best Model์˜ Epoch ๊ธฐ์ค€)
            best_val_loss = next((e['eval_loss'] for e in eval_logs if e['epoch'] == best_epoch), None)
            if best_val_loss:
                ax1.plot(best_epoch, best_val_loss, 'r*', markersize=15) # ์ตœ๊ณ  ์ง€์  ๋งˆ์ปค
                ax1.text(best_epoch, best_val_loss * 1.05, f'Best @ Epoch {best_epoch:g}', 
                         color='red', horizontalalignment='center', fontsize=12)

        # --- 3-2. Metrics ๊ทธ๋ž˜ํ”„ (Validation Accuracy vs F1-Score) ---
        ax2.plot(df_eval['epoch'], df_eval['eval_accuracy'] * 100, 'o-', 
                 label='Validation Accuracy (๊ฒ€์ฆ ์ •ํ™•๋„)', linewidth=2, markersize=8)
        ax2.plot(df_eval['epoch'], df_eval['eval_f1'] * 100, 's--', 
                 label='Validation F1-Score (๊ฒ€์ฆ F1)', linewidth=2, markersize=8)
        ax2.set_title('๋ชจ๋ธ ํ‰๊ฐ€ ์ง€ํ‘œ (Accuracy & F1-Score)', fontsize=16)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Score (%)', fontsize=12)
        ax2.legend(fontsize=11)
        ax2.grid(True, linestyle=':')

        if best_epoch:
            ax2.axvline(x=best_epoch, color='red', linestyle=':', linewidth=2, 
                        label=f'Best Model (Epoch {best_epoch:g})')
            # ์ตœ๊ณ ์  ํ…์ŠคํŠธ ์ถ”๊ฐ€ (Accuracy ๊ธฐ์ค€)
            best_acc = best_epoch_log['eval_accuracy'] * 100
            ax2.plot(best_epoch, best_acc, 'r*', markersize=15) # ์ตœ๊ณ  ์ง€์  ๋งˆ์ปค
            ax2.text(best_epoch, best_acc * 0.99, f'Best Acc: {best_acc:.2f}%', 
                       color='red', horizontalalignment='center', verticalalignment='top', fontsize=12)

        # ๊ทธ๋ž˜ํ”„ ๋ ˆ์ด์•„์›ƒ ์ •๋ฆฌ ๋ฐ ํŒŒ์ผ๋กœ ์ €์žฅ
        plt.tight_layout()
        plt.savefig('training_visualization_graph.png', dpi=300)
        print("\n'training_visualization_graph.png' ํŒŒ์ผ๋กœ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")

except FileNotFoundError:
    print(f"์˜ค๋ฅ˜: '{file_path}' ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ฒฝ๋กœ๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
except Exception as e:
    print(f"๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")