File size: 1,390 Bytes
3c397fd
d130b8e
 
3c397fd
 
 
d130b8e
 
 
 
 
3c397fd
d130b8e
 
 
 
 
 
 
 
 
 
 
3c397fd
d130b8e
 
 
 
 
 
 
 
 
3c397fd
d130b8e
 
 
3c397fd
d130b8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import torch
from comet import download_model, load_from_checkpoint

def calculate_comet(source_sentences, translations, references):
    """
    Calculate COMET scores using the local COMET installation.
    :param source_sentences: List of source sentences
    :param translations: List of translated sentences
    :param references: List of reference translations
    :return: List of COMET scores
    """
    try:
        # Download and load the COMET model
                # Set cache directory explicitly
        os.environ["COMET_CACHE"] = "/tmp"
        # Download and load the COMET model      
        model_path = download_model("Unbabel/wmt22-comet-da")
        model = load_from_checkpoint(model_path)

        # Check for GPU availability
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        # Prepare data in COMET format
        data = [
            {
                "src": src,
                "mt": mt,
                "ref": ref
            }
            for src, mt, ref in zip(source_sentences, translations, references)
        ]

        # Get predictions (use GPU if available)
        results = model.predict(data, batch_size=8, gpus=1 if device == "cuda" else 0)
        return results["scores"]

    except Exception as e:
        print(f"COMET Error: {str(e)}")
        return [0.0] * len(source_sentences)