File size: 5,738 Bytes
1fb8162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import requests
import json
import os
from evaluator.chrf import calculate_chrf
from evaluator.comet import calculate_comet  # Import the COMET function
from pathlib import Path

# OpenAI API URL and key
OPENAI_API_URL = "https://api.openai.com/v1/chat/completions"
# Check for required environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("OPENAI_API_KEY not found. Please set this environment variable.")


CHATGPT_MODELS = {
    "GPT-4": "gpt-4"
}

def improve_translations(system_prompt, temperature, top_p):
    # Load data
    data_dir = Path(__file__).parent / "evaluator" / "mt_data"
    source_sentences = (data_dir / "source_sentences.txt").read_text(encoding="utf-8").splitlines()
    beam_search_translations = (data_dir / "beam_search_translations.txt").read_text(encoding="utf-8").splitlines()
    reference_translations = (data_dir / "reference_translations.txt").read_text(encoding="utf-8").splitlines()

    improved_translations = []
    sentence_pairs = []  # To store source, draft 1, draft 2, and reference

    for source, target, reference in zip(source_sentences, beam_search_translations, reference_translations):
        # Construct the prompt
        user_prompt = f"""
        As an expert translation post editor, your task is to improve the English translation (Target) for the below German text (Source)
        Source: {source}
        Target: {target}
        Your output should be your improved version of the target text only. Do not add any comments or explanations before or after the improved version of the target text. 
        """

        # Prepare API payload
        payload = {
            "model": CHATGPT_MODELS["GPT-4"],
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            "temperature": temperature,
            "top_p": top_p,
            "max_tokens": 512
        }

        headers = {
            "Authorization": f"Bearer {OPENAI_API_KEY}",
            "Content-Type": "application/json"
        }

        # Call OpenAI API
        response = requests.post(OPENAI_API_URL, headers=headers, json=payload)
        response.raise_for_status()
        data = response.json()

        # Extract improved translation
        output = data["choices"][0]["message"]["content"]
        improved_translation = output.split("Improved Translation:")[-1].strip()
        improved_translations.append(improved_translation)

        # Add sentence pair to the list
        sentence_pairs.append([source, target, improved_translation, reference])

    # Calculate ChrF scores
    beam_chrf_scores = [
        calculate_chrf(beam_translation, reference)
        for beam_translation, reference in zip(beam_search_translations, reference_translations)
    ]
    improved_chrf_scores = [
        calculate_chrf(improved_translation, reference)
        for improved_translation, reference in zip(improved_translations, reference_translations)
    ]

    # Calculate COMET scores
    beam_comet_scores = calculate_comet(source_sentences, beam_search_translations, reference_translations)
    improved_comet_scores = calculate_comet(source_sentences, improved_translations, reference_translations)

    # Calculate average scores
    average_beam_chrf = sum(beam_chrf_scores) / len(beam_chrf_scores)
    average_improved_chrf = sum(improved_chrf_scores) / len(improved_chrf_scores)
    average_beam_comet = sum(beam_comet_scores) / len(beam_comet_scores)
    average_improved_comet = sum(improved_comet_scores) / len(improved_comet_scores)

    # Calculate score changes
    chrf_change = average_improved_chrf - average_beam_chrf
    comet_change = average_improved_comet - average_beam_comet

    # Prepare dataframes
    sentence_pairs_df = sentence_pairs  # Dataframe for sentence pairs
    scores_df = [
        ["ChrF", round(average_beam_chrf, 2), round(average_improved_chrf, 2), round(chrf_change, 2)],
        ["COMET", round(average_beam_comet, 2), round(average_improved_comet, 2), round(comet_change, 2)]
    ]

    # Return dataframes and evaluation message
    evaluation_message = f"ChrF Change: {(average_improved_chrf/chrf_change):.2f}%, COMET Change: {(average_improved_comet/comet_change):.2f}%"
    return sentence_pairs_df, scores_df, evaluation_message

# Gradio interface
app = gr.Interface(
    fn=improve_translations, 
    inputs=[
        gr.Textbox(label="System Prompt", placeholder="Define the assistant's behavior here..."),
        gr.Slider(value=1, minimum=0, maximum=1.9, step=0.1, label="Temperature"),
        gr.Slider(value=1, minimum=0, maximum=1, step=0.01, label="Top P")
    ],
    outputs=[
        gr.Dataframe(headers=["Source text", "Draft 1", "Draft 2", "Reference"], label="Sentence Pairs"),
        gr.Dataframe(headers=["Metric", "Draft 1", "Draft 2", "Change"], label="Scores"),
        gr.Textbox(label="Evaluation Results")
    ],
    title="Translation Post-Editing and Evaluation",
    description="Improve translations using GPT-4 and evaluate the results with ChrF and COMET."
)



if __name__ == "__main__":
    # Try different ports in case 7860 is occupied
    for port in range(7860, 7870):
        try:
            app.launch(
                server_name="127.0.0.1",  # localhost
                server_port=port,
                share=False,  # Don't create public URL
                debug=True    # Show detailed errors
            )
            break
        except OSError:
            if port == 7869:  # Last attempt
                print("Could not find an available port between 7860-7869")
                raise
            continue