diminch's picture
Deploy V15 Clean (Removed binary files history)
d939bae
import json
import numpy as np
from datasets import Dataset, DatasetDict
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
EvalPrediction
)
from sklearn.metrics import mean_squared_error, mean_absolute_error
from huggingface_hub import HfFolder, notebook_login
MODEL_NAME = "roberta-base"
DATASET_PATH = "/content/data/dataset_for_scorer.json"
MODEL_OUTPUT_DIR = "./ielts_grader_model"
HUB_MODEL_ID = "diminch/ielts-grader-ai"
def load_and_prepare_data(dataset_path):
print(f"Đang tải dữ liệu từ {dataset_path}...")
with open(dataset_path, "r", encoding="utf-8") as f:
raw_data = json.load(f)
processed_data = []
for item in raw_data:
text = item['prompt_text'] + " [SEP] " + item['essay_text']
labels = [
float(item['scores']['task_response']),
float(item['scores']['coherence_cohesion']),
float(item['scores']['lexical_resource']),
float(item['scores']['grammatical_range'])
]
processed_data.append({"text": text, "label": labels})
print(f"Tổng cộng {len(processed_data)} mẫu.")
dataset = Dataset.from_list(processed_data)
train_test_split = dataset.train_test_split(test_size=0.1)
dataset_dict = DatasetDict({
'train': train_test_split['train'],
'test': train_test_split['test']
})
return dataset_dict
def tokenize_data(dataset_dict, tokenizer):
print("Đang tokenize dữ liệu...")
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding="max_length",
truncation=True,
max_length=512
)
tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
return tokenized_datasets
def compute_metrics(p: EvalPrediction):
preds = p.predictions
labels = p.label_ids
rmse_tr = np.sqrt(mean_squared_error(labels[:, 0], preds[:, 0]))
rmse_cc = np.sqrt(mean_squared_error(labels[:, 1], preds[:, 1]))
rmse_lr = np.sqrt(mean_squared_error(labels[:, 2], preds[:, 2]))
rmse_gra = np.sqrt(mean_squared_error(labels[:, 3], preds[:, 3]))
mae_tr = mean_absolute_error(labels[:, 0], preds[:, 0])
mae_cc = mean_absolute_error(labels[:, 1], preds[:, 1])
mae_lr = mean_absolute_error(labels[:, 2], preds[:, 2])
mae_gra = mean_absolute_error(labels[:, 3], preds[:, 3])
avg_rmse = np.mean([rmse_tr, rmse_cc, rmse_lr, rmse_gra])
return {
"avg_rmse": avg_rmse,
"rmse_task_response": rmse_tr,
"rmse_coherence_cohesion": rmse_cc,
"rmse_lexical_resource": rmse_lr,
"rmse_grammatical_range": rmse_gra,
"mae_task_response": mae_tr,
"mae_coherence_cohesion": mae_cc,
# ... có thể thêm các MAE khác
}
def main():
print("Vui lòng dán token Hugging Face (quyền 'write') của bạn:")
# (Nếu chạy trên Colab, nó sẽ hiện ô input)
# notebook_login()
# Hoặc nếu chạy local, dùng 'huggingface-cli login' trước
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
dataset_dict = load_and_prepare_data(DATASET_PATH)
tokenized_datasets = tokenize_data(dataset_dict, tokenizer)
print("Đang tải mô hình nền tảng...")
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=4,
problem_type="regression"
)
training_args = TrainingArguments(
output_dir=MODEL_OUTPUT_DIR,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
eval_strategy="epoch", # Changed evaluation_strategy to eval_strategy
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="avg_rmse",
greater_is_better=False,
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_strategy="end",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
compute_metrics=compute_metrics,
tokenizer=tokenizer,
)
print("--- BẮT ĐẦU HUẤN LUYỆN ---")
trainer.train()
print("--- HUẤN LUYỆN HOÀN TẤT ---")
print("--- ĐÁNH GIÁ TRÊN TẬP TEST ---")
eval_results = trainer.evaluate()
print(json.dumps(eval_results, indent=2))
print("Đang đẩy model tốt nhất lên Hugging Face Hub...")
trainer.push_to_hub()
print(f"Hoàn tất! Model của bạn đã ở trên Hub: https://huggingface.co/{HUB_MODEL_ID}")
if __name__ == "__main__":
import os
if not os.path.exists(DATASET_PATH):
print(f"LỖI: Không tìm thấy file {DATASET_PATH}.")
else:
main()