Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,888 Bytes
5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 5c395b2 24f37c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gc
from functools import partial
import gradio as gr
import torch
from langdetect import detect, LangDetectException
from transformers import MarianMTModel, MarianTokenizer
from utils import get_pytorch_device, spaces_gpu, get_torch_dtype
# Language code mapping to Helsinki-NLP translation models
# If a specific language pair model doesn't exist, we'll use the multilingual model
LANGUAGE_TO_MODEL_MAP = {
"fr": "Helsinki-NLP/opus-mt-fr-en",
"de": "Helsinki-NLP/opus-mt-de-en",
"es": "Helsinki-NLP/opus-mt-es-en",
"it": "Helsinki-NLP/opus-mt-it-en",
"pt": "Helsinki-NLP/opus-mt-pt-en",
"ru": "Helsinki-NLP/opus-mt-ru-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
"ja": "Helsinki-NLP/opus-mt-ja-en",
"ko": "Helsinki-NLP/opus-mt-ko-en",
"ar": "Helsinki-NLP/opus-mt-ar-en",
"nl": "Helsinki-NLP/opus-mt-nl-en",
"pl": "Helsinki-NLP/opus-mt-pl-en",
"tr": "Helsinki-NLP/opus-mt-tr-en",
"vi": "Helsinki-NLP/opus-mt-vi-en",
"hi": "Helsinki-NLP/opus-mt-hi-en",
"cs": "Helsinki-NLP/opus-mt-cs-en",
"sv": "Helsinki-NLP/opus-mt-sv-en",
"fi": "Helsinki-NLP/opus-mt-fi-en",
"uk": "Helsinki-NLP/opus-mt-uk-en",
"ro": "Helsinki-NLP/opus-mt-ro-en",
"th": "Helsinki-NLP/opus-mt-th-en",
}
def detect_language(text: str) -> str:
"""Detect the language of the input text using langdetect library.
Uses the langdetect library, which is a Python port of Google's language-detection
library. It supports over 55 languages and is known for high accuracy, especially
for languages with unique character sets like Korean, Japanese, and Chinese.
Args:
text: Input text to detect the language of.
Returns:
ISO 639-1 language code (e.g., "en", "fr", "de", "ko", "ja") of the detected language.
Raises:
LangDetectException: If the language cannot be detected (e.g., text is too short).
"""
try:
language_code = detect(text)
return language_code
except LangDetectException:
# If detection fails, default to English (will be handled by translation logic)
return "en"
def get_translation_model(language_code: str, fallback_model: str) -> str:
"""Get the appropriate translation model for a given language code.
Args:
language_code: ISO 639-1 language code (e.g., "fr", "de", "en").
fallback_model: Fallback model to use if no specific model is available.
Returns:
Model ID for translation, or fallback model if language not in mapping.
"""
if language_code == "en":
return None # Already in English
return LANGUAGE_TO_MODEL_MAP.get(language_code, fallback_model)
@spaces_gpu
def translate_to_english(fallback_translation_model: str, text: str) -> str:
"""Translate text to English using automatic language detection.
First detects the source language using the langdetect library, then selects
the appropriate translation model and translates the text to English using
a local MarianMT model.
Args:
fallback_translation_model: Fallback translation model to use if no
language-specific model is available.
text: Input text to translate to English.
Returns:
String containing the translated text in English, or the original text
if it is already in English.
Note:
- Uses safetensors for secure model loading.
- Automatically selects the best available device (CUDA/XPU/MPS/CPU).
- Cleans up model and GPU memory after inference.
"""
# Detect the language using langdetect library
detected_lang = detect_language(text)
# Check if already in English
if detected_lang == "en":
return text
# Get the appropriate translation model
translation_model = get_translation_model(detected_lang, fallback_translation_model)
# Load model and tokenizer
pytorch_device = get_pytorch_device()
dtype = get_torch_dtype()
# During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
# reduces memory consumption by not storing gradients. This can significantly reduce the
# amount of memory used during the inference phase.
tokenizer = MarianTokenizer.from_pretrained(translation_model)
model = MarianMTModel.from_pretrained(
translation_model,
use_safetensors=True,
dtype=dtype
).to(pytorch_device)
# Tokenize and translate
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(pytorch_device)
with torch.no_grad():
translated = model.generate(**inputs)
translation = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
# Clean up GPU memory
del model, tokenizer, inputs, translated
if pytorch_device == "cuda":
torch.cuda.empty_cache()
gc.collect()
return translation
def create_translation_tab(fallback_translation_model: str):
"""Create the translation to English tab in the Gradio interface.
This function sets up all UI components for translation with automatic
language detection, including input textbox, translate button, and output textbox.
Args:
fallback_translation_model: Fallback translation model to use if no
language-specific model is available.
"""
gr.Markdown("Translate text to English. The source language will be automatically detected.")
translation_input = gr.Textbox(label="Input Text", lines=5)
translation_button = gr.Button("Translate")
translation_output = gr.Textbox(label="Translated Text", lines=5, interactive=False)
translation_button.click(
fn=partial(translate_to_english, fallback_translation_model),
inputs=translation_input,
outputs=translation_output
)
|