File size: 1,390 Bytes
3c397fd d130b8e 3c397fd d130b8e 3c397fd d130b8e 3c397fd d130b8e 3c397fd d130b8e 3c397fd d130b8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import os
import torch
from comet import download_model, load_from_checkpoint
def calculate_comet(source_sentences, translations, references):
"""
Calculate COMET scores using the local COMET installation.
:param source_sentences: List of source sentences
:param translations: List of translated sentences
:param references: List of reference translations
:return: List of COMET scores
"""
try:
# Download and load the COMET model
# Set cache directory explicitly
os.environ["COMET_CACHE"] = "/tmp"
# Download and load the COMET model
model_path = download_model("Unbabel/wmt22-comet-da")
model = load_from_checkpoint(model_path)
# Check for GPU availability
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Prepare data in COMET format
data = [
{
"src": src,
"mt": mt,
"ref": ref
}
for src, mt, ref in zip(source_sentences, translations, references)
]
# Get predictions (use GPU if available)
results = model.predict(data, batch_size=8, gpus=1 if device == "cuda" else 0)
return results["scores"]
except Exception as e:
print(f"COMET Error: {str(e)}")
return [0.0] * len(source_sentences) |