diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b3e9dade..351b4167 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -74,6 +74,7 @@ jobs: pip install Cython wheel # need to pre-install to avoid error in nemo installation pip install nemo-toolkit[asr,nlp]==2.2.1 pip install nemo_text_processing + pip install pymarian pip install -r requirements/huggingface.txt pip install certifi #this needed to avoid problems with certificates [COORAL] export SSL_CERT_FILE=$(python -m certifi) diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index 26e9fb0a..9d1cb59c 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -208,6 +208,9 @@ used in the downstream processing for additional enhancement or filtering. .. autodata:: sdp.processors.AudioLid :annotation: +.. autodata:: sdp.processors.CometoidWMTQualityEstimation + :annotation: + .. autodata:: sdp.processors.FastTextLangIdClassifier :annotation: diff --git a/requirements/docs.txt b/requirements/docs.txt index 2f44117e..a2abeed4 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -3,5 +3,4 @@ pyyaml Sphinx sphinx-book-theme sphinx-copybutton -sphinxext-opengraph -tabulate \ No newline at end of file +sphinxext-opengraph \ No newline at end of file diff --git a/requirements/main.txt b/requirements/main.txt index 9553fc7c..31a4a87d 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -10,6 +10,8 @@ pandas rarfile regex sox +tabulate +termplotlib tqdm gdown webvtt-py @@ -30,5 +32,6 @@ datasets>=2.14.0,<3.0.0 # pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` # for vLLMInference processor is required: pip install "optree>=0.13.0" vllm +# for CometoidWMTQualityEstimation processor is required: pip install pymarian # for FastTextLangIdClassifier processor is required: pip install fasttext # for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1" \ No newline at end of file diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 69428ce8..dce8e584 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -151,6 +151,7 @@ from sdp.processors.inference.nlp.fasttext.fasttext import FastTextLangIdClassifier from sdp.processors.inference.llm.vllm.vllm import vLLMInference from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration +from sdp.processors.inference.quality_estimation.pymarian import CometoidWMTQualityEstimation from sdp.processors.manage_files.convert_audio import ( FfmpegConvert, diff --git a/sdp/processors/inference/quality_estimation/pymarian.py b/sdp/processors/inference/quality_estimation/pymarian.py new file mode 100644 index 00000000..9a1c4d94 --- /dev/null +++ b/sdp/processors/inference/quality_estimation/pymarian.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from tqdm import tqdm +import termplotlib as tpl +import numpy as np + +from sdp.logging import logger +from sdp.processors.base_processor import BaseProcessor + + +class CometoidWMTQualityEstimation(BaseProcessor): + """ + A processor for estimating translation quality using pretrained COMET-like models + based on MarianNMT and the pymarian Evaluator. + + This processor evaluates the quality of source-target text pairs (bitext) using + COMETOID-style quality estimation and appends the resulting score to each dataset entry. + + Args: + source_text_field (str): The key in the data entry containing the source (original) text. + target_text_field (str): The key in the data entry containing the target (translated) text. + model_name_or_path (str): Hugging Face model name or path to local model checkpoint. + vocab_path (str, optional): Path to the vocabulary file. If None and model is from HF, it will be downloaded. + save_model_to (str, optional): Directory to download and cache the model and vocab. + mini_batch (int): Mini-batch size for evaluation. + maxi_batch (int): Maxi-batch size for evaluation. + output_field (str): The name of the field where the quality score will be saved in the output manifest. + device_type (str): Device type to use: 'cpu' or 'gpu'. + num_devices (int): Number of CPU threads or GPU devices to use. Use -1 to use all available. + chunksize (int): Number of lines to process in each chunk. + + Returns: + A manifest file where each entry has an added key (`output_field`) with the computed score. + + .. note:: + This processor uses MarianNMT models fine-tuned for quality estimation. See https://marian-nmt.github.io/. + + Make sure to install `pymarian` before using this processor: + pip install pymarian + + + """ + + # Mapping of supported model aliases to Hugging Face repo paths + MODEL_NAME_TO_HF_PATH = { + "cometoid-wmt23": "marian-nmt/cometoid22-wmt23", + "cometoid-wmt23-mqm": "marian-nmt/cometoid22-wmt23", + } + + # Marian evaluation arguments depending on device + MARIAN_GPU_ARGS = "-w 8000 -d {device_indicies}" + MARIAN_CPU_ARGS = "-w 2000 --cpu-threads {num_threads}" + + def __init__(self, + source_text_field: str, + target_text_field: str, + model_name_or_path: str, + vocab_path: str = None, + save_model_to: str = None, + mini_batch: int = 16, + maxi_batch: int = 96, + output_field: str = 'cometoid_score', + device_type: str = 'cpu', + num_devices: int = -1, + chunksize = 5000, + **kwargs, + ): + super().__init__(**kwargs) + self.source_text_field = source_text_field + self.target_text_field = target_text_field + self.model_name_or_path = model_name_or_path + self.vocab_path = vocab_path + self.save_model_to = save_model_to + self.device_type = device_type + self.max_workers = num_devices + self.mini_batch = mini_batch + self.maxi_batch = maxi_batch + self.output_field = output_field + self.model = None + self.chunksize = chunksize + + def load_model(self): + try: + from pymarian import Evaluator + except ImportError: + raise ImportError("`pymarian` is not installed. Please install it using `pip install pymarian`.") + + from huggingface_hub import hf_hub_download + + """ + Load the model and vocabulary from Hugging Face if necessary. + Assemble command-line arguments for launching pymarian Evaluator. + Depending on the device (CPU/GPU), configure parallelism parameters. + """ + repo_id = None + if self.model_name_or_path in self.MODEL_NAME_TO_HF_PATH: + repo_id = self.MODEL_NAME_TO_HF_PATH[self.model_name_or_path] + self.model_name_or_path = hf_hub_download(repo_id, filename="checkpoints/marian.model.bin", local_dir = self.save_model_to) + + if not os.path.exists(self.model_name_or_path): + raise ValueError(f'`model_name_or_path`: model name is not valid or model path does not exist ({self.model_name_or_path}).') + + if not self.vocab_path and repo_id is not None: + self.vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.spm", local_dir = self.save_model_to) + + if not os.path.exists(self.vocab_path): + raise FileNotFoundError(f'`vocab_path`: path does not exist ({self.vocab_path}).') + + marian_args = f"-m {self.model_name_or_path} -v {self.vocab_path} {self.vocab_path} --like comet-qe" + + if self.device_type == "cpu": + max_available_cpus = os.cpu_count() + if self.max_workers == -1 or self.max_workers > max_available_cpus: + self.max_workers = max_available_cpus + + cpu_args = self.MARIAN_CPU_ARGS.format(num_threads = self.max_workers) + marian_args += f' {cpu_args}' + else: + try: + import torch + if torch.cuda.is_available(): + max_available_gpus = torch.cuda.device_count() + if self.max_workers == -1 or self.max_workers > max_available_gpus: + self.max_workers = max_available_cpus + except Exception: + pass + + device_indicies = ' '.join([str(i) for i in range(self.max_workers)]) + gpu_args = self.MARIAN_GPU_ARGS.format(device_indicies = device_indicies) + marian_args += f' {gpu_args}' + + marian_args += f' --mini-batch {self.mini_batch} --maxi-batch {self.maxi_batch}' + + self.model = Evaluator(marian_args) + + def process(self): + """ + Process the entire manifest in chunks. + For each pair of texts (source–target), compute the translation quality score. + Save the resulting scores in output_manifest_file. + """ + self.load_model() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + entries = [] + bitext_pairs = [] + for data_entry in manifest_chunk: + src = str(data_entry[self.source_text_field]).replace('\t', ' ') + tgt = str(data_entry[self.target_text_field]).replace('\t', ' ') + bitext_pairs.append(f'{src}\t{tgt}') + entries.append(data_entry) + + scores = self.model.evaluate(bitext_pairs) + for entry, score in tqdm(zip(entries, scores)): + metrics.append(score) + entry[self.output_field] = score + json.dump(entry, fout, ensure_ascii=False) + self.number_of_entries += 1 + fout.write("\n") + + self.finalize(metrics) + + def finalize(self, metrics): + """ + Print statistics about the quality scores: histogram, min, max, mean, median. + Use termplotlib to render the histogram directly in the terminal. + """ + logger.info("Total number of entries after processing: %d", self.number_of_entries) + logger.info("Histogram of scores:") + + bins = np.arange(0, 1.1, 0.1) + hist, bin_edges = np.histogram(metrics, bins=bins) + + labels = [] + for i in range(len(bin_edges) - 1): + left = f"{bin_edges[i]:.1f}" + right = f"{bin_edges[i+1]:.1f}" + if i < len(bin_edges) - 2: + labels.append(f"[{left}–{right})") + else: + labels.append(f"[{left}–{right}]") + + fig = tpl.figure() + fig.barh(hist, labels) + fig.show() + + logger.info(f"Min score: {np.min(metrics):.4f}") + logger.info(f"Max score: {np.max(metrics):.4f}") + logger.info(f"Mean score: {np.mean(metrics):.4f}") + logger.info(f"Median score: {np.median(metrics):.4f}") \ No newline at end of file diff --git a/tests/test_cometoid_qe.py b/tests/test_cometoid_qe.py new file mode 100644 index 00000000..e85b5c52 --- /dev/null +++ b/tests/test_cometoid_qe.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock, patch + +from sdp.processors.inference.quality_estimation.pymarian import CometoidWMTQualityEstimation + +@pytest.fixture(scope="module") +def mock_processor(): + processor = CometoidWMTQualityEstimation( + source_text_field="src", + target_text_field="tgt", + model_name_or_path="cometoid-wmt23", + output_field="cometoid_score", + device_type="cpu", + num_devices=1, + chunksize=1, + output_manifest_file="/tmp/test_output.jsonl", + ) + return processor + + +@patch("huggingface_hub.hf_hub_download", return_value="/tmp/dummy") +@patch("sdp.processors.inference.quality_estimation.pymarian.os.path.exists", return_value=True) +@patch("pymarian.Evaluator") +def test_load_model_with_mock(mock_eval, mock_exists, mock_hf_download, mock_processor): + mock_eval.return_value = MagicMock() + mock_processor.load_model() + assert mock_processor.model is not None + mock_hf_download.assert_called() + mock_eval.assert_called() + + +def test_process_dataset_entry(mock_processor): + mock_processor.model = MagicMock() + mock_processor.model.evaluate = MagicMock(return_value=[0.875]) + + entry = { + "src": "This is a test sentence.", + "tgt": "Dies ist ein Testsatz." + } + + mock_processor._chunk_manifest = lambda: [[entry]] + mock_processor.finalize = MagicMock() + mock_processor.number_of_entries = 0 + + # Patch load_model to avoid real downloading + with patch.object(mock_processor, "load_model"), \ + patch("builtins.open"), \ + patch("json.dump"), \ + patch("os.makedirs"): + mock_processor.process() + + mock_processor.model.evaluate.assert_called_once() + assert mock_processor.number_of_entries == 1 + + +@pytest.mark.parametrize("source,target", [ + ("Hello", "Hallo"), + ("Good morning", "Guten Morgen"), + ("How are you?", "Wie geht's dir?"), +]) +def test_score_format(mock_processor, source, target): + mock_processor.model = MagicMock() + mock_processor.model.evaluate = MagicMock(return_value=[0.9]) + + entry = {"src": source, "tgt": target} + mock_processor.output_field = "cometoid_score" + + bitext_pairs = [f"{source}\t{target}"] + scores = mock_processor.model.evaluate(bitext_pairs) + + assert isinstance(scores, list) + assert len(scores) == 1 + score = scores[0] + assert 0.0 <= score <= 1.0 \ No newline at end of file