File size: 5,418 Bytes
5d5c713
 
 
 
 
d130b8e
5d5c713
 
 
 
 
 
 
 
 
 
 
d130b8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d5c713
 
3d756ca
bd1f987
5d5c713
 
 
 
 
 
 
 
 
 
 
 
 
 
2961758
d130b8e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import requests
import json
import os
from evaluator.chrf import calculate_chrf
from evaluator.comet_hf import calculate_comet
from pathlib import Path

# OpenAI API URL and key
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

CHATGPT_MODELS = {
    "GPT-4": "gpt-4"
}

def improve_translations(system_prompt, temperature, top_p):
    if not OPENAI_API_KEY:
        return [], [], "Error: OpenAI API key not found"
        
    try:
        # Load data
        data_dir = Path(__file__).parent / "evaluator" / "mt_data"
        source_sentences = (data_dir / "source_sentences.txt").read_text(encoding="utf-8").splitlines()
        beam_search_translations = (data_dir / "beam_search_translations.txt").read_text(encoding="utf-8").splitlines()
        reference_translations = (data_dir / "reference_translations.txt").read_text(encoding="utf-8").splitlines()

        improved_translations = []
        sentence_pairs = []  # To store source, draft 1, draft 2, and reference

        for source, target, reference in zip(source_sentences, beam_search_translations, reference_translations):
            # Construct the prompt
            user_prompt = f"""
            As an expert translation post editor, your task is to improve the English translation (Target) for the below German text (Source)
            Source: {source}
            Target: {target}
            Your output should be your improved version of the target text only. Do not add any comments or explanations before or after the improved version of the target text. 
            """

            # Prepare API payload
            payload = {
                "model": CHATGPT_MODELS["GPT-4"],
                "messages": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                "temperature": temperature,
                "top_p": top_p,
                "max_tokens": 512
            }

            headers = {
                "Authorization": f"Bearer {OPENAI_API_KEY}",
                "Content-Type": "application/json"
            }

            # Call OpenAI API
            response = requests.post(OPENAI_API_URL, headers=headers, json=payload)
            response.raise_for_status()
            data = response.json()

            # Extract improved translation
            output = data["choices"][0]["message"]["content"]
            improved_translation = output.strip()
            improved_translations.append(improved_translation)

            # Add sentence pair to the list
            sentence_pairs.append([source, target, improved_translation, reference])

        # Calculate ChrF scores
        beam_chrf_scores = [
            calculate_chrf(beam_translation, reference)
            for beam_translation, reference in zip(beam_search_translations, reference_translations)
        ]
        improved_chrf_scores = [
            calculate_chrf(improved_translation, reference)
            for improved_translation, reference in zip(improved_translations, reference_translations)
        ]

        # Calculate COMET scores
        beam_comet_scores = calculate_comet(source_sentences, beam_search_translations, reference_translations)
        improved_comet_scores = calculate_comet(source_sentences, improved_translations, reference_translations)

        # Calculate average scores
        average_beam_chrf = sum(beam_chrf_scores) / len(beam_chrf_scores)
        average_improved_chrf = sum(improved_chrf_scores) / len(improved_chrf_scores)
        average_beam_comet = sum(beam_comet_scores) / len(beam_comet_scores)
        average_improved_comet = sum(improved_comet_scores) / len(improved_comet_scores)

        # Calculate score changes
        chrf_change = average_improved_chrf - average_beam_chrf
        comet_change = average_improved_comet - average_beam_comet

        # Prepare dataframes
        sentence_pairs_df = sentence_pairs
        scores_df = [
            ["ChrF", round(average_beam_chrf, 2), round(average_improved_chrf, 2), round(chrf_change, 2)],
            ["COMET", round(average_beam_comet, 2), round(average_improved_comet, 2), round(comet_change, 2)]
        ]

        # Return dataframes and evaluation message without division
        evaluation_message = f"ChrF Change: {chrf_change:.2f}, COMET Change: {comet_change:.2f}"
        return sentence_pairs_df, scores_df, evaluation_message

    except Exception as e:
        return [], [], f"Error: {str(e)}"

# Gradio interface
app = gr.Interface(
    fn=improve_translations, 
    inputs=[
        gr.Textbox(label="System Prompt", placeholder="Define the assistant's behavior here..."),
        gr.Slider(value=1, minimum=0, maximum=1.9, step=0.1, label="Temperature"),
        gr.Slider(value=1, minimum=0, maximum=1, step=0.01, label="Top P")
    ],
    outputs=[
        gr.Dataframe(headers=["Source text", "Draft 1", "Draft 2", "Reference"], label="Sentence Pairs"),
        gr.Dataframe(headers=["Metric", "Draft 1", "Draft 2", "Change"], label="Scores"),
        gr.Textbox(label="Evaluation Results")
    ],
    title="Translation Post-Editing and Evaluation",
    description="Improve translations using GPT-4 and evaluate the results with ChrF and COMET."
)

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)