-
Notifications
You must be signed in to change notification settings - Fork 31
CometoidWMTQualityEstimation processor implementation #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
f265077
CometoidWMTQualityEstimation processor implementation
ssh-meister dbde948
Merge branch 'main' into CometoidWMTQualityEstimation
ssh-meister 1682a17
Merge branch 'main' into CometoidWMTQualityEstimation
ssh-meister 1404a02
Changes addressing the reviewer’s comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,5 +3,4 @@ pyyaml | |
Sphinx | ||
sphinx-book-theme | ||
sphinx-copybutton | ||
sphinxext-opengraph | ||
tabulate | ||
sphinxext-opengraph |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
207 changes: 207 additions & 0 deletions
207
sdp/processors/inference/quality_estimation/pymarian.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import json | ||
from tqdm import tqdm | ||
import termplotlib as tpl | ||
import numpy as np | ||
|
||
from sdp.logging import logger | ||
from sdp.processors.base_processor import BaseProcessor | ||
|
||
|
||
class CometoidWMTQualityEstimation(BaseProcessor): | ||
""" | ||
A processor for estimating translation quality using pretrained COMET-like models | ||
based on MarianNMT and the pymarian Evaluator. | ||
|
||
This processor evaluates the quality of source-target text pairs (bitext) using | ||
COMETOID-style quality estimation and appends the resulting score to each dataset entry. | ||
|
||
Args: | ||
source_text_field (str): The key in the data entry containing the source (original) text. | ||
target_text_field (str): The key in the data entry containing the target (translated) text. | ||
model_name_or_path (str): Hugging Face model name or path to local model checkpoint. | ||
vocab_path (str, optional): Path to the vocabulary file. If None and model is from HF, it will be downloaded. | ||
save_model_to (str, optional): Directory to download and cache the model and vocab. | ||
mini_batch (int): Mini-batch size for evaluation. | ||
maxi_batch (int): Maxi-batch size for evaluation. | ||
output_field (str): The name of the field where the quality score will be saved in the output manifest. | ||
device_type (str): Device type to use: 'cpu' or 'gpu'. | ||
num_devices (int): Number of CPU threads or GPU devices to use. Use -1 to use all available. | ||
chunksize (int): Number of lines to process in each chunk. | ||
|
||
Returns: | ||
A manifest file where each entry has an added key (`output_field`) with the computed score. | ||
|
||
.. note:: | ||
This processor uses MarianNMT models fine-tuned for quality estimation. See https://marian-nmt.github.io/. | ||
|
||
Make sure to install `pymarian` before using this processor: | ||
pip install pymarian | ||
|
||
|
||
""" | ||
|
||
# Mapping of supported model aliases to Hugging Face repo paths | ||
MODEL_NAME_TO_HF_PATH = { | ||
"cometoid-wmt23": "marian-nmt/cometoid22-wmt23", | ||
"cometoid-wmt23-mqm": "marian-nmt/cometoid22-wmt23", | ||
} | ||
|
||
# Marian evaluation arguments depending on device | ||
MARIAN_GPU_ARGS = "-w 8000 -d {device_indicies}" | ||
MARIAN_CPU_ARGS = "-w 2000 --cpu-threads {num_threads}" | ||
|
||
def __init__(self, | ||
source_text_field: str, | ||
target_text_field: str, | ||
model_name_or_path: str, | ||
vocab_path: str = None, | ||
save_model_to: str = None, | ||
mini_batch: int = 16, | ||
maxi_batch: int = 96, | ||
output_field: str = 'cometoid_score', | ||
device_type: str = 'cpu', | ||
num_devices: int = -1, | ||
chunksize = 5000, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.source_text_field = source_text_field | ||
self.target_text_field = target_text_field | ||
self.model_name_or_path = model_name_or_path | ||
self.vocab_path = vocab_path | ||
self.save_model_to = save_model_to | ||
self.device_type = device_type | ||
self.max_workers = num_devices | ||
self.mini_batch = mini_batch | ||
self.maxi_batch = maxi_batch | ||
self.output_field = output_field | ||
self.model = None | ||
self.chunksize = chunksize | ||
|
||
def load_model(self): | ||
try: | ||
from pymarian import Evaluator | ||
except ImportError: | ||
raise ImportError("`pymarian` is not installed. Please install it using `pip install pymarian`.") | ||
|
||
from huggingface_hub import hf_hub_download | ||
|
||
""" | ||
Load the model and vocabulary from Hugging Face if necessary. | ||
Assemble command-line arguments for launching pymarian Evaluator. | ||
Depending on the device (CPU/GPU), configure parallelism parameters. | ||
""" | ||
repo_id = None | ||
if self.model_name_or_path in self.MODEL_NAME_TO_HF_PATH: | ||
repo_id = self.MODEL_NAME_TO_HF_PATH[self.model_name_or_path] | ||
self.model_name_or_path = hf_hub_download(repo_id, filename="checkpoints/marian.model.bin", local_dir = self.save_model_to) | ||
|
||
if not os.path.exists(self.model_name_or_path): | ||
raise ValueError(f'`model_name_or_path`: model name is not valid or model path does not exist ({self.model_name_or_path}).') | ||
|
||
if not self.vocab_path and repo_id is not None: | ||
self.vocab_path = hf_hub_download(repo_id=repo_id, filename="vocab.spm", local_dir = self.save_model_to) | ||
|
||
if not os.path.exists(self.vocab_path): | ||
raise FileNotFoundError(f'`vocab_path`: path does not exist ({self.vocab_path}).') | ||
|
||
marian_args = f"-m {self.model_name_or_path} -v {self.vocab_path} {self.vocab_path} --like comet-qe" | ||
|
||
if self.device_type == "cpu": | ||
max_available_cpus = os.cpu_count() | ||
if self.max_workers == -1 or self.max_workers > max_available_cpus: | ||
self.max_workers = max_available_cpus | ||
|
||
cpu_args = self.MARIAN_CPU_ARGS.format(num_threads = self.max_workers) | ||
marian_args += f' {cpu_args}' | ||
else: | ||
try: | ||
import torch | ||
if torch.cuda.is_available(): | ||
max_available_gpus = torch.cuda.device_count() | ||
if self.max_workers == -1 or self.max_workers > max_available_gpus: | ||
self.max_workers = max_available_cpus | ||
except Exception: | ||
pass | ||
|
||
device_indicies = ' '.join([str(i) for i in range(self.max_workers)]) | ||
gpu_args = self.MARIAN_GPU_ARGS.format(device_indicies = device_indicies) | ||
marian_args += f' {gpu_args}' | ||
|
||
marian_args += f' --mini-batch {self.mini_batch} --maxi-batch {self.maxi_batch}' | ||
|
||
self.model = Evaluator(marian_args) | ||
|
||
def process(self): | ||
""" | ||
Process the entire manifest in chunks. | ||
For each pair of texts (source–target), compute the translation quality score. | ||
Save the resulting scores in output_manifest_file. | ||
""" | ||
self.load_model() | ||
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) | ||
metrics = [] | ||
|
||
with open(self.output_manifest_file, "wt", encoding="utf8") as fout: | ||
for manifest_chunk in self._chunk_manifest(): | ||
entries = [] | ||
bitext_pairs = [] | ||
for data_entry in manifest_chunk: | ||
src = str(data_entry[self.source_text_field]).replace('\t', ' ') | ||
tgt = str(data_entry[self.target_text_field]).replace('\t', ' ') | ||
bitext_pairs.append(f'{src}\t{tgt}') | ||
entries.append(data_entry) | ||
|
||
scores = self.model.evaluate(bitext_pairs) | ||
for entry, score in tqdm(zip(entries, scores)): | ||
metrics.append(score) | ||
entry[self.output_field] = score | ||
json.dump(entry, fout, ensure_ascii=False) | ||
self.number_of_entries += 1 | ||
fout.write("\n") | ||
|
||
self.finalize(metrics) | ||
|
||
def finalize(self, metrics): | ||
""" | ||
Print statistics about the quality scores: histogram, min, max, mean, median. | ||
Use termplotlib to render the histogram directly in the terminal. | ||
""" | ||
logger.info("Total number of entries after processing: %d", self.number_of_entries) | ||
logger.info("Histogram of scores:") | ||
|
||
bins = np.arange(0, 1.1, 0.1) | ||
hist, bin_edges = np.histogram(metrics, bins=bins) | ||
|
||
labels = [] | ||
for i in range(len(bin_edges) - 1): | ||
left = f"{bin_edges[i]:.1f}" | ||
right = f"{bin_edges[i+1]:.1f}" | ||
if i < len(bin_edges) - 2: | ||
labels.append(f"[{left}–{right})") | ||
else: | ||
labels.append(f"[{left}–{right}]") | ||
|
||
fig = tpl.figure() | ||
fig.barh(hist, labels) | ||
fig.show() | ||
|
||
logger.info(f"Min score: {np.min(metrics):.4f}") | ||
logger.info(f"Max score: {np.max(metrics):.4f}") | ||
logger.info(f"Mean score: {np.mean(metrics):.4f}") | ||
logger.info(f"Median score: {np.median(metrics):.4f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
from unittest.mock import MagicMock, patch | ||
|
||
from sdp.processors.inference.quality_estimation.pymarian import CometoidWMTQualityEstimation | ||
|
||
@pytest.fixture(scope="module") | ||
def mock_processor(): | ||
processor = CometoidWMTQualityEstimation( | ||
source_text_field="src", | ||
target_text_field="tgt", | ||
model_name_or_path="cometoid-wmt23", | ||
output_field="cometoid_score", | ||
device_type="cpu", | ||
num_devices=1, | ||
chunksize=1, | ||
output_manifest_file="/tmp/test_output.jsonl", | ||
) | ||
return processor | ||
|
||
|
||
@patch("huggingface_hub.hf_hub_download", return_value="/tmp/dummy") | ||
@patch("sdp.processors.inference.quality_estimation.pymarian.os.path.exists", return_value=True) | ||
@patch("pymarian.Evaluator") | ||
def test_load_model_with_mock(mock_eval, mock_exists, mock_hf_download, mock_processor): | ||
mock_eval.return_value = MagicMock() | ||
mock_processor.load_model() | ||
assert mock_processor.model is not None | ||
mock_hf_download.assert_called() | ||
mock_eval.assert_called() | ||
|
||
|
||
def test_process_dataset_entry(mock_processor): | ||
mock_processor.model = MagicMock() | ||
mock_processor.model.evaluate = MagicMock(return_value=[0.875]) | ||
|
||
entry = { | ||
"src": "This is a test sentence.", | ||
"tgt": "Dies ist ein Testsatz." | ||
} | ||
|
||
mock_processor._chunk_manifest = lambda: [[entry]] | ||
mock_processor.finalize = MagicMock() | ||
mock_processor.number_of_entries = 0 | ||
|
||
# Patch load_model to avoid real downloading | ||
with patch.object(mock_processor, "load_model"), \ | ||
patch("builtins.open"), \ | ||
patch("json.dump"), \ | ||
patch("os.makedirs"): | ||
mock_processor.process() | ||
|
||
mock_processor.model.evaluate.assert_called_once() | ||
assert mock_processor.number_of_entries == 1 | ||
|
||
|
||
@pytest.mark.parametrize("source,target", [ | ||
("Hello", "Hallo"), | ||
("Good morning", "Guten Morgen"), | ||
("How are you?", "Wie geht's dir?"), | ||
]) | ||
def test_score_format(mock_processor, source, target): | ||
mock_processor.model = MagicMock() | ||
mock_processor.model.evaluate = MagicMock(return_value=[0.9]) | ||
|
||
entry = {"src": source, "tgt": target} | ||
mock_processor.output_field = "cometoid_score" | ||
|
||
bitext_pairs = [f"{source}\t{target}"] | ||
scores = mock_processor.model.evaluate(bitext_pairs) | ||
|
||
assert isinstance(scores, list) | ||
assert len(scores) == 1 | ||
score = scores[0] | ||
assert 0.0 <= score <= 1.0 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.