Commit
·
0db2cde
1
Parent(s):
dcdf99d
internal comet
Browse files- evaluator/comet.py +18 -12
- evaluator/comet_internal.py +28 -0
evaluator/comet.py
CHANGED
|
@@ -1,28 +1,34 @@
|
|
| 1 |
import os
|
| 2 |
-
|
| 3 |
|
| 4 |
-
# Set
|
| 5 |
-
|
|
|
|
| 6 |
|
| 7 |
def calculate_comet(source_sentences, translations, references):
|
| 8 |
"""
|
| 9 |
-
Calculate COMET scores
|
| 10 |
:param source_sentences: List of source sentences.
|
| 11 |
:param translations: List of translated sentences (hypotheses).
|
| 12 |
:param references: List of reference translations.
|
| 13 |
:return: List of COMET scores (one score per sentence pair).
|
| 14 |
"""
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
# Prepare data for
|
| 20 |
data = [
|
| 21 |
-
{"
|
| 22 |
for src, mt, ref in zip(source_sentences, translations, references)
|
| 23 |
]
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return scores
|
|
|
|
| 1 |
import os
|
| 2 |
+
import requests
|
| 3 |
|
| 4 |
+
# Set the Hugging Face Inference API URL and token
|
| 5 |
+
HF_API_URL = "https://api-inference.huggingface.co/models/Unbabel/wmt20-comet-da"
|
| 6 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Ensure this is set in your environment
|
| 7 |
|
| 8 |
def calculate_comet(source_sentences, translations, references):
|
| 9 |
"""
|
| 10 |
+
Calculate COMET scores using the Hugging Face Inference API.
|
| 11 |
:param source_sentences: List of source sentences.
|
| 12 |
:param translations: List of translated sentences (hypotheses).
|
| 13 |
:param references: List of reference translations.
|
| 14 |
:return: List of COMET scores (one score per sentence pair).
|
| 15 |
"""
|
| 16 |
+
headers = {
|
| 17 |
+
"Authorization": f"Bearer {HF_API_TOKEN}",
|
| 18 |
+
"Content-Type": "application/json"
|
| 19 |
+
}
|
| 20 |
|
| 21 |
+
# Prepare data for the API
|
| 22 |
data = [
|
| 23 |
+
{"source": src, "translation": mt, "reference": ref}
|
| 24 |
for src, mt, ref in zip(source_sentences, translations, references)
|
| 25 |
]
|
| 26 |
|
| 27 |
+
# Make the API call
|
| 28 |
+
response = requests.post(HF_API_URL, headers=headers, json={"inputs": data})
|
| 29 |
+
response.raise_for_status() # Raise an error for bad responses
|
| 30 |
+
|
| 31 |
+
# Parse the response
|
| 32 |
+
results = response.json()
|
| 33 |
+
scores = [item["score"] for item in results] # Extract scores from the response
|
| 34 |
return scores
|
evaluator/comet_internal.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from comet import download_model, load_from_checkpoint
|
| 3 |
+
|
| 4 |
+
# Set a custom cache directory for COMET
|
| 5 |
+
os.environ["COMET_CACHE"] = "/tmp"
|
| 6 |
+
|
| 7 |
+
def calculate_comet(source_sentences, translations, references):
|
| 8 |
+
"""
|
| 9 |
+
Calculate COMET scores for a list of translations.
|
| 10 |
+
:param source_sentences: List of source sentences.
|
| 11 |
+
:param translations: List of translated sentences (hypotheses).
|
| 12 |
+
:param references: List of reference translations.
|
| 13 |
+
:return: List of COMET scores (one score per sentence pair).
|
| 14 |
+
"""
|
| 15 |
+
# Download and load the COMET model
|
| 16 |
+
model_path = download_model("Unbabel/wmt22-comet-da") # Use a supported model
|
| 17 |
+
model = load_from_checkpoint(model_path)
|
| 18 |
+
|
| 19 |
+
# Prepare data for COMET
|
| 20 |
+
data = [
|
| 21 |
+
{"src": src, "mt": mt, "ref": ref}
|
| 22 |
+
for src, mt, ref in zip(source_sentences, translations, references)
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
# Compute COMET scores
|
| 26 |
+
results = model.predict(data, batch_size=8, gpus=0)
|
| 27 |
+
scores = results["scores"] # Extract the scores from the results
|
| 28 |
+
return scores
|