|
5 | 5 | # @Software: PyCharm |
6 | 6 | import logging |
7 | 7 | import os |
| 8 | +from enum import Enum |
8 | 9 | from pathlib import Path |
9 | | -from typing import Dict, Union, List |
| 10 | +from typing import Dict, Union, List, Optional, Any |
10 | 11 |
|
11 | 12 | import fasttext |
12 | 13 | from robust_downloader import download |
13 | 14 |
|
14 | 15 | logger = logging.getLogger(__name__) |
15 | | -MODELS = {"low_mem": None, "high_mem": None} |
16 | | -FTLANG_CACHE = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect") |
| 16 | +CACHE_DIRECTORY = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect") |
| 17 | +LOCAL_SMALL_MODEL_PATH = Path(__file__).parent / "resources" / "lid.176.ftz" |
17 | 18 |
|
| 19 | +# Suppress FastText output if possible |
18 | 20 | try: |
19 | | - # silences warnings as the package does not properly use the python 'warnings' package |
20 | | - # see https://github.com/facebookresearch/fastText/issues/1056 |
21 | 21 | fasttext.FastText.eprint = lambda *args, **kwargs: None |
22 | 22 | except Exception: |
23 | 23 | pass |
24 | 24 |
|
25 | 25 |
|
| 26 | +class ModelType(Enum): |
| 27 | + LOW_MEMORY = "low_mem" |
| 28 | + HIGH_MEMORY = "high_mem" |
| 29 | + |
| 30 | + |
| 31 | +class ModelCache: |
| 32 | + def __init__(self): |
| 33 | + self._models = {} |
| 34 | + |
| 35 | + def get_model(self, model_type: ModelType) -> Optional["fasttext.FastText._FastText"]: |
| 36 | + return self._models.get(model_type) |
| 37 | + |
| 38 | + def set_model(self, model_type: ModelType, model: "fasttext.FastText._FastText"): |
| 39 | + self._models[model_type] = model |
| 40 | + |
| 41 | + |
| 42 | +_model_cache = ModelCache() |
| 43 | + |
| 44 | + |
26 | 45 | class DetectError(Exception): |
| 46 | + """Custom exception for language detection errors.""" |
27 | 47 | pass |
28 | 48 |
|
29 | 49 |
|
30 | | -def get_model_map(low_memory=False): |
| 50 | +def load_model(low_memory: bool = False, |
| 51 | + download_proxy: Optional[str] = None, |
| 52 | + use_strict_mode: bool = False) -> "fasttext.FastText._FastText": |
31 | 53 | """ |
32 | | - Getting model map |
33 | | - :param low_memory: |
34 | | - :return: |
| 54 | + Load the FastText model based on memory preference. |
| 55 | +
|
| 56 | + :param low_memory: Indicates whether to load a smaller, memory-efficient model |
| 57 | + :param download_proxy: Proxy to use for downloading the large model if necessary |
| 58 | + :param use_strict_mode: If enabled, strictly loads large model or raises error if it fails |
| 59 | + :return: Loaded FastText model |
| 60 | + :raises DetectError: If the model cannot be loaded |
35 | 61 | """ |
36 | | - if low_memory: |
37 | | - return "low_mem", FTLANG_CACHE, "lid.176.ftz", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.ftz" |
38 | | - else: |
39 | | - return "high_mem", FTLANG_CACHE, "lid.176.bin", "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" |
| 62 | + model_type = ModelType.LOW_MEMORY if low_memory else ModelType.HIGH_MEMORY |
40 | 63 |
|
| 64 | + # If the model is already loaded, return it |
| 65 | + cached_model = _model_cache.get_model(model_type) |
| 66 | + if cached_model: |
| 67 | + return cached_model |
41 | 68 |
|
42 | | -def get_model_loaded( |
43 | | - low_memory: bool = False, |
44 | | - download_proxy: str = None |
45 | | -): |
46 | | - """ |
47 | | - Getting model loaded |
48 | | - :param low_memory: |
49 | | - :param download_proxy: |
50 | | - :return: |
51 | | - """ |
52 | | - mode, cache, name, url = get_model_map(low_memory) |
53 | | - loaded = MODELS.get(mode, None) |
54 | | - if loaded: |
55 | | - return loaded |
56 | | - model_path = os.path.join(cache, name) |
57 | | - if Path(model_path).exists(): |
58 | | - if Path(model_path).is_dir(): |
59 | | - raise Exception(f"{model_path} is a directory") |
| 69 | + def load_local_small_model(): |
| 70 | + """Try to load the local small model.""" |
| 71 | + try: |
| 72 | + _loaded_model = fasttext.load_model(str(LOCAL_SMALL_MODEL_PATH)) |
| 73 | + _model_cache.set_model(ModelType.LOW_MEMORY, _loaded_model) |
| 74 | + return _loaded_model |
| 75 | + except Exception as e: |
| 76 | + logger.error(f"Failed to load the local small model '{LOCAL_SMALL_MODEL_PATH}': {e}") |
| 77 | + raise DetectError("Unable to load low-memory model from local resources.") |
| 78 | + |
| 79 | + def load_large_model(): |
| 80 | + """Try to load the large model.""" |
60 | 81 | try: |
61 | | - loaded_model = fasttext.load_model(model_path) |
62 | | - MODELS[mode] = loaded_model |
| 82 | + loaded_model = fasttext.load_model(str(model_path)) |
| 83 | + _model_cache.set_model(ModelType.HIGH_MEMORY, loaded_model) |
| 84 | + return loaded_model |
63 | 85 | except Exception as e: |
64 | | - logger.error(f"Error loading model {model_path}: {e}") |
65 | | - download(url=url, folder=cache, filename=name, proxy=download_proxy) |
66 | | - raise e |
67 | | - else: |
| 86 | + logger.error(f"Failed to load the large model '{model_path}': {e}") |
| 87 | + return None |
| 88 | + |
| 89 | + if low_memory: |
| 90 | + # Attempt to load the local small model |
| 91 | + return load_local_small_model() |
| 92 | + |
| 93 | + # Path for the large model |
| 94 | + large_model_name = "lid.176.bin" |
| 95 | + model_path = Path(CACHE_DIRECTORY) / large_model_name |
| 96 | + |
| 97 | + # If the large model is already present, load it |
| 98 | + if model_path.exists(): |
| 99 | + # Model cant be dir |
| 100 | + if model_path.is_dir(): |
| 101 | + try: |
| 102 | + model_path.rmdir() |
| 103 | + except Exception as e: |
| 104 | + logger.error(f"Failed to remove the directory '{model_path}': {e}") |
| 105 | + raise DetectError(f"Unexpected directory found in large model file path '{model_path}': {e}") |
| 106 | + # Attempt to load large model |
| 107 | + loaded_model = load_large_model() |
| 108 | + if loaded_model: |
| 109 | + return loaded_model |
| 110 | + |
| 111 | + # If the large model is not present, attempt to download (only if necessary) |
| 112 | + model_url = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" |
| 113 | + try: |
| 114 | + logger.info(f"Downloading large model from {model_url} to {model_path}") |
| 115 | + download( |
| 116 | + url=model_url, |
| 117 | + folder=CACHE_DIRECTORY, |
| 118 | + filename=large_model_name, |
| 119 | + proxy=download_proxy, |
| 120 | + retry_max=3, |
| 121 | + timeout=20 |
| 122 | + ) |
| 123 | + # Try loading the model again after download |
| 124 | + loaded_model = load_large_model() |
| 125 | + if loaded_model: |
68 | 126 | return loaded_model |
| 127 | + except Exception as e: |
| 128 | + logger.error(f"Failed to download the large model: {e}") |
69 | 129 |
|
70 | | - download(url=url, folder=cache, filename=name, proxy=download_proxy, retry_max=3, timeout=20) |
71 | | - loaded_model = fasttext.load_model(model_path) |
72 | | - MODELS[mode] = loaded_model |
73 | | - return loaded_model |
| 130 | + # Handle fallback logic for strict and non-strict modes |
| 131 | + if use_strict_mode: |
| 132 | + raise DetectError("Strict mode enabled: Unable to download or load the large model.") |
| 133 | + else: |
| 134 | + logger.info("Attempting to fall back to local small model.") |
| 135 | + return load_local_small_model() |
74 | 136 |
|
75 | 137 |
|
76 | 138 | def detect(text: str, *, |
77 | 139 | low_memory: bool = True, |
78 | | - model_download_proxy: str = None |
| 140 | + model_download_proxy: Optional[str] = None, |
| 141 | + use_strict_mode: bool = False |
79 | 142 | ) -> Dict[str, Union[str, float]]: |
80 | 143 | """ |
81 | | - Detect language of text |
82 | | -
|
| 144 | + Detect the language of a text using FastText. |
83 | 145 | This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character. |
84 | | -
|
85 | | - :param text: Text for language detection |
86 | | - :param low_memory: Whether to use low memory mode |
87 | | - :param model_download_proxy: model download proxy |
88 | | - :return: {"lang": "en", "score": 0.99} |
89 | | - :raise ValueError: predict processes one line at a time (remove \'\\n\') |
| 146 | + If the model is not supervised, this function will throw a ValueError. |
| 147 | + :param text: The text for language detection |
| 148 | + :param low_memory: Whether to use a memory-efficient model |
| 149 | + :param model_download_proxy: Download proxy for the model if needed |
| 150 | + :param use_strict_mode: If enabled, strictly loads large model or raises error if it fails |
| 151 | + :return: A dictionary with detected language and confidence score |
| 152 | + :raises LanguageDetectionError: If detection fails |
90 | 153 | """ |
91 | | - model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy) |
| 154 | + model = load_model(low_memory=low_memory, download_proxy=model_download_proxy, use_strict_mode=use_strict_mode) |
92 | 155 | labels, scores = model.predict(text) |
93 | | - label = labels[0].replace("__label__", '') |
94 | | - score = min(float(scores[0]), 1.0) |
| 156 | + language_label = labels[0].replace("__label__", '') |
| 157 | + confidence_score = min(float(scores[0]), 1.0) |
95 | 158 | return { |
96 | | - "lang": label, |
97 | | - "score": score, |
| 159 | + "lang": language_label, |
| 160 | + "score": confidence_score, |
98 | 161 | } |
99 | 162 |
|
100 | 163 |
|
101 | 164 | def detect_multilingual(text: str, *, |
102 | 165 | low_memory: bool = True, |
103 | | - model_download_proxy: str = None, |
| 166 | + model_download_proxy: Optional[str] = None, |
104 | 167 | k: int = 5, |
105 | 168 | threshold: float = 0.0, |
106 | | - on_unicode_error: str = "strict" |
107 | | - ) -> List[dict]: |
| 169 | + on_unicode_error: str = "strict", |
| 170 | + use_strict_mode: bool = False |
| 171 | + ) -> List[Dict[str, Any]]: |
108 | 172 | """ |
109 | | - Given a string, get a list of labels and a list of corresponding probabilities. |
110 | | - k controls the number of returned labels. A choice of 5, will return the 5 most probable labels. |
111 | | - By default this returns only the most likely label and probability. threshold filters the returned labels by a threshold on probability. A choice of 0.5 will return labels with at least 0.5 probability. |
112 | | - k and threshold will be applied together to determine the returned labels. |
113 | | -
|
114 | | - NOTE:This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed and the null character. |
115 | | -
|
116 | | - :param text: Text for language detection |
117 | | - :param low_memory: Whether to use low memory mode |
118 | | - :param model_download_proxy: model download proxy |
119 | | - :param k: Predict top k languages |
120 | | - :param threshold: Threshold for prediction |
121 | | - :param on_unicode_error: Error handling |
122 | | - :return: |
| 173 | + Detect multiple potential languages and their probabilities in a given text. |
| 174 | + k controls the number of returned labels. A choice of 5, will return the 5 most probable labels. By default, this returns only the most likely label and probability. threshold filters the returned labels by a threshold on probability. A choice of 0.5 will return labels with at least 0.5 probability. k and threshold will be applied together to determine the returned labels. |
| 175 | + This function assumes to be given a single line of text. We split words on whitespace (space, newline, tab, vertical tab) and the control characters carriage return, formfeed, and the null character. |
| 176 | + If the model is not supervised, this function will throw a ValueError. |
| 177 | +
|
| 178 | + :param text: The text for language detection |
| 179 | + :param low_memory: Whether to use a memory-efficient model |
| 180 | + :param model_download_proxy: Proxy for downloading the model |
| 181 | + :param k: Number of top language predictions to return |
| 182 | + :param threshold: Minimum score threshold for predictions |
| 183 | + :param on_unicode_error: Error handling for Unicode errors |
| 184 | + :param use_strict_mode: If enabled, strictly loads large model or raises error if it fails |
| 185 | + :return: A list of dictionaries, each containing a language and its confidence score |
| 186 | + :raises LanguageDetectionError: If detection fails |
123 | 187 | """ |
124 | | - model = get_model_loaded(low_memory=low_memory, download_proxy=model_download_proxy) |
125 | | - labels, scores = model.predict(text=text, k=k, threshold=threshold, on_unicode_error=on_unicode_error) |
126 | | - detect_result = [] |
| 188 | + model = load_model(low_memory=low_memory, download_proxy=model_download_proxy, use_strict_mode=use_strict_mode) |
| 189 | + labels, scores = model.predict(text, k=k, threshold=threshold, on_unicode_error=on_unicode_error) |
| 190 | + results = [] |
127 | 191 | for label, score in zip(labels, scores): |
128 | | - label = label.replace("__label__", '') |
129 | | - score = min(float(score), 1.0) |
130 | | - detect_result.append({ |
131 | | - "lang": label, |
132 | | - "score": score, |
| 192 | + language_label = label.replace("__label__", '') |
| 193 | + confidence_score = min(float(score), 1.0) |
| 194 | + results.append({ |
| 195 | + "lang": language_label, |
| 196 | + "score": confidence_score, |
133 | 197 | }) |
134 | | - return sorted(detect_result, key=lambda i: i['score'], reverse=True) |
| 198 | + return sorted(results, key=lambda x: x['score'], reverse=True) |
0 commit comments