Spaces:
Running
Running
| # ํ์ผ ์ด๋ฆ: train_final.py | |
| import os | |
| import pandas as pd | |
| import json | |
| import re | |
| import torch | |
| import numpy as np | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments | |
| ) | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.utils import resample | |
| from typing import Dict, List, Tuple | |
| from dataclasses import dataclass | |
| import platform | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # --- Matplotlib ํ๊ธ ํฐํธ ์ค์ --- | |
| try: | |
| if platform.system() == 'Windows': | |
| plt.rc('font', family='Malgun Gothic') | |
| elif platform.system() == 'Darwin': # Mac OS | |
| plt.rc('font', family='AppleGothic') | |
| else: # Linux (์ฝ๋ฉ ๋ฑ) | |
| plt.rc('font', family='NanumBarunGothic') | |
| plt.rcParams['axes.unicode_minus'] = False | |
| except Exception as e: | |
| print(f"ํ๊ธ ํฐํธ ์ค์ ๊ฒฝ๊ณ : {e}. ํผ๋ ํ๋ ฌ์ ๋ผ๋ฒจ์ด ๊นจ์ง ์ ์์ต๋๋ค.") | |
| # --- 1. ์ค์ ๋ถ --- | |
| class TrainingConfig: | |
| mode: str = "emotion" | |
| data_dir: str = "./data" | |
| output_dir: str = "./results1024" | |
| base_model_name: str = "klue/roberta-base" | |
| eval_batch_size: int = 64 | |
| num_train_epochs: int = 3 | |
| learning_rate: float = 2e-5 | |
| train_batch_size: int = 16 | |
| weight_decay: float = 0.01 | |
| max_length: int = 128 | |
| warmup_ratio: float = 0.1 | |
| def get_model_name(self) -> str: | |
| if self.mode == 'emotion': | |
| if not os.path.exists(self.base_model_name): | |
| print(f"๊ฒฝ๊ณ : 1์ฐจ ํ์ต๋ NSMC ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค ({self.base_model_name})") | |
| print("๊ธฐ๋ณธ 'klue/roberta-base' ๋ชจ๋ธ๋ก ๋์ ํ์ต์ ์๋ํฉ๋๋ค.") | |
| return "klue/roberta-base" | |
| print(f"1์ฐจ ํ์ต๋ ๋ชจ๋ธ ๋ก๋: {self.base_model_name}") | |
| return self.base_model_name | |
| return "klue/roberta-base" | |
| def get_output_dir(self) -> str: | |
| # ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก ์์ | |
| return os.path.join(self.output_dir, 'emotion_model_6class_oversampled') | |
| # --- 2. ์ปค์คํ ํด๋์ค ๋ฐ ํจ์ --- | |
| class EmotionDataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: val[idx].clone().detach() for key, val in self.encodings.items()} | |
| item['labels'] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.labels) | |
| def compute_metrics(pred): | |
| labels = pred.label_ids | |
| preds = pred.predictions.argmax(-1) | |
| acc = accuracy_score(labels, preds) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0) | |
| return {'accuracy': acc, 'f1': f1} | |
| def clean_text(text: str) -> str: | |
| return re.sub(r'[^๊ฐ-ํฃa-zA-Z0-9 ]', '', str(text)) | |
| # --- 3. ๋ฐ์ดํฐ ๋ก๋ --- | |
| def map_ecode_to_6class(e_code_str): | |
| """E์ฝ๋("E18")๋ฅผ 6-Class("๋ถ๋ ธ")๋ก ๋งคํํ๋ ํจ์""" | |
| if not isinstance(e_code_str, str) or not e_code_str.startswith('E'): return None | |
| try: code_num = int(e_code_str[1:]) | |
| except (ValueError, TypeError): return None | |
| if 10 <= code_num <= 19: return '๋ถ๋ ธ' | |
| elif 20 <= code_num <= 29: return '์ฌํ' | |
| elif 30 <= code_num <= 39: return '๋ถ์' | |
| elif 40 <= code_num <= 49: return '์์ฒ' | |
| elif 50 <= code_num <= 59: return '๋นํฉ' | |
| elif 60 <= code_num <= 69: return '๊ธฐ์จ' | |
| else: return None | |
| def load_and_process(text_file, label_file, data_dir): | |
| """Excel(ํ ์คํธ)๊ณผ JSON(๋ผ๋ฒจ)์ ๋ณํฉํ๊ณ ์ ์ฒ๋ฆฌํ๋ ํฌํผ ํจ์""" | |
| text_path = os.path.join(data_dir, text_file) | |
| label_path = os.path.join(data_dir, label_file) | |
| try: | |
| df_text = pd.read_excel(text_path, header=None) | |
| with open(label_path, 'r', encoding='utf-8') as f: | |
| labels_raw = json.load(f) | |
| except FileNotFoundError as e: | |
| print(f"์ค๋ฅ: ํ์ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {e}") | |
| return pd.DataFrame() | |
| e_codes = [] | |
| for dialogue in labels_raw: | |
| try: e_codes.append(dialogue['profile']['emotion']['type']) # "E18" | |
| except KeyError: e_codes.append(None) | |
| if len(df_text) != len(e_codes): | |
| min_len = min(len(df_text), len(e_codes)) | |
| print(f"๊ฒฝ๊ณ : {text_file}๊ณผ {label_file} ์ค ์ ๋ถ์ผ์น. {min_len}๊ฐ๋ก ์ถ์ํฉ๋๋ค.") | |
| df_text = df_text.iloc[:min_len] | |
| e_codes = e_codes[:min_len] | |
| df_labels = pd.DataFrame({'e_code': e_codes}) | |
| df_combined = pd.concat([df_text, df_labels], axis=1) | |
| dialogue_cols = [8, 9, 10, 11] | |
| for col in dialogue_cols: | |
| df_combined[col] = df_combined[col].astype(str).fillna('') | |
| df_combined['text'] = df_combined[dialogue_cols].apply(lambda row: ' '.join(row), axis=1) | |
| df_combined['cleaned_text'] = df_combined['text'].apply(clean_text) | |
| df_combined['major_emotion'] = df_combined['e_code'].apply(map_ecode_to_6class) | |
| df_combined.dropna(subset=['major_emotion', 'cleaned_text'], inplace=True) | |
| df_combined = df_combined[df_combined['cleaned_text'].str.strip() != ''] | |
| return df_combined | |
| def get_data(config: TrainingConfig) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
| """ | |
| Train/Val/Test 3-Set์ ๋ก๋ํ๊ณ , Train Set์ Oversampling์ ์ ์ฉ | |
| """ | |
| if config.mode != 'emotion': | |
| print("์ด ์คํฌ๋ฆฝํธ๋ 'emotion' ๋ชจ๋ ์ ์ฉ์ ๋๋ค.") | |
| return pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| print("--- ๊ฐ์ ๋ฐ์ดํฐ ๋ก๋ฉ (Train/Validation/Test) ---") | |
| df_full_train = load_and_process("training-origin.xlsx", "training-label.json", config.data_dir) | |
| df_test = load_and_process("validation-origin.xlsx", "test.json", config.data_dir) | |
| if df_full_train.empty or df_test.empty: | |
| print("์ค๋ฅ: ๋ฐ์ดํฐ ๋ก๋ฉ ์คํจ.") | |
| return pd.DataFrame(), pd.DataFrame(), pd.DataFrame() | |
| print(f"Full Train data loaded: {len(df_full_train)} rows") | |
| print(f"Test data loaded: {len(df_test)} rows") | |
| print("Splitting Full Train into New Train (90%) and New Validation (10%)...") | |
| df_train, df_val = train_test_split( | |
| df_full_train, | |
| test_size=0.1, | |
| random_state=42, | |
| stratify=df_full_train['major_emotion'] | |
| ) | |
| print("\n--- [Oversampling] New Train 6-Class (์๋ณธ ๋ถํฌ) ---") | |
| print(df_train['major_emotion'].value_counts()) | |
| # --- [์ค๋ฒ์ํ๋ง ์์] --- | |
| print("\n--- '๊ธฐ์จ' ํด๋์ค ์ค๋ฒ์ํ๋ง ์ ์ฉ ์ค ---") | |
| # 1. '๊ธฐ์จ'๊ณผ ๋๋จธ์ง ๋ถ๋ฆฌ | |
| df_train_joy = df_train[df_train['major_emotion'] == '๊ธฐ์จ'] | |
| df_train_others = df_train[df_train['major_emotion'] != '๊ธฐ์จ'] | |
| # 2. '๊ธฐ์จ'์ ์ ์ธํ 5๊ฐ ํด๋์ค์ ํ๊ท ๊ฐ์ ๊ณ์ฐ | |
| target_count = int(df_train_others['major_emotion'].value_counts().mean()) | |
| print(f" '๊ธฐ์จ' ์๋ณธ: {len(df_train_joy)}๊ฐ") | |
| print(f" ๋ค๋ฅธ ํด๋์ค ํ๊ท (ํ๊ฒ): {target_count}๊ฐ") | |
| # 3. '๊ธฐ์จ'์ ํ๊ฒ ๊ฐ์๋งํผ ๋ณต์ (with replacement) | |
| df_joy_oversampled = resample( | |
| df_train_joy, | |
| replace=True, | |
| n_samples=target_count, # ํ๊ฒ ๊ฐ์ | |
| random_state=42 | |
| ) | |
| # 4. ๋๋จธ์ง ๋ฐ์ดํฐ์ ๋ณต์ ๋ '๊ธฐ์จ' ๋ฐ์ดํฐ๋ฅผ ๋ค์ ํฉ์นจ | |
| df_train = pd.concat([df_train_others, df_joy_oversampled]) | |
| # 5. ๋ฐ์ดํฐ์ ์ ๋ค์ ์์ด์ค | |
| df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True) | |
| print("\n--- [Oversampling] New Train 6-Class (์ต์ข ๋ถํฌ) ---") | |
| print(df_train['major_emotion'].value_counts()) | |
| # --- [์ค๋ฒ์ํ๋ง ๋] --- | |
| print(f"\nNew Train set (Oversampled) size: {len(df_train)}") | |
| print(f"New Validation set (Original) size: {len(df_val)}") | |
| return df_train, df_val, df_test | |
| # --- 4. ๋ฉ์ธ ์คํ ํจ์ --- | |
| def run_training(): | |
| config = TrainingConfig() | |
| df_train, df_val, df_test = get_data(config) | |
| if df_train.empty or df_val.empty or df_test.empty: | |
| print("\n์ค๋ฅ: ๋ฐ์ดํฐ๊ฐ ๋น์ด์์ด ํ๋ จ์ ์ค๋จํฉ๋๋ค.") | |
| return | |
| text_column = 'cleaned_text' | |
| label_column_str = 'major_emotion' | |
| model_name_to_load = config.get_model_name() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_to_load) | |
| unique_labels = sorted(df_train[label_column_str].unique()) | |
| label_to_id = {label: i for i, label in enumerate(unique_labels)} | |
| id_to_label = {i: label for label, i in label_to_id.items()} | |
| print(f"\n๋ผ๋ฒจ ์ธ์ฝ๋ฉ ๋งต (6-Class): {label_to_id}") | |
| df_train['label'] = df_train[label_column_str].map(label_to_id) | |
| df_val['label'] = df_val[label_column_str].map(label_to_id) | |
| df_test['label'] = df_test[label_column_str].map(label_to_id) | |
| print("๋ฐ์ดํฐ์ ํ ํฌ๋์ด์ง ์ค...") | |
| train_encodings = tokenizer(list(df_train[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt") | |
| val_encodings = tokenizer(list(df_val[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt") | |
| test_encodings = tokenizer(list(df_test[text_column]), max_length=config.max_length, padding=True, truncation=True, return_tensors="pt") | |
| train_dataset = EmotionDataset(train_encodings, df_train['label'].tolist()) | |
| val_dataset = EmotionDataset(val_encodings, df_val['label'].tolist()) | |
| test_dataset = EmotionDataset(test_encodings, df_test['label'].tolist()) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"\nUsing device: {device}") | |
| print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name_to_load, | |
| num_labels=len(unique_labels), | |
| id2label=id_to_label, | |
| label2id=label_to_id, | |
| ignore_mismatched_sizes=True | |
| ).to(device) | |
| output_dir = config.get_output_dir() | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=config.num_train_epochs, | |
| per_device_train_batch_size=config.train_batch_size, | |
| per_device_eval_batch_size=config.eval_batch_size, | |
| learning_rate=config.learning_rate, | |
| weight_decay=config.weight_decay, | |
| warmup_ratio=config.warmup_ratio, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| report_to="none" | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| compute_metrics=compute_metrics | |
| ) | |
| print(f"\n 6-Class (Oversampled) ๋ชจ๋ธ ํ๋ จ์ ์์ํฉ๋๋ค...") | |
| trainer.train() | |
| print("\n ๋ชจ๋ธ ํ๋ จ ์๋ฃ!") | |
| final_model_path = os.path.join(output_dir, "best_model") | |
| trainer.save_model(final_model_path) | |
| tokenizer.save_pretrained(final_model_path) | |
| print(f"์ต์ข ๋ชจ๋ธ(Best)๊ณผ ํ ํฌ๋์ด์ ๊ฐ {final_model_path} ๊ฒฝ๋ก์ ์ ์ฅ๋์์ต๋๋ค.") | |
| print("\n--- (1) Validation Set ํ๊ฐ (๋ชจ์๊ณ ์ฌ) ---") | |
| val_results = trainer.evaluate(eval_dataset=val_dataset) | |
| print(f"์ต์ข Validation ํ๊ฐ ๊ฒฐ๊ณผ: {val_results}") | |
| results_path = os.path.join(output_dir, "validation_evaluation_results.json") | |
| with open(results_path, "w", encoding='utf-8') as f: | |
| json.dump(val_results, f, indent=4, ensure_ascii=False) | |
| print(f"Validation ํ๊ฐ ๊ฒฐ๊ณผ๊ฐ {results_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| print("\n--- (2) Test Set ์ต์ข ํ๊ฐ (์๋ฅ) ---") | |
| test_predictions = trainer.predict(test_dataset) | |
| test_metrics = test_predictions.metrics | |
| print(f"์ต์ข Test ํ๊ฐ ๊ฒฐ๊ณผ: {test_metrics}") | |
| results_path = os.path.join(output_dir, "TEST_evaluation_results.json") | |
| with open(results_path, "w", encoding='utf-8') as f: | |
| json.dump(test_metrics, f, indent=4, ensure_ascii=False) | |
| print(f"*** ์ต์ข Test ํ๊ฐ ๊ฒฐ๊ณผ๊ฐ {results_path}์ ์ ์ฅ๋์์ต๋๋ค. ***") | |
| print("\n--- ํผ๋ ํ๋ ฌ ์์ฑ (Test Set) ---") | |
| try: | |
| y_pred = test_predictions.predictions.argmax(-1) | |
| y_true = test_predictions.label_ids | |
| labels = [id_to_label[i] for i in sorted(id_to_label.keys())] | |
| cm = confusion_matrix(y_true, y_pred, labels=[label_to_id[l] for l in labels]) | |
| plt.figure(figsize=(10, 8)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels) | |
| plt.xlabel('์์ธก ๋ผ๋ฒจ (Predicted Label)') | |
| plt.ylabel('์ค์ ๋ผ๋ฒจ (True Label)') | |
| plt.title('Confusion Matrix (TEST Set - 6-Class Oversampled)') | |
| cm_path = os.path.join(output_dir, "TEST_confusion_matrix.png") | |
| plt.savefig(cm_path) | |
| print(f"Test Set ํผ๋ ํ๋ ฌ์ด {cm_path}์ ์ ์ฅ๋์์ต๋๋ค.") | |
| except Exception as e: | |
| print(f"\n!!! ์ค๋ฅ: ํผ๋ ํ๋ ฌ ์์ฑ ์คํจ: {e} !!!") | |
| print("matplotlib, seaborn ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์๋์ง ํ์ธํ์ธ์.") | |
| if __name__ == "__main__": | |
| print("--- 6-Class (Oversampling) ๊ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ ํ์ต ์์ ---") | |
| run_training() |