morgankavanagh commited on
Commit
0db2cde
·
1 Parent(s): dcdf99d

internal comet

Browse files
Files changed (2) hide show
  1. evaluator/comet.py +18 -12
  2. evaluator/comet_internal.py +28 -0
evaluator/comet.py CHANGED
@@ -1,28 +1,34 @@
1
  import os
2
- from comet import download_model, load_from_checkpoint
3
 
4
- # Set a custom cache directory for COMET
5
- os.environ["COMET_CACHE"] = "/tmp"
 
6
 
7
  def calculate_comet(source_sentences, translations, references):
8
  """
9
- Calculate COMET scores for a list of translations.
10
  :param source_sentences: List of source sentences.
11
  :param translations: List of translated sentences (hypotheses).
12
  :param references: List of reference translations.
13
  :return: List of COMET scores (one score per sentence pair).
14
  """
15
- # Download and load the COMET model
16
- model_path = download_model("Unbabel/wmt22-comet-da") # Use a supported model
17
- model = load_from_checkpoint(model_path)
 
18
 
19
- # Prepare data for COMET
20
  data = [
21
- {"src": src, "mt": mt, "ref": ref}
22
  for src, mt, ref in zip(source_sentences, translations, references)
23
  ]
24
 
25
- # Compute COMET scores
26
- results = model.predict(data, batch_size=8, gpus=0)
27
- scores = results["scores"] # Extract the scores from the results
 
 
 
 
28
  return scores
 
1
  import os
2
+ import requests
3
 
4
+ # Set the Hugging Face Inference API URL and token
5
+ HF_API_URL = "https://api-inference.huggingface.co/models/Unbabel/wmt20-comet-da"
6
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Ensure this is set in your environment
7
 
8
  def calculate_comet(source_sentences, translations, references):
9
  """
10
+ Calculate COMET scores using the Hugging Face Inference API.
11
  :param source_sentences: List of source sentences.
12
  :param translations: List of translated sentences (hypotheses).
13
  :param references: List of reference translations.
14
  :return: List of COMET scores (one score per sentence pair).
15
  """
16
+ headers = {
17
+ "Authorization": f"Bearer {HF_API_TOKEN}",
18
+ "Content-Type": "application/json"
19
+ }
20
 
21
+ # Prepare data for the API
22
  data = [
23
+ {"source": src, "translation": mt, "reference": ref}
24
  for src, mt, ref in zip(source_sentences, translations, references)
25
  ]
26
 
27
+ # Make the API call
28
+ response = requests.post(HF_API_URL, headers=headers, json={"inputs": data})
29
+ response.raise_for_status() # Raise an error for bad responses
30
+
31
+ # Parse the response
32
+ results = response.json()
33
+ scores = [item["score"] for item in results] # Extract scores from the response
34
  return scores
evaluator/comet_internal.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from comet import download_model, load_from_checkpoint
3
+
4
+ # Set a custom cache directory for COMET
5
+ os.environ["COMET_CACHE"] = "/tmp"
6
+
7
+ def calculate_comet(source_sentences, translations, references):
8
+ """
9
+ Calculate COMET scores for a list of translations.
10
+ :param source_sentences: List of source sentences.
11
+ :param translations: List of translated sentences (hypotheses).
12
+ :param references: List of reference translations.
13
+ :return: List of COMET scores (one score per sentence pair).
14
+ """
15
+ # Download and load the COMET model
16
+ model_path = download_model("Unbabel/wmt22-comet-da") # Use a supported model
17
+ model = load_from_checkpoint(model_path)
18
+
19
+ # Prepare data for COMET
20
+ data = [
21
+ {"src": src, "mt": mt, "ref": ref}
22
+ for src, mt, ref in zip(source_sentences, translations, references)
23
+ ]
24
+
25
+ # Compute COMET scores
26
+ results = model.predict(data, batch_size=8, gpus=0)
27
+ scores = results["scores"] # Extract the scores from the results
28
+ return scores