emotion-chatbot-app / scripts /evaluate_step1.py
hfexample's picture
Deploy clean snapshot of the repository
e221c83
# ํŒŒ์ผ ์ด๋ฆ„: evaluate_step1.py
# 1๋‹จ๊ณ„ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€์„œ ํ‰๊ฐ€ ๋ฐ ํ˜ผ๋™ ํ–‰๋ ฌ๋งŒ ๋‹ค์‹œ ์ƒ์„ฑํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ
import os
import pandas as pd
from dataclasses import dataclass
import json
import re
import torch
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import platform
import matplotlib.pyplot as plt
import seaborn as sns
# train_final.py์™€ ๋™์ผํ•œ ํด๋ž˜์Šค ๋ฐ ํ•จ์ˆ˜๋“ค
# (๋ฐ์ดํ„ฐ ๋กœ๋”ฉ, ๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ ๋“ฑ)
# -----------------------------------------------------------------
@dataclass
class TrainingConfig:
mode: str = "emotion"
data_dir: str = "./data"
output_dir: str = "./results1024"
# [์ˆ˜์ •] 1์ฐจ NSMC ๋ชจ๋ธ ๊ฒฝ๋กœ (์‚ฌ์šฉ์ž๋‹˜ ๊ฒฝ๋กœ๋กœ)
base_model_name: str = r"E:\Emotion\results\nsmc_model"
eval_batch_size: int = 64
max_length: int = 128
def get_step1_model_dir(self) -> str:
# [์ˆ˜์ •] ์ด๋ฏธ ํ›ˆ๋ จ๋œ 1๋‹จ๊ณ„ ๋ชจ๋ธ์˜ "best_model" ํด๋”๋ฅผ ์ง€์ •
return os.path.join(self.output_dir, 'emotion_model_step1_3class', 'best_model')
class EmotionDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
acc = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
return {'accuracy': acc, 'f1': f1}
def clean_text(text: str) -> str:
return re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9 ]', '', str(text))
def map_ecode_to_6class(e_code_str):
if not isinstance(e_code_str, str) or not e_code_str.startswith('E'): return None
try: code_num = int(e_code_str[1:])
except (ValueError, TypeError): return None
if 0 <= code_num <= 9: return '๊ธฐ์จ'
elif 10 <= code_num <= 19: return '๋ถ„๋…ธ'
elif 20 <= code_num <= 29: return '์Šฌํ””'
elif 30 <= code_num <= 39: return '๋ถˆ์•ˆ'
elif 40 <= code_num <= 49: return '์ƒ์ฒ˜'
elif 50 <= code_num <= 59: return '๋‹นํ™ฉ'
else: return None
def map_6_to_3_groups(emotion_6_class):
if emotion_6_class == '์Šฌํ””': return '๊ทธ๋ฃน1(์Šฌํ””)'
elif emotion_6_class in ['๋ถˆ์•ˆ', '์ƒ์ฒ˜']: return '๊ทธ๋ฃน2(๋ถˆ์•ˆ,์ƒ์ฒ˜)'
elif emotion_6_class in ['๋ถ„๋…ธ', '๋‹นํ™ฉ', '๊ธฐ์จ']: return '๊ทธ๋ฃน3(๋ถ„๋…ธ,๋‹นํ™ฉ,๊ธฐ์จ)'
else: return None
def load_and_process(text_file, label_file, data_dir):
text_path = os.path.join(data_dir, text_file)
label_path = os.path.join(data_dir, label_file)
try:
df_text = pd.read_excel(text_path, header=0)
with open(label_path, 'r', encoding='utf-8') as f:
labels_raw = json.load(f)
except FileNotFoundError as e:
print(f"์˜ค๋ฅ˜: ํ•„์ˆ˜ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}")
return pd.DataFrame()
e_codes = []
for dialogue in labels_raw:
try: e_codes.append(dialogue['profile']['emotion']['type'])
except KeyError: e_codes.append(None)
if len(df_text) != len(e_codes):
min_len = min(len(df_text), len(e_codes))
df_text = df_text.iloc[:min_len]
e_codes = e_codes[:min_len]
df_labels = pd.DataFrame({'e_code': e_codes})
df_combined = pd.concat([df_text, df_labels], axis=1)
dialogue_cols = [col for col in df_combined.columns if '๋ฌธ์žฅ' in str(col)]
for col in dialogue_cols:
df_combined[col] = df_combined[col].astype(str).fillna('')
df_combined['text'] = df_combined[dialogue_cols].apply(lambda row: ' '.join(row), axis=1)
df_combined['cleaned_text'] = df_combined['text'].apply(clean_text)
df_combined['major_emotion'] = df_combined['e_code'].apply(map_ecode_to_6class)
df_combined.dropna(subset=['major_emotion'], inplace=True)
df_combined['group_emotion'] = df_combined['major_emotion'].apply(map_6_to_3_groups)
df_combined.dropna(subset=['group_emotion', 'cleaned_text'], inplace=True)
df_combined = df_combined[df_combined['cleaned_text'].str.strip() != '']
return df_combined
def get_test_data(config: TrainingConfig) -> pd.DataFrame:
"""[์ˆ˜์ •] Test Set๋งŒ ๋ถˆ๋Ÿฌ์˜ค๋Š” ํ•จ์ˆ˜"""
print("Loading TEST set (from validation-origin.xlsx + test.json)...")
df_test = load_and_process(
"validation-origin.xlsx",
"test.json",
config.data_dir
)
if df_test.empty:
print("์˜ค๋ฅ˜: Test ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์‹คํŒจ.")
return pd.DataFrame()
print(f"Test data loaded: {len(df_test)} rows")
return df_test
# -----------------------------------------------------------------
def run_evaluation():
# --- 1. ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ---
try:
if platform.system() == 'Windows':
plt.rc('font', family='Malgun Gothic')
elif platform.system() == 'Darwin': # Mac OS
plt.rc('font', family='AppleGothic')
else: # Linux (์ฝ”๋žฉ ๋“ฑ)
plt.rc('font', family='NanumBarunGothic')
plt.rcParams['axes.unicode_minus'] = False
print("ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ์™„๋ฃŒ.")
except Exception as e:
print(f"ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ๊ฒฝ๊ณ : {e}. ํ˜ผ๋™ ํ–‰๋ ฌ์˜ ๋ผ๋ฒจ์ด ๊นจ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
config = TrainingConfig()
# --- 2. ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ ---
model_dir = config.get_step1_model_dir()
output_dir = os.path.dirname(model_dir) # .../emotion_model_step1_3class
if not os.path.exists(model_dir):
print(f"์˜ค๋ฅ˜: ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {model_dir}")
print("train_final.py๋ฅผ ๋จผ์ € ์‹คํ–‰ํ•˜์„ธ์š”.")
return
print(f"์ €์žฅ๋œ 1๋‹จ๊ณ„ ๋ชจ๋ธ ๋กœ๋“œ: {model_dir}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_dir).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# --- 3. ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ ---
df_test = get_test_data(config)
if df_test.empty: return
# ๋ผ๋ฒจ ์ธ์ฝ”๋”ฉ (๋ชจ๋ธ config์—์„œ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ)
label_to_id = model.config.label2id
id_to_label = model.config.id2label
print(f"๋ชจ๋ธ์˜ ๋ผ๋ฒจ ๋งต ๋กœ๋“œ: {label_to_id}")
df_test['label'] = df_test['group_emotion'].map(label_to_id)
# NaN ๋ผ๋ฒจ์ด ์žˆ๋Š”์ง€ ํ™•์ธ (test.json์— ํ›ˆ๋ จ ์‹œ ์—†๋˜ ๋ผ๋ฒจ์ด ์žˆ์„ ๊ฒฝ์šฐ)
if df_test['label'].isnull().any():
print("๊ฒฝ๊ณ : Test set์— ํ›ˆ๋ จ ์‹œ ์—†๋˜ ๋ผ๋ฒจ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ•ด๋‹น ๋ฐ์ดํ„ฐ๋Š” ํ‰๊ฐ€์—์„œ ์ œ์™ธ๋ฉ๋‹ˆ๋‹ค.")
df_test.dropna(subset=['label'], inplace=True)
df_test['label'] = df_test['label'].astype(int)
test_encodings = tokenizer(list(df_test['cleaned_text']), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt")
test_dataset = EmotionDataset(test_encodings, df_test['label'].tolist())
# --- 4. Trainer ์„ค์ • (ํ‰๊ฐ€ ์ „์šฉ) ---
training_args = TrainingArguments(
output_dir=output_dir,
per_device_eval_batch_size=config.eval_batch_size,
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics
)
# --- 5. ํ‰๊ฐ€ ์‹คํ–‰ ๋ฐ ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ ---
print("\n--- (2) Test Set ์ตœ์ข… ํ‰๊ฐ€ (์ˆ˜๋Šฅ) ---")
test_predictions = trainer.predict(test_dataset)
test_metrics = test_predictions.metrics
print(f"์ตœ์ข… Test ํ‰๊ฐ€ ๊ฒฐ๊ณผ: {test_metrics}")
# (์ด ๋ถ€๋ถ„์€ ์ด๋ฏธ ์„ฑ๊ณตํ–ˆ์œผ๋ฏ€๋กœ ์ค‘๋ณต ์ €์žฅ)
results_path = os.path.join(output_dir, "TEST_evaluation_results.json")
with open(results_path, "w", encoding='utf-8') as f:
json.dump(test_metrics, f, indent=4, ensure_ascii=False)
print(f"*** ์ตœ์ข… Test ํ‰๊ฐ€ ๊ฒฐ๊ณผ๊ฐ€ {results_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ***")
print("\n--- ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ (Test Set) ---")
try:
y_pred = test_predictions.predictions.argmax(-1)
y_true = test_predictions.label_ids
labels = [id_to_label[i] for i in sorted(id_to_label.keys())]
cm = confusion_matrix(y_true, y_pred, labels=[label_to_id[l] for l in labels])
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel('์˜ˆ์ธก ๋ผ๋ฒจ (Predicted Label)')
plt.ylabel('์‹ค์ œ ๋ผ๋ฒจ (True Label)')
plt.title('Confusion Matrix (TEST Set - 3 Groups)')
cm_path = os.path.join(output_dir, "TEST_confusion_matrix.png")
plt.savefig(cm_path)
print(f"Test Set ํ˜ผ๋™ ํ–‰๋ ฌ์ด {cm_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
print("--- ํ‰๊ฐ€ ๋ฐ ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ ์™„๋ฃŒ ---")
except Exception as e:
print("\n!!! ์น˜๋ช…์  ์˜ค๋ฅ˜: ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ ์‹คํŒจ !!!")
print(f"์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€: {e}")
print("matplotlib, seaborn, ๋˜๋Š” ํ•œ๊ธ€ ํฐํŠธ ์„ค์ •์„ ํ™•์ธํ•˜์„ธ์š”.")
if __name__ == "__main__":
run_evaluation()