Commit
·
bc828c5
1
Parent(s):
a2d40e3
Add language parsing functionality and update dependencies
Browse files- main.py +33 -10
- requirements.in +2 -1
- requirements.txt +2 -0
main.py
CHANGED
|
@@ -15,6 +15,7 @@ from starlette.responses import RedirectResponse
|
|
| 15 |
from cashews import cache
|
| 16 |
from datetime import timedelta
|
| 17 |
import logging
|
|
|
|
| 18 |
|
| 19 |
cache.setup("mem://")
|
| 20 |
|
|
@@ -93,6 +94,19 @@ async def get_dataset_info(hub_id: str, config: str | None = None):
|
|
| 93 |
return resp.json()
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
async def get_random_rows(
|
| 97 |
hub_id: str,
|
| 98 |
total_length: int,
|
|
@@ -110,15 +124,8 @@ async def get_random_rows(
|
|
| 110 |
offset = random.randint(0, total_length - rows_per_call)
|
| 111 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
| 112 |
logger.info(f"Fetching {url}")
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
if response.status_code == 200:
|
| 116 |
-
data = response.json()
|
| 117 |
-
batch_rows = data.get("rows")
|
| 118 |
-
rows.extend(batch_rows)
|
| 119 |
-
else:
|
| 120 |
-
print(f"Failed to fetch data: {response.status_code}")
|
| 121 |
-
print(url)
|
| 122 |
if len(rows) >= number_of_rows:
|
| 123 |
break
|
| 124 |
return [row.get("row") for row in rows]
|
|
@@ -181,6 +188,17 @@ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
|
|
| 181 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
| 182 |
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def predict_rows(
|
| 185 |
rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
|
| 186 |
):
|
|
@@ -196,8 +214,13 @@ def predict_rows(
|
|
| 196 |
langues_counts, threshold_percent=language_threshold_percent
|
| 197 |
)
|
| 198 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
default_data = {
|
| 200 |
-
"
|
|
|
|
| 201 |
"hub_id": "hub_id",
|
| 202 |
"config": "config",
|
| 203 |
}
|
|
|
|
| 15 |
from cashews import cache
|
| 16 |
from datetime import timedelta
|
| 17 |
import logging
|
| 18 |
+
from iso639 import Lang
|
| 19 |
|
| 20 |
cache.setup("mem://")
|
| 21 |
|
|
|
|
| 94 |
return resp.json()
|
| 95 |
|
| 96 |
|
| 97 |
+
@cache(ttl=timedelta(minutes=5))
|
| 98 |
+
async def fetch_rows(url: str) -> list[dict]:
|
| 99 |
+
response = await async_client.get(url)
|
| 100 |
+
if response.status_code == 200:
|
| 101 |
+
data = response.json()
|
| 102 |
+
return data.get("rows")
|
| 103 |
+
else:
|
| 104 |
+
print(f"Failed to fetch data: {response.status_code}")
|
| 105 |
+
print(url)
|
| 106 |
+
return []
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# Function to get random rows from the dataset
|
| 110 |
async def get_random_rows(
|
| 111 |
hub_id: str,
|
| 112 |
total_length: int,
|
|
|
|
| 124 |
offset = random.randint(0, total_length - rows_per_call)
|
| 125 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
| 126 |
logger.info(f"Fetching {url}")
|
| 127 |
+
batch_rows = await fetch_rows(url)
|
| 128 |
+
rows.extend(batch_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if len(rows) >= number_of_rows:
|
| 130 |
break
|
| 131 |
return [row.get("row") for row in rows]
|
|
|
|
| 188 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
| 189 |
|
| 190 |
|
| 191 |
+
def try_parse_language(lang: str) -> str | None:
|
| 192 |
+
try:
|
| 193 |
+
split = lang.split("_")
|
| 194 |
+
lang = split[0]
|
| 195 |
+
lang = Lang(lang)
|
| 196 |
+
return lang.pt1
|
| 197 |
+
except Exception as e:
|
| 198 |
+
logger.error(f"Failed to parse language {lang}: {e}")
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
|
| 202 |
def predict_rows(
|
| 203 |
rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
|
| 204 |
):
|
|
|
|
| 214 |
langues_counts, threshold_percent=language_threshold_percent
|
| 215 |
)
|
| 216 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
| 217 |
+
raw_model_prediction_summary = dict(valmap(get_mean_score, filtered_dict))
|
| 218 |
+
parsed_langs = {
|
| 219 |
+
try_parse_language(k): v for k, v in raw_model_prediction_summary.items()
|
| 220 |
+
}
|
| 221 |
default_data = {
|
| 222 |
+
"language_prediction_summary": parsed_langs,
|
| 223 |
+
"raw_model_prediction_summary": raw_model_prediction_summary,
|
| 224 |
"hub_id": "hub_id",
|
| 225 |
"config": "config",
|
| 226 |
}
|
requirements.in
CHANGED
|
@@ -8,4 +8,5 @@ huggingface_hub
|
|
| 8 |
python-dotenv
|
| 9 |
rich
|
| 10 |
toolz
|
| 11 |
-
uvicorn[standard]
|
|
|
|
|
|
| 8 |
python-dotenv
|
| 9 |
rich
|
| 10 |
toolz
|
| 11 |
+
uvicorn[standard]
|
| 12 |
+
iso639-lang
|
requirements.txt
CHANGED
|
@@ -51,6 +51,8 @@ idna==3.6
|
|
| 51 |
# anyio
|
| 52 |
# httpx
|
| 53 |
# requests
|
|
|
|
|
|
|
| 54 |
markdown-it-py==3.0.0
|
| 55 |
# via rich
|
| 56 |
mdurl==0.1.2
|
|
|
|
| 51 |
# anyio
|
| 52 |
# httpx
|
| 53 |
# requests
|
| 54 |
+
iso639-lang==2.2.2
|
| 55 |
+
# via -r requirements.in
|
| 56 |
markdown-it-py==3.0.0
|
| 57 |
# via rich
|
| 58 |
mdurl==0.1.2
|