|
|
import gradio as gr |
|
|
import requests |
|
|
import json |
|
|
import os |
|
|
from evaluator.chrf import calculate_chrf |
|
|
from evaluator.comet_hf import calculate_comet |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions" |
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
CHATGPT_MODELS = { |
|
|
"GPT-4": "gpt-4" |
|
|
} |
|
|
|
|
|
def improve_translations(system_prompt, temperature, top_p): |
|
|
if not OPENAI_API_KEY: |
|
|
return [], [], "Error: OpenAI API key not found" |
|
|
|
|
|
try: |
|
|
|
|
|
data_dir = Path(__file__).parent / "evaluator" / "mt_data" |
|
|
source_sentences = (data_dir / "source_sentences.txt").read_text(encoding="utf-8").splitlines() |
|
|
beam_search_translations = (data_dir / "beam_search_translations.txt").read_text(encoding="utf-8").splitlines() |
|
|
reference_translations = (data_dir / "reference_translations.txt").read_text(encoding="utf-8").splitlines() |
|
|
|
|
|
improved_translations = [] |
|
|
sentence_pairs = [] |
|
|
|
|
|
for source, target, reference in zip(source_sentences, beam_search_translations, reference_translations): |
|
|
|
|
|
user_prompt = f""" |
|
|
As an expert translation post editor, your task is to improve the English translation (Target) for the below German text (Source) |
|
|
Source: {source} |
|
|
Target: {target} |
|
|
Your output should be your improved version of the target text only. Do not add any comments or explanations before or after the improved version of the target text. |
|
|
""" |
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": CHATGPT_MODELS["GPT-4"], |
|
|
"messages": [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"max_tokens": 512 |
|
|
} |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {OPENAI_API_KEY}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
response = requests.post(OPENAI_API_URL, headers=headers, json=payload) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
|
|
|
|
|
|
output = data["choices"][0]["message"]["content"] |
|
|
improved_translation = output.strip() |
|
|
improved_translations.append(improved_translation) |
|
|
|
|
|
|
|
|
sentence_pairs.append([source, target, improved_translation, reference]) |
|
|
|
|
|
|
|
|
beam_chrf_scores = [ |
|
|
calculate_chrf(beam_translation, reference) |
|
|
for beam_translation, reference in zip(beam_search_translations, reference_translations) |
|
|
] |
|
|
improved_chrf_scores = [ |
|
|
calculate_chrf(improved_translation, reference) |
|
|
for improved_translation, reference in zip(improved_translations, reference_translations) |
|
|
] |
|
|
|
|
|
|
|
|
beam_comet_scores = calculate_comet(source_sentences, beam_search_translations, reference_translations) |
|
|
improved_comet_scores = calculate_comet(source_sentences, improved_translations, reference_translations) |
|
|
|
|
|
|
|
|
average_beam_chrf = sum(beam_chrf_scores) / len(beam_chrf_scores) |
|
|
average_improved_chrf = sum(improved_chrf_scores) / len(improved_chrf_scores) |
|
|
average_beam_comet = sum(beam_comet_scores) / len(beam_comet_scores) |
|
|
average_improved_comet = sum(improved_comet_scores) / len(improved_comet_scores) |
|
|
|
|
|
|
|
|
chrf_change = average_improved_chrf - average_beam_chrf |
|
|
comet_change = average_improved_comet - average_beam_comet |
|
|
|
|
|
|
|
|
sentence_pairs_df = sentence_pairs |
|
|
scores_df = [ |
|
|
["ChrF", round(average_beam_chrf, 2), round(average_improved_chrf, 2), round(chrf_change, 2)], |
|
|
["COMET", round(average_beam_comet, 2), round(average_improved_comet, 2), round(comet_change, 2)] |
|
|
] |
|
|
|
|
|
|
|
|
evaluation_message = f"ChrF Change: {chrf_change:.2f}, COMET Change: {comet_change:.2f}" |
|
|
return sentence_pairs_df, scores_df, evaluation_message |
|
|
|
|
|
except Exception as e: |
|
|
return [], [], f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
app = gr.Interface( |
|
|
fn=improve_translations, |
|
|
inputs=[ |
|
|
gr.Textbox(label="System Prompt", placeholder="Define the assistant's behavior here..."), |
|
|
gr.Slider(value=1, minimum=0, maximum=1.9, step=0.1, label="Temperature"), |
|
|
gr.Slider(value=1, minimum=0, maximum=1, step=0.01, label="Top P") |
|
|
], |
|
|
outputs=[ |
|
|
gr.Dataframe(headers=["Source text", "Draft 1", "Draft 2", "Reference"], label="Sentence Pairs"), |
|
|
gr.Dataframe(headers=["Metric", "Draft 1", "Draft 2", "Change"], label="Scores"), |
|
|
gr.Textbox(label="Evaluation Results") |
|
|
], |
|
|
title="Translation Post-Editing and Evaluation", |
|
|
description="Improve translations using GPT-4 and evaluate the results with ChrF and COMET." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.launch(server_name="0.0.0.0", server_port=7860) |