Commit
·
ef19caa
1
Parent(s):
9915c6f
Refactor app.py: Import modules, update function parameters, and improve logging
Browse files
app.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from httpx import Client
|
| 3 |
-
import random
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import fasttext
|
| 6 |
-
|
| 7 |
-
from typing import Union
|
| 8 |
-
from typing import Iterator
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from httpx import Timeout
|
| 13 |
from huggingface_hub.utils import logging
|
|
|
|
| 14 |
|
| 15 |
logger = logging.get_logger(__name__)
|
| 16 |
load_dotenv()
|
|
@@ -24,6 +23,7 @@ headers = {
|
|
| 24 |
}
|
| 25 |
timeout = Timeout(60, read=120)
|
| 26 |
client = Client(headers=headers, timeout=timeout)
|
|
|
|
| 27 |
# non exhaustive list of columns that might contain text which can be used for language detection
|
| 28 |
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
|
| 29 |
TARGET_COLUMN_NAMES = {
|
|
@@ -73,10 +73,10 @@ def get_dataset_info(hub_id: str, config: str | None = None):
|
|
| 73 |
|
| 74 |
|
| 75 |
def get_random_rows(
|
| 76 |
-
hub_id,
|
| 77 |
-
total_length,
|
| 78 |
-
number_of_rows,
|
| 79 |
-
max_request_calls,
|
| 80 |
config="default",
|
| 81 |
split="train",
|
| 82 |
):
|
|
@@ -88,8 +88,9 @@ def get_random_rows(
|
|
| 88 |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
|
| 89 |
offset = random.randint(0, total_length - rows_per_call)
|
| 90 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
|
|
|
|
|
|
| 91 |
response = client.get(url)
|
| 92 |
-
|
| 93 |
if response.status_code == 200:
|
| 94 |
data = response.json()
|
| 95 |
batch_rows = data.get("rows")
|
|
@@ -107,10 +108,6 @@ def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
|
| 107 |
return fasttext.load_model(model_path)
|
| 108 |
|
| 109 |
|
| 110 |
-
# def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str):
|
| 111 |
-
# pass
|
| 112 |
-
|
| 113 |
-
|
| 114 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
| 115 |
for row in rows:
|
| 116 |
if isinstance(row, str):
|
|
@@ -186,7 +183,8 @@ def predict_language(
|
|
| 186 |
config: str | None = None,
|
| 187 |
split: str | None = None,
|
| 188 |
max_request_calls: int = 10,
|
| 189 |
-
|
|
|
|
| 190 |
is_valid = datasets_server_valid_rows(hub_id)
|
| 191 |
if not is_valid:
|
| 192 |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
|
@@ -202,7 +200,7 @@ def predict_language(
|
|
| 202 |
logger.info(f"Column names: {column_names}")
|
| 203 |
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
|
| 204 |
raise gr.Error(
|
| 205 |
-
f"Dataset {hub_id}
|
| 206 |
)
|
| 207 |
for column in TARGET_COLUMN_NAMES:
|
| 208 |
if column in column_names:
|
|
@@ -210,7 +208,12 @@ def predict_language(
|
|
| 210 |
logger.info(f"Using column {target_column} for language detection")
|
| 211 |
break
|
| 212 |
random_rows = get_random_rows(
|
| 213 |
-
hub_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
logger.info(f"Predicting language for {len(random_rows)} rows")
|
| 216 |
predictions = predict_rows(random_rows, target_column)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import random
|
| 3 |
+
from statistics import mean
|
| 4 |
+
from typing import Iterator, Union
|
| 5 |
+
|
| 6 |
import fasttext
|
| 7 |
+
import gradio as gr
|
|
|
|
|
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
+
from httpx import Client, Timeout
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
| 11 |
from huggingface_hub.utils import logging
|
| 12 |
+
from toolz import concat, groupby, valmap
|
| 13 |
|
| 14 |
logger = logging.get_logger(__name__)
|
| 15 |
load_dotenv()
|
|
|
|
| 23 |
}
|
| 24 |
timeout = Timeout(60, read=120)
|
| 25 |
client = Client(headers=headers, timeout=timeout)
|
| 26 |
+
# async_client = AsyncClient(headers=headers, timeout=timeout)
|
| 27 |
# non exhaustive list of columns that might contain text which can be used for language detection
|
| 28 |
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
|
| 29 |
TARGET_COLUMN_NAMES = {
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def get_random_rows(
|
| 76 |
+
hub_id: str,
|
| 77 |
+
total_length: int,
|
| 78 |
+
number_of_rows: int,
|
| 79 |
+
max_request_calls: int,
|
| 80 |
config="default",
|
| 81 |
split="train",
|
| 82 |
):
|
|
|
|
| 88 |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
|
| 89 |
offset = random.randint(0, total_length - rows_per_call)
|
| 90 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
| 91 |
+
logger.info(f"Fetching {url}")
|
| 92 |
+
print(url)
|
| 93 |
response = client.get(url)
|
|
|
|
| 94 |
if response.status_code == 200:
|
| 95 |
data = response.json()
|
| 96 |
batch_rows = data.get("rows")
|
|
|
|
| 108 |
return fasttext.load_model(model_path)
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
| 112 |
for row in rows:
|
| 113 |
if isinstance(row, str):
|
|
|
|
| 183 |
config: str | None = None,
|
| 184 |
split: str | None = None,
|
| 185 |
max_request_calls: int = 10,
|
| 186 |
+
number_of_rows: int = 1000,
|
| 187 |
+
) -> dict[str, float | str]:
|
| 188 |
is_valid = datasets_server_valid_rows(hub_id)
|
| 189 |
if not is_valid:
|
| 190 |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
|
|
|
| 200 |
logger.info(f"Column names: {column_names}")
|
| 201 |
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
|
| 202 |
raise gr.Error(
|
| 203 |
+
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
|
| 204 |
)
|
| 205 |
for column in TARGET_COLUMN_NAMES:
|
| 206 |
if column in column_names:
|
|
|
|
| 208 |
logger.info(f"Using column {target_column} for language detection")
|
| 209 |
break
|
| 210 |
random_rows = get_random_rows(
|
| 211 |
+
hub_id,
|
| 212 |
+
total_rows_for_split,
|
| 213 |
+
number_of_rows,
|
| 214 |
+
max_request_calls,
|
| 215 |
+
config,
|
| 216 |
+
split,
|
| 217 |
)
|
| 218 |
logger.info(f"Predicting language for {len(random_rows)} rows")
|
| 219 |
predictions = predict_rows(random_rows, target_column)
|