File size: 5,738 Bytes
1fb8162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import requests
import json
import os
from evaluator.chrf import calculate_chrf
from evaluator.comet import calculate_comet # Import the COMET function
from pathlib import Path
# OpenAI API URL and key
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
# Check for required environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
raise ValueError("OPENAI_API_KEY not found. Please set this environment variable.")
CHATGPT_MODELS = {
"GPT-4": "gpt-4"
}
def improve_translations(system_prompt, temperature, top_p):
# Load data
data_dir = Path(__file__).parent / "evaluator" / "mt_data"
source_sentences = (data_dir / "source_sentences.txt").read_text(encoding="utf-8").splitlines()
beam_search_translations = (data_dir / "beam_search_translations.txt").read_text(encoding="utf-8").splitlines()
reference_translations = (data_dir / "reference_translations.txt").read_text(encoding="utf-8").splitlines()
improved_translations = []
sentence_pairs = [] # To store source, draft 1, draft 2, and reference
for source, target, reference in zip(source_sentences, beam_search_translations, reference_translations):
# Construct the prompt
user_prompt = f"""
As an expert translation post editor, your task is to improve the English translation (Target) for the below German text (Source)
Source: {source}
Target: {target}
Your output should be your improved version of the target text only. Do not add any comments or explanations before or after the improved version of the target text.
"""
# Prepare API payload
payload = {
"model": CHATGPT_MODELS["GPT-4"],
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": temperature,
"top_p": top_p,
"max_tokens": 512
}
headers = {
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
}
# Call OpenAI API
response = requests.post(OPENAI_API_URL, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
# Extract improved translation
output = data["choices"][0]["message"]["content"]
improved_translation = output.split("Improved Translation:")[-1].strip()
improved_translations.append(improved_translation)
# Add sentence pair to the list
sentence_pairs.append([source, target, improved_translation, reference])
# Calculate ChrF scores
beam_chrf_scores = [
calculate_chrf(beam_translation, reference)
for beam_translation, reference in zip(beam_search_translations, reference_translations)
]
improved_chrf_scores = [
calculate_chrf(improved_translation, reference)
for improved_translation, reference in zip(improved_translations, reference_translations)
]
# Calculate COMET scores
beam_comet_scores = calculate_comet(source_sentences, beam_search_translations, reference_translations)
improved_comet_scores = calculate_comet(source_sentences, improved_translations, reference_translations)
# Calculate average scores
average_beam_chrf = sum(beam_chrf_scores) / len(beam_chrf_scores)
average_improved_chrf = sum(improved_chrf_scores) / len(improved_chrf_scores)
average_beam_comet = sum(beam_comet_scores) / len(beam_comet_scores)
average_improved_comet = sum(improved_comet_scores) / len(improved_comet_scores)
# Calculate score changes
chrf_change = average_improved_chrf - average_beam_chrf
comet_change = average_improved_comet - average_beam_comet
# Prepare dataframes
sentence_pairs_df = sentence_pairs # Dataframe for sentence pairs
scores_df = [
["ChrF", round(average_beam_chrf, 2), round(average_improved_chrf, 2), round(chrf_change, 2)],
["COMET", round(average_beam_comet, 2), round(average_improved_comet, 2), round(comet_change, 2)]
]
# Return dataframes and evaluation message
evaluation_message = f"ChrF Change: {(average_improved_chrf/chrf_change):.2f}%, COMET Change: {(average_improved_comet/comet_change):.2f}%"
return sentence_pairs_df, scores_df, evaluation_message
# Gradio interface
app = gr.Interface(
fn=improve_translations,
inputs=[
gr.Textbox(label="System Prompt", placeholder="Define the assistant's behavior here..."),
gr.Slider(value=1, minimum=0, maximum=1.9, step=0.1, label="Temperature"),
gr.Slider(value=1, minimum=0, maximum=1, step=0.01, label="Top P")
],
outputs=[
gr.Dataframe(headers=["Source text", "Draft 1", "Draft 2", "Reference"], label="Sentence Pairs"),
gr.Dataframe(headers=["Metric", "Draft 1", "Draft 2", "Change"], label="Scores"),
gr.Textbox(label="Evaluation Results")
],
title="Translation Post-Editing and Evaluation",
description="Improve translations using GPT-4 and evaluate the results with ChrF and COMET."
)
if __name__ == "__main__":
# Try different ports in case 7860 is occupied
for port in range(7860, 7870):
try:
app.launch(
server_name="127.0.0.1", # localhost
server_port=port,
share=False, # Don't create public URL
debug=True # Show detailed errors
)
break
except OSError:
if port == 7869: # Last attempt
print("Could not find an available port between 7860-7869")
raise
continue |