emotion-chatbot-app / scripts /train_final.py
hfexample's picture
Deploy clean snapshot of the repository
e221c83
# ํŒŒ์ผ ์ด๋ฆ„: train_final.py
import os
import pandas as pd
import json
import re
import torch
import numpy as np
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from typing import Dict, List, Tuple
from dataclasses import dataclass
import platform
import matplotlib.pyplot as plt
import seaborn as sns
# --- Matplotlib ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ---
try:
if platform.system() == 'Windows':
plt.rc('font', family='Malgun Gothic')
elif platform.system() == 'Darwin': # Mac OS
plt.rc('font', family='AppleGothic')
else: # Linux (์ฝ”๋žฉ ๋“ฑ)
plt.rc('font', family='NanumBarunGothic')
plt.rcParams['axes.unicode_minus'] = False
except Exception as e:
print(f"ํ•œ๊ธ€ ํฐํŠธ ์„ค์ • ๊ฒฝ๊ณ : {e}. ํ˜ผ๋™ ํ–‰๋ ฌ์˜ ๋ผ๋ฒจ์ด ๊นจ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
# --- 1. ์„ค์ •๋ถ€ ---
@dataclass
class TrainingConfig:
mode: str = "emotion"
data_dir: str = "./data"
output_dir: str = "./results1024"
base_model_name: str = "klue/roberta-base"
eval_batch_size: int = 64
num_train_epochs: int = 3
learning_rate: float = 2e-5
train_batch_size: int = 16
weight_decay: float = 0.01
max_length: int = 128
warmup_ratio: float = 0.1
def get_model_name(self) -> str:
if self.mode == 'emotion':
if not os.path.exists(self.base_model_name):
print(f"๊ฒฝ๊ณ : 1์ฐจ ํ•™์Šต๋œ NSMC ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค ({self.base_model_name})")
print("๊ธฐ๋ณธ 'klue/roberta-base' ๋ชจ๋ธ๋กœ ๋Œ€์‹  ํ•™์Šต์„ ์‹œ๋„ํ•ฉ๋‹ˆ๋‹ค.")
return "klue/roberta-base"
print(f"1์ฐจ ํ•™์Šต๋œ ๋ชจ๋ธ ๋กœ๋“œ: {self.base_model_name}")
return self.base_model_name
return "klue/roberta-base"
def get_output_dir(self) -> str:
# ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ ์ˆ˜์ •
return os.path.join(self.output_dir, 'emotion_model_6class_oversampled')
# --- 2. ์ปค์Šคํ…€ ํด๋ž˜์Šค ๋ฐ ํ•จ์ˆ˜ ---
class EmotionDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
acc = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0)
return {'accuracy': acc, 'f1': f1}
def clean_text(text: str) -> str:
return re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9 ]', '', str(text))
# --- 3. ๋ฐ์ดํ„ฐ ๋กœ๋” ---
def map_ecode_to_6class(e_code_str):
"""E์ฝ”๋“œ("E18")๋ฅผ 6-Class("๋ถ„๋…ธ")๋กœ ๋งคํ•‘ํ•˜๋Š” ํ•จ์ˆ˜"""
if not isinstance(e_code_str, str) or not e_code_str.startswith('E'): return None
try: code_num = int(e_code_str[1:])
except (ValueError, TypeError): return None
if 10 <= code_num <= 19: return '๋ถ„๋…ธ'
elif 20 <= code_num <= 29: return '์Šฌํ””'
elif 30 <= code_num <= 39: return '๋ถˆ์•ˆ'
elif 40 <= code_num <= 49: return '์ƒ์ฒ˜'
elif 50 <= code_num <= 59: return '๋‹นํ™ฉ'
elif 60 <= code_num <= 69: return '๊ธฐ์จ'
else: return None
def load_and_process(text_file, label_file, data_dir):
"""Excel(ํ…์ŠคํŠธ)๊ณผ JSON(๋ผ๋ฒจ)์„ ๋ณ‘ํ•ฉํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•˜๋Š” ํ—ฌํผ ํ•จ์ˆ˜"""
text_path = os.path.join(data_dir, text_file)
label_path = os.path.join(data_dir, label_file)
try:
df_text = pd.read_excel(text_path, header=None)
with open(label_path, 'r', encoding='utf-8') as f:
labels_raw = json.load(f)
except FileNotFoundError as e:
print(f"์˜ค๋ฅ˜: ํ•„์ˆ˜ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {e}")
return pd.DataFrame()
e_codes = []
for dialogue in labels_raw:
try: e_codes.append(dialogue['profile']['emotion']['type']) # "E18"
except KeyError: e_codes.append(None)
if len(df_text) != len(e_codes):
min_len = min(len(df_text), len(e_codes))
print(f"๊ฒฝ๊ณ : {text_file}๊ณผ {label_file} ์ค„ ์ˆ˜ ๋ถˆ์ผ์น˜. {min_len}๊ฐœ๋กœ ์ถ•์†Œํ•ฉ๋‹ˆ๋‹ค.")
df_text = df_text.iloc[:min_len]
e_codes = e_codes[:min_len]
df_labels = pd.DataFrame({'e_code': e_codes})
df_combined = pd.concat([df_text, df_labels], axis=1)
dialogue_cols = [8, 9, 10, 11]
for col in dialogue_cols:
df_combined[col] = df_combined[col].astype(str).fillna('')
df_combined['text'] = df_combined[dialogue_cols].apply(lambda row: ' '.join(row), axis=1)
df_combined['cleaned_text'] = df_combined['text'].apply(clean_text)
df_combined['major_emotion'] = df_combined['e_code'].apply(map_ecode_to_6class)
df_combined.dropna(subset=['major_emotion', 'cleaned_text'], inplace=True)
df_combined = df_combined[df_combined['cleaned_text'].str.strip() != '']
return df_combined
def get_data(config: TrainingConfig) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Train/Val/Test 3-Set์„ ๋กœ๋“œํ•˜๊ณ , Train Set์— Oversampling์„ ์ ์šฉ
"""
if config.mode != 'emotion':
print("์ด ์Šคํฌ๋ฆฝํŠธ๋Š” 'emotion' ๋ชจ๋“œ ์ „์šฉ์ž…๋‹ˆ๋‹ค.")
return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
print("--- ๊ฐ์ • ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ (Train/Validation/Test) ---")
df_full_train = load_and_process("training-origin.xlsx", "training-label.json", config.data_dir)
df_test = load_and_process("validation-origin.xlsx", "test.json", config.data_dir)
if df_full_train.empty or df_test.empty:
print("์˜ค๋ฅ˜: ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์‹คํŒจ.")
return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
print(f"Full Train data loaded: {len(df_full_train)} rows")
print(f"Test data loaded: {len(df_test)} rows")
print("Splitting Full Train into New Train (90%) and New Validation (10%)...")
df_train, df_val = train_test_split(
df_full_train,
test_size=0.1,
random_state=42,
stratify=df_full_train['major_emotion']
)
print("\n--- [Oversampling] New Train 6-Class (์›๋ณธ ๋ถ„ํฌ) ---")
print(df_train['major_emotion'].value_counts())
# --- [์˜ค๋ฒ„์ƒ˜ํ”Œ๋ง ์‹œ์ž‘] ---
print("\n--- '๊ธฐ์จ' ํด๋ž˜์Šค ์˜ค๋ฒ„์ƒ˜ํ”Œ๋ง ์ ์šฉ ์ค‘ ---")
# 1. '๊ธฐ์จ'๊ณผ ๋‚˜๋จธ์ง€ ๋ถ„๋ฆฌ
df_train_joy = df_train[df_train['major_emotion'] == '๊ธฐ์จ']
df_train_others = df_train[df_train['major_emotion'] != '๊ธฐ์จ']
# 2. '๊ธฐ์จ'์„ ์ œ์™ธํ•œ 5๊ฐœ ํด๋ž˜์Šค์˜ ํ‰๊ท  ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ
target_count = int(df_train_others['major_emotion'].value_counts().mean())
print(f" '๊ธฐ์จ' ์›๋ณธ: {len(df_train_joy)}๊ฐœ")
print(f" ๋‹ค๋ฅธ ํด๋ž˜์Šค ํ‰๊ท (ํƒ€๊ฒŸ): {target_count}๊ฐœ")
# 3. '๊ธฐ์จ'์„ ํƒ€๊ฒŸ ๊ฐœ์ˆ˜๋งŒํผ ๋ณต์ œ (with replacement)
df_joy_oversampled = resample(
df_train_joy,
replace=True,
n_samples=target_count, # ํƒ€๊ฒŸ ๊ฐœ์ˆ˜
random_state=42
)
# 4. ๋‚˜๋จธ์ง€ ๋ฐ์ดํ„ฐ์™€ ๋ณต์ œ๋œ '๊ธฐ์จ' ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์‹œ ํ•ฉ์นจ
df_train = pd.concat([df_train_others, df_joy_oversampled])
# 5. ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์‹œ ์„ž์–ด์คŒ
df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)
print("\n--- [Oversampling] New Train 6-Class (์ตœ์ข… ๋ถ„ํฌ) ---")
print(df_train['major_emotion'].value_counts())
# --- [์˜ค๋ฒ„์ƒ˜ํ”Œ๋ง ๋] ---
print(f"\nNew Train set (Oversampled) size: {len(df_train)}")
print(f"New Validation set (Original) size: {len(df_val)}")
return df_train, df_val, df_test
# --- 4. ๋ฉ”์ธ ์‹คํ–‰ ํ•จ์ˆ˜ ---
def run_training():
config = TrainingConfig()
df_train, df_val, df_test = get_data(config)
if df_train.empty or df_val.empty or df_test.empty:
print("\n์˜ค๋ฅ˜: ๋ฐ์ดํ„ฐ๊ฐ€ ๋น„์–ด์žˆ์–ด ํ›ˆ๋ จ์„ ์ค‘๋‹จํ•ฉ๋‹ˆ๋‹ค.")
return
text_column = 'cleaned_text'
label_column_str = 'major_emotion'
model_name_to_load = config.get_model_name()
tokenizer = AutoTokenizer.from_pretrained(model_name_to_load)
unique_labels = sorted(df_train[label_column_str].unique())
label_to_id = {label: i for i, label in enumerate(unique_labels)}
id_to_label = {i: label for label, i in label_to_id.items()}
print(f"\n๋ผ๋ฒจ ์ธ์ฝ”๋”ฉ ๋งต (6-Class): {label_to_id}")
df_train['label'] = df_train[label_column_str].map(label_to_id)
df_val['label'] = df_val[label_column_str].map(label_to_id)
df_test['label'] = df_test[label_column_str].map(label_to_id)
print("๋ฐ์ดํ„ฐ์…‹ ํ† ํฌ๋‚˜์ด์ง• ์ค‘...")
train_encodings = tokenizer(list(df_train[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt")
val_encodings = tokenizer(list(df_val[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt")
test_encodings = tokenizer(list(df_test[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt")
train_dataset = EmotionDataset(train_encodings, df_train['label'].tolist())
val_dataset = EmotionDataset(val_encodings, df_val['label'].tolist())
test_dataset = EmotionDataset(test_encodings, df_test['label'].tolist())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
model = AutoModelForSequenceClassification.from_pretrained(
model_name_to_load,
num_labels=len(unique_labels),
id2label=id_to_label,
label2id=label_to_id,
ignore_mismatched_sizes=True
).to(device)
output_dir = config.get_output_dir()
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.train_batch_size,
per_device_eval_batch_size=config.eval_batch_size,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
print(f"\n 6-Class (Oversampled) ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค...")
trainer.train()
print("\n ๋ชจ๋ธ ํ›ˆ๋ จ ์™„๋ฃŒ!")
final_model_path = os.path.join(output_dir, "best_model")
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"์ตœ์ข… ๋ชจ๋ธ(Best)๊ณผ ํ† ํฌ๋‚˜์ด์ €๊ฐ€ {final_model_path} ๊ฒฝ๋กœ์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
print("\n--- (1) Validation Set ํ‰๊ฐ€ (๋ชจ์˜๊ณ ์‚ฌ) ---")
val_results = trainer.evaluate(eval_dataset=val_dataset)
print(f"์ตœ์ข… Validation ํ‰๊ฐ€ ๊ฒฐ๊ณผ: {val_results}")
results_path = os.path.join(output_dir, "validation_evaluation_results.json")
with open(results_path, "w", encoding='utf-8') as f:
json.dump(val_results, f, indent=4, ensure_ascii=False)
print(f"Validation ํ‰๊ฐ€ ๊ฒฐ๊ณผ๊ฐ€ {results_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
print("\n--- (2) Test Set ์ตœ์ข… ํ‰๊ฐ€ (์ˆ˜๋Šฅ) ---")
test_predictions = trainer.predict(test_dataset)
test_metrics = test_predictions.metrics
print(f"์ตœ์ข… Test ํ‰๊ฐ€ ๊ฒฐ๊ณผ: {test_metrics}")
results_path = os.path.join(output_dir, "TEST_evaluation_results.json")
with open(results_path, "w", encoding='utf-8') as f:
json.dump(test_metrics, f, indent=4, ensure_ascii=False)
print(f"*** ์ตœ์ข… Test ํ‰๊ฐ€ ๊ฒฐ๊ณผ๊ฐ€ {results_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ***")
print("\n--- ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ (Test Set) ---")
try:
y_pred = test_predictions.predictions.argmax(-1)
y_true = test_predictions.label_ids
labels = [id_to_label[i] for i in sorted(id_to_label.keys())]
cm = confusion_matrix(y_true, y_pred, labels=[label_to_id[l] for l in labels])
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel('์˜ˆ์ธก ๋ผ๋ฒจ (Predicted Label)')
plt.ylabel('์‹ค์ œ ๋ผ๋ฒจ (True Label)')
plt.title('Confusion Matrix (TEST Set - 6-Class Oversampled)')
cm_path = os.path.join(output_dir, "TEST_confusion_matrix.png")
plt.savefig(cm_path)
print(f"Test Set ํ˜ผ๋™ ํ–‰๋ ฌ์ด {cm_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"\n!!! ์˜ค๋ฅ˜: ํ˜ผ๋™ ํ–‰๋ ฌ ์ƒ์„ฑ ์‹คํŒจ: {e} !!!")
print("matplotlib, seaborn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.")
if __name__ == "__main__":
print("--- 6-Class (Oversampling) ๊ฐ์ • ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šต ์‹œ์ž‘ ---")
run_training()