File size: 5,888 Bytes
5c395b2
24f37c6
 
5c395b2
24f37c6
5c395b2
 
24f37c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c395b2
 
24f37c6
 
 
5c395b2
 
24f37c6
 
 
 
 
 
 
5c395b2
 
 
 
 
 
 
24f37c6
 
 
 
 
 
 
 
 
 
 
5c395b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24f37c6
 
5c395b2
24f37c6
 
 
 
 
 
 
 
 
 
 
 
 
 
5c395b2
24f37c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gc
from functools import partial
import gradio as gr
import torch
from langdetect import detect, LangDetectException
from transformers import MarianMTModel, MarianTokenizer
from utils import get_pytorch_device, spaces_gpu, get_torch_dtype


# Language code mapping to Helsinki-NLP translation models
# If a specific language pair model doesn't exist, we'll use the multilingual model
LANGUAGE_TO_MODEL_MAP = {
    "fr": "Helsinki-NLP/opus-mt-fr-en",
    "de": "Helsinki-NLP/opus-mt-de-en",
    "es": "Helsinki-NLP/opus-mt-es-en",
    "it": "Helsinki-NLP/opus-mt-it-en",
    "pt": "Helsinki-NLP/opus-mt-pt-en",
    "ru": "Helsinki-NLP/opus-mt-ru-en",
    "zh": "Helsinki-NLP/opus-mt-zh-en",
    "ja": "Helsinki-NLP/opus-mt-ja-en",
    "ko": "Helsinki-NLP/opus-mt-ko-en",
    "ar": "Helsinki-NLP/opus-mt-ar-en",
    "nl": "Helsinki-NLP/opus-mt-nl-en",
    "pl": "Helsinki-NLP/opus-mt-pl-en",
    "tr": "Helsinki-NLP/opus-mt-tr-en",
    "vi": "Helsinki-NLP/opus-mt-vi-en",
    "hi": "Helsinki-NLP/opus-mt-hi-en",
    "cs": "Helsinki-NLP/opus-mt-cs-en",
    "sv": "Helsinki-NLP/opus-mt-sv-en",
    "fi": "Helsinki-NLP/opus-mt-fi-en",
    "uk": "Helsinki-NLP/opus-mt-uk-en",
    "ro": "Helsinki-NLP/opus-mt-ro-en",
    "th": "Helsinki-NLP/opus-mt-th-en",
}


def detect_language(text: str) -> str:
    """Detect the language of the input text using langdetect library.
    
    Uses the langdetect library, which is a Python port of Google's language-detection
    library. It supports over 55 languages and is known for high accuracy, especially
    for languages with unique character sets like Korean, Japanese, and Chinese.
    
    Args:
        text: Input text to detect the language of.
    
    Returns:
        ISO 639-1 language code (e.g., "en", "fr", "de", "ko", "ja") of the detected language.
    
    Raises:
        LangDetectException: If the language cannot be detected (e.g., text is too short).
    """
    try:
        language_code = detect(text)
        return language_code
    except LangDetectException:
        # If detection fails, default to English (will be handled by translation logic)
        return "en"


def get_translation_model(language_code: str, fallback_model: str) -> str:
    """Get the appropriate translation model for a given language code.
    
    Args:
        language_code: ISO 639-1 language code (e.g., "fr", "de", "en").
        fallback_model: Fallback model to use if no specific model is available.
    
    Returns:
        Model ID for translation, or fallback model if language not in mapping.
    """
    if language_code == "en":
        return None  # Already in English
    return LANGUAGE_TO_MODEL_MAP.get(language_code, fallback_model)


@spaces_gpu
def translate_to_english(fallback_translation_model: str, text: str) -> str:
    """Translate text to English using automatic language detection.
    
    First detects the source language using the langdetect library, then selects
    the appropriate translation model and translates the text to English using
    a local MarianMT model.
    
    Args:
        fallback_translation_model: Fallback translation model to use if no
            language-specific model is available.
        text: Input text to translate to English.
    
    Returns:
        String containing the translated text in English, or the original text
        if it is already in English.
    
    Note:
        - Uses safetensors for secure model loading.
        - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
        - Cleans up model and GPU memory after inference.
    """
    # Detect the language using langdetect library
    detected_lang = detect_language(text)
    
    # Check if already in English
    if detected_lang == "en":
        return text
    
    # Get the appropriate translation model
    translation_model = get_translation_model(detected_lang, fallback_translation_model)
    
    # Load model and tokenizer
    pytorch_device = get_pytorch_device()
    dtype = get_torch_dtype()
    
    # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad()
    # reduces memory consumption by not storing gradients. This can significantly reduce the
    # amount of memory used during the inference phase.
    tokenizer = MarianTokenizer.from_pretrained(translation_model)
    model = MarianMTModel.from_pretrained(
        translation_model,
        use_safetensors=True,
        dtype=dtype
    ).to(pytorch_device)
    
    # Tokenize and translate
    inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(pytorch_device)
    with torch.no_grad():
        translated = model.generate(**inputs)
    translation = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
    
    # Clean up GPU memory
    del model, tokenizer, inputs, translated
    if pytorch_device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()
    
    return translation


def create_translation_tab(fallback_translation_model: str):
    """Create the translation to English tab in the Gradio interface.
    
    This function sets up all UI components for translation with automatic
    language detection, including input textbox, translate button, and output textbox.
    
    Args:
        fallback_translation_model: Fallback translation model to use if no
            language-specific model is available.
    """
    gr.Markdown("Translate text to English. The source language will be automatically detected.")
    translation_input = gr.Textbox(label="Input Text", lines=5)
    translation_button = gr.Button("Translate")
    translation_output = gr.Textbox(label="Translated Text", lines=5, interactive=False)
    translation_button.click(
        fn=partial(translate_to_english, fallback_translation_model),
        inputs=translation_input,
        outputs=translation_output
    )