post_editing_evaluator / interface.py
morgankavanagh's picture
Fixes of Docker, chrf, comet_hf, interface
d130b8e
import gradio as gr
import requests
import json
import os
from evaluator.chrf import calculate_chrf
from evaluator.comet_hf import calculate_comet
from pathlib import Path
# OpenAI API URL and key
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
CHATGPT_MODELS = {
"GPT-4": "gpt-4"
}
def improve_translations(system_prompt, temperature, top_p):
if not OPENAI_API_KEY:
return [], [], "Error: OpenAI API key not found"
try:
# Load data
data_dir = Path(__file__).parent / "evaluator" / "mt_data"
source_sentences = (data_dir / "source_sentences.txt").read_text(encoding="utf-8").splitlines()
beam_search_translations = (data_dir / "beam_search_translations.txt").read_text(encoding="utf-8").splitlines()
reference_translations = (data_dir / "reference_translations.txt").read_text(encoding="utf-8").splitlines()
improved_translations = []
sentence_pairs = [] # To store source, draft 1, draft 2, and reference
for source, target, reference in zip(source_sentences, beam_search_translations, reference_translations):
# Construct the prompt
user_prompt = f"""
As an expert translation post editor, your task is to improve the English translation (Target) for the below German text (Source)
Source: {source}
Target: {target}
Your output should be your improved version of the target text only. Do not add any comments or explanations before or after the improved version of the target text.
"""
# Prepare API payload
payload = {
"model": CHATGPT_MODELS["GPT-4"],
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": temperature,
"top_p": top_p,
"max_tokens": 512
}
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
}
# Call OpenAI API
response = requests.post(OPENAI_API_URL, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
# Extract improved translation
output = data["choices"][0]["message"]["content"]
improved_translation = output.strip()
improved_translations.append(improved_translation)
# Add sentence pair to the list
sentence_pairs.append([source, target, improved_translation, reference])
# Calculate ChrF scores
beam_chrf_scores = [
calculate_chrf(beam_translation, reference)
for beam_translation, reference in zip(beam_search_translations, reference_translations)
]
improved_chrf_scores = [
calculate_chrf(improved_translation, reference)
for improved_translation, reference in zip(improved_translations, reference_translations)
]
# Calculate COMET scores
beam_comet_scores = calculate_comet(source_sentences, beam_search_translations, reference_translations)
improved_comet_scores = calculate_comet(source_sentences, improved_translations, reference_translations)
# Calculate average scores
average_beam_chrf = sum(beam_chrf_scores) / len(beam_chrf_scores)
average_improved_chrf = sum(improved_chrf_scores) / len(improved_chrf_scores)
average_beam_comet = sum(beam_comet_scores) / len(beam_comet_scores)
average_improved_comet = sum(improved_comet_scores) / len(improved_comet_scores)
# Calculate score changes
chrf_change = average_improved_chrf - average_beam_chrf
comet_change = average_improved_comet - average_beam_comet
# Prepare dataframes
sentence_pairs_df = sentence_pairs
scores_df = [
["ChrF", round(average_beam_chrf, 2), round(average_improved_chrf, 2), round(chrf_change, 2)],
["COMET", round(average_beam_comet, 2), round(average_improved_comet, 2), round(comet_change, 2)]
]
# Return dataframes and evaluation message without division
evaluation_message = f"ChrF Change: {chrf_change:.2f}, COMET Change: {comet_change:.2f}"
return sentence_pairs_df, scores_df, evaluation_message
except Exception as e:
return [], [], f"Error: {str(e)}"
# Gradio interface
app = gr.Interface(
fn=improve_translations,
inputs=[
gr.Textbox(label="System Prompt", placeholder="Define the assistant's behavior here..."),
gr.Slider(value=1, minimum=0, maximum=1.9, step=0.1, label="Temperature"),
gr.Slider(value=1, minimum=0, maximum=1, step=0.01, label="Top P")
],
outputs=[
gr.Dataframe(headers=["Source text", "Draft 1", "Draft 2", "Reference"], label="Sentence Pairs"),
gr.Dataframe(headers=["Metric", "Draft 1", "Draft 2", "Change"], label="Scores"),
gr.Textbox(label="Evaluation Results")
],
title="Translation Post-Editing and Evaluation",
description="Improve translations using GPT-4 and evaluate the results with ChrF and COMET."
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)