davanstrien
HF Staff
Refactor predict_language function and update interface inputs and outputs
cee1d41
| import gradio as gr | |
| from httpx import Client | |
| import random | |
| import os | |
| import fasttext | |
| from huggingface_hub import hf_hub_download | |
| from typing import Union | |
| from typing import Iterator | |
| from dotenv import load_dotenv | |
| from toolz import groupby, valmap, concat | |
| from statistics import mean | |
| from httpx import Timeout | |
| from huggingface_hub.utils import logging | |
| logger = logging.get_logger(__name__) | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" | |
| DEFAULT_FAST_TEXT_MODEL = "laurievb/OpenLID" | |
| headers = { | |
| "authorization": f"Bearer ${HF_TOKEN}", | |
| } | |
| timeout = Timeout(60, read=120) | |
| client = Client(headers=headers, timeout=timeout) | |
| # non exhaustive list of columns that might contain text which can be used for language detection | |
| # we prefer to use columns in this order i.e. if there is a column named "text" we will use it first | |
| TARGET_COLUMN_NAMES = { | |
| "text", | |
| "input", | |
| "tokens", | |
| "prompt", | |
| "instruction", | |
| "sentence_1", | |
| "question", | |
| "sentence2", | |
| "answer", | |
| "sentence", | |
| "response", | |
| "context", | |
| "query", | |
| "chosen", | |
| "rejected", | |
| } | |
| def datasets_server_valid_rows(hub_id: str): | |
| resp = client.get(f"{BASE_DATASETS_SERVER_URL}/is-valid?dataset={hub_id}") | |
| resp.raise_for_status() | |
| return resp.json()["viewer"] | |
| def get_first_config_and_split_name(hub_id: str): | |
| resp = client.get(f"https://datasets-server.huggingface.co/splits?dataset={hub_id}") | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return data["splits"][0]["config"], data["splits"][0]["split"] | |
| def get_dataset_info(hub_id: str, config: str | None = None): | |
| if config is None: | |
| config = get_first_config_and_split_name(hub_id) | |
| if config is None: | |
| return None | |
| else: | |
| config = config[0] | |
| resp = client.get( | |
| f"{BASE_DATASETS_SERVER_URL}/info?dataset={hub_id}&config={config}" | |
| ) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def get_random_rows( | |
| hub_id, | |
| total_length, | |
| number_of_rows, | |
| max_request_calls, | |
| config="default", | |
| split="train", | |
| ): | |
| rows = [] | |
| rows_per_call = min( | |
| number_of_rows // max_request_calls, total_length // max_request_calls | |
| ) | |
| rows_per_call = min(rows_per_call, 100) # Ensure rows_per_call is not more than 100 | |
| for _ in range(min(max_request_calls, number_of_rows // rows_per_call)): | |
| offset = random.randint(0, total_length - rows_per_call) | |
| url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}" | |
| response = client.get(url) | |
| if response.status_code == 200: | |
| data = response.json() | |
| batch_rows = data.get("rows") | |
| rows.extend(batch_rows) | |
| else: | |
| print(f"Failed to fetch data: {response.status_code}") | |
| print(url) | |
| if len(rows) >= number_of_rows: | |
| break | |
| return [row.get("row") for row in rows] | |
| def load_model(repo_id: str) -> fasttext.FastText._FastText: | |
| model_path = hf_hub_download(repo_id, filename="model.bin") | |
| return fasttext.load_model(model_path) | |
| # def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str): | |
| # pass | |
| def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]: | |
| for row in rows: | |
| if isinstance(row, str): | |
| # split on lines and remove empty lines | |
| line = row.split("\n") | |
| for line in line: | |
| if line: | |
| yield line | |
| elif isinstance(row, list): | |
| try: | |
| line = " ".join(row) | |
| if len(line) < min_length: | |
| continue | |
| else: | |
| yield line | |
| except TypeError: | |
| continue | |
| FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn" | |
| # model = load_model(DEFAULT_FAST_TEXT_MODEL) | |
| model = fasttext.load_model( | |
| hf_hub_download("facebook/fasttext-language-identification", "model.bin") | |
| ) | |
| def model_predict(inputs: str, k=1) -> list[dict[str, float]]: | |
| predictions = model.predict(inputs, k=k) | |
| return [ | |
| {"label": label[FASTTEXT_PREFIX_LENGTH:], "score": prob} | |
| for label, prob in zip(predictions[0], predictions[1]) | |
| ] | |
| def get_label(x): | |
| return x.get("label") | |
| def get_mean_score(preds): | |
| return mean([pred.get("score") for pred in preds]) | |
| def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2): | |
| """Filter a dict to include items whose value is above `threshold_percent`""" | |
| total = sum(counts_dict.values()) | |
| threshold = total * threshold_percent | |
| return {k for k, v in counts_dict.items() if v >= threshold} | |
| def predict_rows(rows, target_column, language_threshold_percent=0.2): | |
| rows = (row.get(target_column) for row in rows) | |
| rows = (row for row in rows if row is not None) | |
| rows = list(yield_clean_rows(rows)) | |
| predictions = [model_predict(row) for row in rows] | |
| predictions = [pred for pred in predictions if pred is not None] | |
| predictions = list(concat(predictions)) | |
| predictions_by_lang = groupby(get_label, predictions) | |
| langues_counts = valmap(len, predictions_by_lang) | |
| keys_to_keep = filter_by_frequency( | |
| langues_counts, threshold_percent=language_threshold_percent | |
| ) | |
| filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep} | |
| return { | |
| "predictions": dict(valmap(get_mean_score, filtered_dict)), | |
| "pred": predictions, | |
| } | |
| def predict_language( | |
| hub_id: str, | |
| config: str | None = None, | |
| split: str | None = None, | |
| max_request_calls: int = 10, | |
| ): | |
| is_valid = datasets_server_valid_rows(hub_id) | |
| if not is_valid: | |
| gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
| if not config: | |
| config, split = get_first_config_and_split_name(hub_id) | |
| info = get_dataset_info(hub_id, config) | |
| if info is None: | |
| gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.") | |
| if dataset_info := info.get("dataset_info"): | |
| total_rows_for_split = dataset_info.get("splits").get(split).get("num_examples") | |
| features = dataset_info.get("features") | |
| column_names = set(features.keys()) | |
| logger.info(f"Column names: {column_names}") | |
| if not set(column_names).intersection(TARGET_COLUMN_NAMES): | |
| raise gr.Error( | |
| f"Dataset {hub_id} does not contain any of the target columns {TARGET_COLUMN_NAMES}" | |
| ) | |
| for column in TARGET_COLUMN_NAMES: | |
| if column in column_names: | |
| target_column = column | |
| logger.info(f"Using column {target_column} for language detection") | |
| break | |
| random_rows = get_random_rows( | |
| hub_id, total_rows_for_split, 1000, max_request_calls, config, split | |
| ) | |
| logger.info(f"Predicting language for {len(random_rows)} rows") | |
| predictions = predict_rows(random_rows, target_column) | |
| predictions["hub_id"] = hub_id | |
| predictions["config"] = config | |
| predictions["split"] = split | |
| return predictions | |
| inputs = [ | |
| gr.Text(label="dataset id"), | |
| gr.Textbox( | |
| None, | |
| label="config", | |
| ), | |
| gr.Textbox(None, label="split"), | |
| ] | |
| interface = gr.Interface(predict_language, inputs=inputs, outputs="json") | |
| interface.queue() | |
| interface.launch() | |