From 02c54090d5f5372db9ea0be8bdb906e4859ba18e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 24 Jun 2025 20:29:07 +0200 Subject: [PATCH 01/20] api code to download model from hugging face --- .gitignore | 1 + api/hugging_face.py | 56 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 .gitignore create mode 100644 api/hugging_face.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..905568c --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/api/api_models diff --git a/api/hugging_face.py b/api/hugging_face.py new file mode 100644 index 0000000..cfcf97c --- /dev/null +++ b/api/hugging_face.py @@ -0,0 +1,56 @@ +import shutil +from pathlib import Path + +from huggingface_hub import hf_hub_download + +# Updated registry: use a list of filenames if you're downloading a folder +MODEL_REGISTRY = { + "electra": { + "repo_id": "aditya0by0/python-chebifier", + "subfolder": "electra", + "filenames": ["electra.ckpt", "classes.txt"], + } +} + +DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models" + + +def download_model(model_name): + if model_name not in MODEL_REGISTRY: + raise ValueError( + f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}" + ) + + model_info = MODEL_REGISTRY[model_name] + repo_id = model_info["repo_id"] + subfolder = model_info["subfolder"] + filenames = model_info["filenames"] + + local_paths = [] + for filename in filenames: + local_model_path = DOWNLOAD_PATH / model_name / filename + if local_model_path.exists(): + print(f"File already exists: {local_model_path}") + local_paths.append(local_model_path) + continue + + print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})") + downloaded_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + ) + + local_model_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(downloaded_file, local_model_path) + print(f"Saved to: {local_model_path}") + local_paths.append(local_model_path) + + return local_paths + + +if __name__ == "__main__": + paths = download_model("electra") + print("Downloaded files:") + for p in paths: + print(p) From b539f0af136692aa83ea4370896f97151b359dc7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:33:39 +0200 Subject: [PATCH 02/20] Create .pre-commit-config.yaml --- .pre-commit-config.yaml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e32d80c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/psf/black + rev: "24.2.0" + hooks: + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.1 + hooks: + - id: ruff + args: [] # No --fix, disables formatting From 2c2aba2315ae0fc10c501dd4db4c44e8619df8c0 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:34:40 +0200 Subject: [PATCH 03/20] utility to setup env and model package dependencies --- api/__init__.py | 0 api/setup_env.py | 165 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 api/__init__.py create mode 100644 api/setup_env.py diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/setup_env.py b/api/setup_env.py new file mode 100644 index 0000000..a246c26 --- /dev/null +++ b/api/setup_env.py @@ -0,0 +1,165 @@ +import os +import re +import subprocess +import sys +from pathlib import Path + +# Conditional import of tomllib based on Python version +if sys.version_info >= (3, 11): + import tomllib # built-in in Python 3.11+ +else: + import toml as tomllib # use third-party toml library for older versions + + +class SetupEnvAndPackage: + """Utility class for cloning a repository, setting up a virtual environment, and installing a package.""" + + def setup( + self, + repo_url: str, + clone_dir: Path, + venv_dir: Path, + venv_name: str = ".venv-chebifier", + ) -> None: + """ + Orchestrates the full setup process: cloning the repository, + creating a virtual environment, and installing the package. + + Args: + repo_url (str): URL of the Git repository. + clone_dir (Path): Directory to clone the repo into. + venv_dir (Path): Directory where the virtual environment will be created. + venv_name (str): Name of the virtual environment folder. + """ + cloned_repo_path = self._clone_repo(repo_url, clone_dir) + venv_path = self._create_virtualenv(venv_dir, venv_name) + self._install_from_pyproject(venv_path, cloned_repo_path) + + def _clone_repo(self, repo_url: str, clone_dir: Path) -> Path: + """ + Clone a Git repository into a specified directory. + + Args: + repo_url (str): Git URL to clone. + clone_dir (Path): Directory to clone into. + + Returns: + Path: Path to the cloned repository. + """ + repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "") + clone_path = Path(clone_dir or repo_name) + + if not clone_path.exists(): + print(f"Cloning {repo_url} into {clone_path}...") + subprocess.check_call( + ["git", "clone", "--depth=1", repo_url, str(clone_path)] + ) + else: + print(f"Repo already exists at {clone_path}") + + return clone_path + + @staticmethod + def _create_virtualenv(venv_dir: Path, venv_name: str = ".venv-chebifier") -> Path: + """ + Create a virtual environment at the specified path. + + Args: + venv_dir (Path): Base directory where the venv will be created. + venv_name (str): Name of the virtual environment directory. + + Returns: + Path: Path to the virtual environment. + """ + venv_path = venv_dir / venv_name + + if venv_path.exists(): + print(f"Virtual environment already exists at: {venv_path}") + return venv_path + + print(f"Creating virtual environment at: {venv_path}") + + try: + subprocess.check_call(["virtualenv", str(venv_path)]) + except FileNotFoundError: + print("virtualenv not found, installing it now...") + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "virtualenv"] + ) + subprocess.check_call(["virtualenv", str(venv_path)]) + + return venv_path + + def _install_from_pyproject(self, venv_dir: Path, cloned_repo_path: Path) -> None: + """ + Install the cloned package in editable mode. + + Args: + venv_dir (Path): Path to the virtual environment. + cloned_repo_path (Path): Path to the cloned repository. + """ + pip_executable = ( + venv_dir / "Scripts" / "pip.exe" + if os.name == "nt" + else venv_dir / "bin" / "pip" + ) + + if not pip_executable.exists(): + raise FileNotFoundError(f"pip not found at {pip_executable}") + + try: + package_name = self._get_package_name(cloned_repo_path) + except Exception as e: + raise RuntimeError(f"Error extracting package name: {e}") + + try: + subprocess.check_output( + [str(pip_executable), "show", package_name], stderr=subprocess.DEVNULL + ) + print(f"Package '{package_name}' is already installed.") + except subprocess.CalledProcessError: + print(f"Installing '{package_name}' from {cloned_repo_path}...") + subprocess.check_call( + [str(pip_executable), "install", "-e", "."], + cwd=cloned_repo_path, + ) + + @staticmethod + def _get_package_name(cloned_repo_path: Path) -> str: + """ + Extracts the package name from `pyproject.toml` or `setup.py`. + + Args: + cloned_repo_path (Path): Path to the cloned repository. + + Returns: + str: Name of the Python package. + + Raises: + ValueError: If parsing fails. + FileNotFoundError: If neither config file is found. + """ + pyproject_path = cloned_repo_path / "pyproject.toml" + setup_path = cloned_repo_path / "setup.py" + + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + pyproject = tomllib.load(f) + return pyproject["project"]["name"] + except Exception as e: + raise ValueError(f"Failed to parse pyproject.toml: {e}") + + elif setup_path.exists(): + try: + setup_contents = setup_path.read_text() + match = re.search(r'name\s*=\s*[\'"]([^\'"]+)[\'"]', setup_contents) + if match: + return match.group(1) + else: + raise ValueError("Could not find package name in setup.py") + except Exception as e: + raise ValueError(f"Failed to parse setup.py: {e}") + + else: + raise FileNotFoundError("Neither pyproject.toml nor setup.py found.") From 2b9f335c060040c6f86242435fa8bd05d3678ea6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 00:37:03 +0200 Subject: [PATCH 04/20] `gather_predictions` will return predicted_classes_dict --- chebifier/ensemble/base_ensemble.py | 70 ++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index edaaf5e..1869923 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,8 +1,8 @@ import os from abc import ABC + import torch import tqdm -from rdkit import Chem from chebifier.prediction_models.base_predictor import BasePredictor from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor @@ -12,11 +12,11 @@ MODEL_TYPES = { "electra": ElectraPredictor, "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor + "chemlog": ChemLogPredictor, } -class BaseEnsemble(ABC): +class BaseEnsemble(ABC): def __init__(self, model_configs: dict): self.models = [] self.positive_prediction_threshold = 0.5 @@ -37,22 +37,30 @@ def gather_predictions(self, smiles_list): if logits_for_smiles is not None: for cls in logits_for_smiles: predicted_classes.add(cls) - print(f"Sorting predictions...") + print("Sorting predictions...") predicted_classes = sorted(list(predicted_classes)) predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)} - ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan + ordered_logits = ( + torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) + * torch.nan + ) for i, model_prediction in enumerate(model_predictions): - for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction), - total=len(model_prediction), - desc=f"Sorting predictions for {self.models[i].model_name}"): + for j, logits_for_smiles in tqdm.tqdm( + enumerate(model_prediction), + total=len(model_prediction), + desc=f"Sorting predictions for {self.models[i].model_name}", + ): if logits_for_smiles is not None: for cls in logits_for_smiles: - ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls] + ordered_logits[j, predicted_classes_dict[cls], i] = ( + logits_for_smiles[cls] + ) - return ordered_logits, predicted_classes + return ordered_logits, predicted_classes_dict - - def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs): + def consolidate_predictions( + self, predictions, predicted_classes, classwise_weights, **kwargs + ): """ Aggregates predictions from multiple models using weighted majority voting. Optimized version using tensor operations instead of for loops. @@ -74,7 +82,9 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig positive_mask = (predictions > 0.5) & valid_predictions negative_mask = (predictions < 0.5) & valid_predictions - confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + confidence = 2 * torch.abs( + predictions.nan_to_num() - self.positive_prediction_threshold + ) # Extract positive and negative weights pos_weights = classwise_weights[0] # Shape: (num_classes, num_models) @@ -83,8 +93,12 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig # Calculate weighted predictions using broadcasting # predictions shape: (num_smiles, num_classes, num_models) # weights shape: (num_classes, num_models) - positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) - negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + positive_weighted = ( + positive_mask.float() * confidence * pos_weights.unsqueeze(0) + ) + negative_weighted = ( + negative_mask.float() * confidence * neg_weights.unsqueeze(0) + ) # Sum over models dimension positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) @@ -92,17 +106,21 @@ def consolidate_predictions(self, predictions, predicted_classes, classwise_weig # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) # Convert tensor decisions to result list using list comprehension for efficiency result = [ - [class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]] + [ + class_indices[idx.item()] + for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0] + ] for i in range(num_smiles) ] return result - def calculate_classwise_weights(self, predicted_classes): """No weights, simple majority voting""" positive_weights = torch.ones(len(predicted_classes), len(self.models)) @@ -114,18 +132,26 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + ordered_predictions, predicted_classes = self.gather_predictions( + smiles_list + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: for cls in predicted_classes: f.write(f"{cls}\n") else: - print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + print( + f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + ) ordered_predictions = torch.load(preds_file) with open(predicted_classes_file, "r") as f: - predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + predicted_classes = { + line.strip(): i for i, line in enumerate(f.readlines()) + } classwise_weights = self.calculate_classwise_weights(predicted_classes) - aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights) + aggregated_predictions = self.consolidate_predictions( + ordered_predictions, predicted_classes, classwise_weights + ) return aggregated_predictions From 6faf3bd67298b6e33efc226f1ff612b1e4846498 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 10:25:52 +0200 Subject: [PATCH 05/20] use package namespace imports for prediction models --- chebifier/ensemble/base_ensemble.py | 10 ++++++---- chebifier/prediction_models/__init__.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 1869923..d4a4fe3 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,10 +4,12 @@ import torch import tqdm -from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor -from chebifier.prediction_models.electra_predictor import ElectraPredictor -from chebifier.prediction_models.gnn_predictor import ResGatedPredictor +from chebifier.prediction_models import ( + BasePredictor, + ChemLogPredictor, + ElectraPredictor, + ResGatedPredictor, +) MODEL_TYPES = { "electra": ElectraPredictor, diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index e69de29..ed08890 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -0,0 +1,6 @@ +from .base_predictor import BasePredictor +from .chemlog_predictor import ChemLogPredictor +from .electra_predictor import ElectraPredictor +from .gnn_predictor import ResGatedPredictor + +__all__ = ["BasePredictor", "ChemLogPredictor", "ElectraPredictor", "ResGatedPredictor"] From a4f5f85cc4eaaeffe4a6a5973826e49ba43f520c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:20:22 +0200 Subject: [PATCH 06/20] add hugging face api --- api/hugging_face.py | 58 ++++++++++++++------------------------------- 1 file changed, 18 insertions(+), 40 deletions(-) diff --git a/api/hugging_face.py b/api/hugging_face.py index cfcf97c..19debb4 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -3,54 +3,32 @@ from huggingface_hub import hf_hub_download -# Updated registry: use a list of filenames if you're downloading a folder -MODEL_REGISTRY = { - "electra": { - "repo_id": "aditya0by0/python-chebifier", - "subfolder": "electra", - "filenames": ["electra.ckpt", "classes.txt"], - } -} -DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models" - - -def download_model(model_name): - if model_name not in MODEL_REGISTRY: - raise ValueError( - f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}" - ) - - model_info = MODEL_REGISTRY[model_name] - repo_id = model_info["repo_id"] - subfolder = model_info["subfolder"] - filenames = model_info["filenames"] - - local_paths = [] - for filename in filenames: - local_model_path = DOWNLOAD_PATH / model_name / filename - if local_model_path.exists(): - print(f"File already exists: {local_model_path}") - local_paths.append(local_model_path) +def download_model_files(model_config: dict, download_path: Path): + repo_id = model_config["repo_id"] + subfolder = model_config["subfolder"] + filenames = model_config["files"] + + local_paths = {} + for file_type, filename in filenames.items(): + local_file_path = download_path / filename + if local_file_path.exists(): + print(f"File already exists: {local_file_path}") + local_paths[file_type] = local_file_path continue - print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})") + print( + f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}" + ) downloaded_file = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) - local_model_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(downloaded_file, local_model_path) - print(f"Saved to: {local_model_path}") - local_paths.append(local_model_path) + local_file_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(downloaded_file, local_file_path) + print(f"Saved to: {local_file_path}") + local_paths[file_type] = local_file_path return local_paths - - -if __name__ == "__main__": - paths = download_model("electra") - print("Downloaded files:") - for p in paths: - print(p) From 481a2eb6da4fd12de4f98ecd3efdb6268e2affcc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:20:57 +0200 Subject: [PATCH 07/20] api registry --- api/registry.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 api/registry.yml diff --git a/api/registry.yml b/api/registry.yml new file mode 100644 index 0000000..c9069a8 --- /dev/null +++ b/api/registry.yml @@ -0,0 +1,23 @@ +electra: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt + repo_url: https://github.com/ChEB-AI/python-chebai + wrapper: chebifier.prediction_models.ElectraPredictor + +resgated: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: resgated + files: + ckpt: resgated.ckpt + labels: classes.txt + repo_url: https://github.com/ChEB-AI/python-chebai-graph + wrapper: chebifier.prediction_models.ResGatedPredictor + +chemlog: + repo_url: https://github.com/sfluegel05/chemlog-peptides + wrapper: chebifier.prediction_models.ChemLogPredictor From 584b6a6ac21b230732f7200078100e13896657c9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:21:19 +0200 Subject: [PATCH 08/20] api cli --- api/__main__.py | 10 +++++ api/cli.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 api/__main__.py create mode 100644 api/cli.py diff --git a/api/__main__.py b/api/__main__.py new file mode 100644 index 0000000..ec70a17 --- /dev/null +++ b/api/__main__.py @@ -0,0 +1,10 @@ +from .cli import cli + +if __name__ == "__main__": + """ + Entry point for the CLI application. + + This script calls the `cli` function from the `api.cli` module + when executed as the main program. + """ + cli() diff --git a/api/cli.py b/api/cli.py new file mode 100644 index 0000000..6f94431 --- /dev/null +++ b/api/cli.py @@ -0,0 +1,114 @@ +import importlib +from pathlib import Path + +import click +import yaml + +from chebifier.prediction_models.base_predictor import BasePredictor + +from .hugging_face import download_model_files +from .setup_env import SetupEnvAndPackage + +yaml_path = Path("api/registry.yml") +if yaml_path.exists(): + with yaml_path.open("r") as f: + model_registry = yaml.safe_load(f) +else: + raise FileNotFoundError(f"{yaml_path} not found.") + + +@click.group() +def cli(): + """Command line interface for Api-Chebifier.""" + pass + + +@cli.command() +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--model-type", + "-m", + type=click.Choice(model_registry.keys()), + default="mv", + help="Type of model to use", +) +def predict(smiles, smiles_file, output, model_type): + """Predict ChEBI classes for SMILES strings using an ensemble model. + + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. + """ + + # Collect SMILES strings from arguments and/or file + smiles_list = list(smiles) + if smiles_file: + with open(smiles_file, "r") as f: + smiles_list.extend([line.strip() for line in f if line.strip()]) + + if not smiles_list: + click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") + return + + model_config = model_registry[model_type] + predictor_kwargs = {"model_name": model_type} + + current_dir = Path(__file__).resolve().parent + + if "hugging_face" in model_config: + local_file_path = download_model_files( + model_config["hugging_face"], + current_dir / ".api_models" / model_type, + ) + predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] + predictor_kwargs["target_labels_path"] = local_file_path["labels"] + + SetupEnvAndPackage().setup( + repo_url=model_config["repo_url"], + clone_dir=current_dir / ".cloned_repos", + venv_dir=current_dir, + ) + + model_cls_path = model_config["wrapper"] + module_path, class_name = model_cls_path.rsplit(".", 1) + module = importlib.import_module(module_path) + model_cls: type = getattr(module, class_name) + model_instance = model_cls(**predictor_kwargs) + assert isinstance(model_instance, BasePredictor) + + # Make predictions + predictions = model_instance.predict_smiles_list(smiles_list) + + if output: + # save as json + import json + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) + + else: + # Print results + for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)): + click.echo(f"Result for: {smiles}") + if prediction: + click.echo(f" Predicted classes: {', '.join(map(str, prediction))}") + else: + click.echo(" No predictions") + + +if __name__ == "__main__": + cli() From 05d8580358038047aa387c2080f2e4f78aa84196 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 13:21:33 +0200 Subject: [PATCH 09/20] Update .gitignore --- .gitignore | 180 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/.gitignore b/.gitignore index 905568c..90044ae 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,181 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt + +build +.virtual_documents +.jupyter +chebai.egg-info +lightning_logs +logs +.isort.cfg +/.vscode /api/api_models +/api/.api_models +/api/.cloned_repos From 997120e9574c260b3c8199aa20cc05ac67032193 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 19:20:17 +0200 Subject: [PATCH 10/20] use hugging face's cache system instead of custom file management --- .gitignore | 2 -- api/cli.py | 6 ++---- api/hugging_face.py | 50 +++++++++++++++++++++++++++++---------------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 90044ae..613c70b 100644 --- a/.gitignore +++ b/.gitignore @@ -176,6 +176,4 @@ lightning_logs logs .isort.cfg /.vscode -/api/api_models -/api/.api_models /api/.cloned_repos diff --git a/api/cli.py b/api/cli.py index 6f94431..99f7572 100644 --- a/api/cli.py +++ b/api/cli.py @@ -66,10 +66,8 @@ def predict(smiles, smiles_file, output, model_type): current_dir = Path(__file__).resolve().parent if "hugging_face" in model_config: - local_file_path = download_model_files( - model_config["hugging_face"], - current_dir / ".api_models" / model_type, - ) + print(f"For model type `{model_type}` following files are used:") + local_file_path = download_model_files(model_config["hugging_face"]) predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] predictor_kwargs["target_labels_path"] = local_file_path["labels"] diff --git a/api/hugging_face.py b/api/hugging_face.py index 19debb4..62d16e8 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -1,34 +1,48 @@ -import shutil +""" +Hugging Face Api: + - For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations + + Refer for Hugging Face Hub caching and versioning documentation: + https://huggingface.co/docs/huggingface_hub/en/guides/download + https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache +""" + from pathlib import Path from huggingface_hub import hf_hub_download -def download_model_files(model_config: dict, download_path: Path): +def download_model_files( + model_config: dict[str, str | dict[str, str]], +) -> dict[str, Path]: + """ + Downloads specified model files from a Hugging Face Hub repository using hf_hub_download. + + Hugging Face Hub provides internal caching and versioning, so file management or duplication + checks are not required. + + Args: + model_config (Dict[str, str | Dict[str, str]]): A dictionary containing: + - 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname'). + - 'subfolder' (str): The subfolder within the repo where the files are located. + - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to + actual file names (e.g., 'electra.ckpt', 'classes.txt'). + + Returns: + Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file. + """ repo_id = model_config["repo_id"] subfolder = model_config["subfolder"] filenames = model_config["files"] - local_paths = {} + local_paths: dict[str, Path] = {} for file_type, filename in filenames.items(): - local_file_path = download_path / filename - if local_file_path.exists(): - print(f"File already exists: {local_file_path}") - local_paths[file_type] = local_file_path - continue - - print( - f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}" - ) - downloaded_file = hf_hub_download( + downloaded_file_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) - - local_file_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(downloaded_file, local_file_path) - print(f"Saved to: {local_file_path}") - local_paths[file_type] = local_file_path + local_paths[file_type] = Path(downloaded_file_path) + print(f"\t Using file `{filename}` from: {downloaded_file_path}") return local_paths From 9c3beea542985ca28e38e8a69843bec83a6e77e7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 28 Jun 2025 19:33:08 +0200 Subject: [PATCH 11/20] pre-commit -run -a --- chebifier/cli.py | 63 +++++++++++++------ .../ensemble/weighted_majority_ensemble.py | 21 ++++--- chebifier/prediction_models/base_predictor.py | 18 ++++-- .../prediction_models/chemlog_predictor.py | 33 ++++++---- .../prediction_models/electra_predictor.py | 11 ++-- chebifier/prediction_models/gnn_predictor.py | 54 +++++++++++----- chebifier/prediction_models/nn_predictor.py | 46 ++++++++++---- 7 files changed, 169 insertions(+), 77 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 704f8a0..a6b8743 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,11 +1,11 @@ - - - import click import yaml -import sys + from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble +from chebifier.ensemble.weighted_majority_ensemble import ( + WMVwithF1Ensemble, + WMVwithPPVNPVEnsemble, +) @click.group() @@ -13,36 +13,54 @@ def cli(): """Command line interface for Chebifier.""" pass + ENSEMBLES = { "mv": BaseEnsemble, "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble + "wmv-f1": WMVwithF1Ensemble, } + @cli.command() -@click.argument('config_file', type=click.Path(exists=True)) -@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') -@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') -@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') -@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') +@click.argument("config_file", type=click.Path(exists=True)) +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--ensemble-type", + "-e", + type=click.Choice(ENSEMBLES.keys()), + default="mv", + help="Type of ensemble to use (default: Majority Voting)", +) def predict(config_file, smiles, smiles_file, output, ensemble_type): """Predict ChEBI classes for SMILES strings using an ensemble model. - + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. """ # Load configuration from YAML file - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = yaml.safe_load(f) - + # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type](config) - + # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) if smiles_file: - with open(smiles_file, 'r') as f: + with open(smiles_file, "r") as f: smiles_list.extend([line.strip() for line in f if line.strip()]) - + if not smiles_list: click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return @@ -53,8 +71,13 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type): if output: # save as json import json - with open(output, 'w') as f: - json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) else: # Print results @@ -66,5 +89,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type): click.echo(" No predictions") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py index 811770d..95e9956 100644 --- a/chebifier/ensemble/weighted_majority_ensemble.py +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -3,9 +3,7 @@ from chebifier.ensemble.base_ensemble import BaseEnsemble - class WMVwithPPVNPVEnsemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -23,15 +21,18 @@ def calculate_classwise_weights(self, predicted_classes): positive_weights[predicted_classes[cls], j] *= weights["PPV"] negative_weights[predicted_classes[cls], j] *= weights["NPV"] - print(f"Calculated model weightings. The averages for positive / negative weights are:") + print( + "Calculated model weightings. The averages for positive / negative weights are:" + ) for i, model in enumerate(self.models): - print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}") + print( + f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}" + ) return positive_weights, negative_weights class WMVwithF1Ensemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -45,11 +46,15 @@ def calculate_classwise_weights(self, predicted_classes): continue for cls, weights in model.classwise_weights.items(): if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0: - f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + f1 = ( + 2 + * weights["TP"] + / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + ) weights_by_cls[predicted_classes[cls], j] *= f1 - print(f"Calculated model weightings. The average weights are:") + print("Calculated model weightings. The average weights are:") for i, model in enumerate(self.models): print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") - return weights_by_cls, weights_by_cls \ No newline at end of file + return weights_by_cls, weights_by_cls diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index 5633458..e6b7952 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,16 +1,24 @@ -from abc import ABC import json +from abc import ABC + class BasePredictor(ABC): - def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs): + def __init__( + self, + model_name: str, + model_weight: int = 1, + classwise_weights_path: str = None, + **kwargs + ): self.model_name = model_name self.model_weight = model_weight if classwise_weights_path is not None: - self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8")) + self.classwise_weights = json.load( + open(classwise_weights_path, encoding="utf-8") + ) else: self.classwise_weights = None - def predict_smiles_list(self, smiles_list: list[str]) -> dict: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 54b020a..692c79c 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -1,23 +1,22 @@ import tqdm +from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call from chebifier.prediction_models.base_predictor import BasePredictor -from chemlog.alg_classification.charge_classifier import AlgChargeClassifier -from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier -from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier -from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier -from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS -class ChemLogPredictor(BasePredictor): +class ChemLogPredictor(BasePredictor): def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.strategy = "algo" self.classifier_instances = { k: v() for k, v in CLASSIFIERS[self.strategy].items() } - self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", - "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"] - + # fmt: off + self.peptide_labels = [ + "15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", + "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837" + ] + # fmt: on print(f"Initialised ChemLog model {self.model_name}") def predict_smiles_list(self, smiles_list: list[str]) -> list: @@ -27,9 +26,21 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: if mol is None: results.append(None) else: - results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels}) + results.append( + { + label: ( + 1 + if label + in strategy_call( + self.strategy, self.classifier_instances, mol + )["chebi_classes"] + else 0 + ) + for label in self.peptide_labels + } + ) for classifier in self.classifier_instances.values(): classifier.on_finish() - return results \ No newline at end of file + return results diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index 075eafa..7a3bcaa 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,7 +1,8 @@ -from chebifier.prediction_models.nn_predictor import NNPredictor from chebai.models.electra import Electra from chebai.preprocessing.reader import ChemDataReader +from chebifier.prediction_models.nn_predictor import NNPredictor + class ElectraPredictor(NNPredictor): @@ -13,10 +14,10 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra: model = Electra.load_from_checkpoint( ckpt_path, map_location=self.device, - criterion=None, strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, ) model.eval() return model - - diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index b139c6c..ef354c1 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -1,16 +1,19 @@ -from chebifier.prediction_models.nn_predictor import NNPredictor import chebai_graph.preprocessing.properties as p import torch from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred -from chebai_graph.preprocessing.reader import GraphPropertyReader from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder +from chebai_graph.preprocessing.reader import GraphPropertyReader from torch_geometric.data.data import Data as GeomData +from chebifier.prediction_models.nn_predictor import NNPredictor + class ResGatedPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): - super().__init__(model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs) + super().__init__( + model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs + ) # molecular_properties is a list of class paths if molecular_properties is not None: properties = [self.load_class(prop)() for prop in molecular_properties] @@ -32,11 +35,23 @@ def load_class(self, class_path: str): def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( - ckpt_path, map_location=torch.device(self.device), criterion=None, strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, - config={"in_length": 256, "hidden_length": 512, "dropout_rate": 0.1, "n_conv_layers": 3, - "n_linear_layers": 3, "n_atom_properties": 158, "n_bond_properties": 7, - "n_molecule_properties": 200}) + ckpt_path, + map_location=torch.device(self.device), + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, + config={ + "in_length": 256, + "hidden_length": 512, + "dropout_rate": 0.1, + "n_conv_layers": 3, + "n_linear_layers": 3, + "n_atom_properties": 158, + "n_bond_properties": 7, + "n_molecule_properties": 200, + }, + ) model.eval() return model @@ -55,14 +70,21 @@ def read_smiles(self, smiles): # use default value if we meet an unseen value if isinstance(prop.encoder, IndexEncoder): if str(value) in prop.encoder.cache: - index = prop.encoder.cache.index(str(value)) + prop.encoder.offset + index = ( + prop.encoder.cache.index(str(value)) + prop.encoder.offset + ) else: index = 0 - print(f"Unknown property value {value} for property {prop} at smiles {smiles}") + print( + f"Unknown property value {value} for property {prop} at smiles {smiles}" + ) if isinstance(prop.encoder, OneHotEncoder): - encoded_values.append(torch.nn.functional.one_hot( - torch.tensor(index), num_classes=prop.encoder.get_encoding_length() - )) + encoded_values.append( + torch.nn.functional.one_hot( + torch.tensor(index), + num_classes=prop.encoder.get_encoding_length(), + ) + ) else: encoded_values.append(torch.tensor([index])) @@ -77,9 +99,7 @@ def read_smiles(self, smiles): if len(encoded_values.size()) == 1: encoded_values = encoded_values.unsqueeze(1) else: - encoded_values = torch.zeros( - (0, prop.encoder.get_encoding_length()) - ) + encoded_values = torch.zeros((0, prop.encoder.get_encoding_length())) if isinstance(prop, p.AtomProperty): x = torch.cat([x, encoded_values], dim=1) elif isinstance(prop, p.BondProperty): @@ -93,4 +113,4 @@ def read_smiles(self, smiles): edge_attr=edge_attr, molecule_attr=molecule_attr, ) - return d \ No newline at end of file + return d diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 1ee5e46..9f2e00a 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,24 +1,35 @@ +import numpy as np +import torch import tqdm +from rdkit import Chem from chebifier.prediction_models.base_predictor import BasePredictor -from rdkit import Chem -import numpy as np -import torch + class NNPredictor(BasePredictor): - def __init__(self, model_name: str, ckpt_path: str, reader_cls, target_labels_path: str, **kwargs): + def __init__( + self, + model_name: str, + ckpt_path: str, + reader_cls, + target_labels_path: str, + **kwargs, + ): super().__init__(model_name, **kwargs) self.reader_cls = reader_cls self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.init_model(ckpt_path=ckpt_path) - self.target_labels = [line.strip() for line in open(target_labels_path, encoding="utf-8")] + self.target_labels = [ + line.strip() for line in open(target_labels_path, encoding="utf-8") + ] self.batch_size = kwargs.get("batch_size", 1) - def init_model(self, ckpt_path: str, **kwargs): - raise NotImplementedError("Model initialization must be implemented in subclasses.") + raise NotImplementedError( + "Model initialization must be implemented in subclasses." + ) def calculate_results(self, batch): collator = self.reader_cls.COLLATOR() @@ -66,14 +77,27 @@ def predict_smiles_list(self, smiles_list) -> list: token_dicts.append(d) results = [] if token_dicts: - for batch in tqdm.tqdm(self.batchify(token_dicts), desc=f"{self.model_name}", total=len(token_dicts)//self.batch_size): + for batch in tqdm.tqdm( + self.batchify(token_dicts), + desc=f"{self.model_name}", + total=len(token_dicts) // self.batch_size, + ): result = self.calculate_results(batch) if isinstance(result, dict) and "logits" in result: result = result["logits"] results += torch.sigmoid(result).cpu().detach().tolist() results = np.stack(results, axis=0) - preds = [{self.target_labels[j]: p for j, p in enumerate(results[index_map[i]])} - if i not in could_not_parse else None for i in range(len(smiles_list))] + preds = [ + ( + { + self.target_labels[j]: p + for j, p in enumerate(results[index_map[i]]) + } + if i not in could_not_parse + else None + ) + for i in range(len(smiles_list)) + ] return preds else: - return [None for _ in smiles_list] \ No newline at end of file + return [None for _ in smiles_list] From e6602ef24249d20634fbdf539e8117696973946b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Tue, 1 Jul 2025 23:13:34 +0200 Subject: [PATCH 12/20] remove explicit config kwargs for resgated --- chebifier/prediction_models/gnn_predictor.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index ef354c1..57afcfc 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -9,7 +9,6 @@ class ResGatedPredictor(NNPredictor): - def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): super().__init__( model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs @@ -41,16 +40,6 @@ def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: strict=False, metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, - config={ - "in_length": 256, - "hidden_length": 512, - "dropout_rate": 0.1, - "n_conv_layers": 3, - "n_linear_layers": 3, - "n_atom_properties": 158, - "n_bond_properties": 7, - "n_molecule_properties": 200, - }, ) model.eval() return model From fd814e928a62d17899859942c58f684613df2a1d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:30:46 +0200 Subject: [PATCH 13/20] api support for ensemble --- api/api_registry.yml | 24 ++++ api/check_env.py | 30 +++++ api/cli.py | 67 ++++++----- api/hugging_face.py | 5 +- api/registry.yml | 23 ---- api/setup_env.py | 165 ---------------------------- chebifier/cli.py | 13 +-- chebifier/ensemble/base_ensemble.py | 16 +-- chebifier/model_registry.py | 29 +++++ 9 files changed, 130 insertions(+), 242 deletions(-) create mode 100644 api/api_registry.yml create mode 100644 api/check_env.py delete mode 100644 api/registry.yml delete mode 100644 api/setup_env.py create mode 100644 chebifier/model_registry.py diff --git a/api/api_registry.yml b/api/api_registry.yml new file mode 100644 index 0000000..b6e30bd --- /dev/null +++ b/api/api_registry.yml @@ -0,0 +1,24 @@ +electra: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt + package_name: chebai + +resgated: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: resgated + files: + ckpt: resgated.ckpt + labels: classes.txt + package_name: chebai-graph + +chemlog: + package_name: chemlog + + +en_mv: + ensemble_of: {electra, chemlog} diff --git a/api/check_env.py b/api/check_env.py new file mode 100644 index 0000000..b215fda --- /dev/null +++ b/api/check_env.py @@ -0,0 +1,30 @@ +import subprocess +import sys + + +def get_current_environment() -> str: + """ + Return the path of the Python executable for the current environment. + """ + return sys.executable + + +def check_package_installed(package_name: str) -> None: + """ + Check if the given package is installed in the current Python environment. + """ + python_exec = get_current_environment() + try: + subprocess.check_output( + [python_exec, "-m", "pip", "show", package_name], stderr=subprocess.DEVNULL + ) + print(f"✅ Package '{package_name}' is already installed.") + except subprocess.CalledProcessError: + raise ( + f"❌ Please install '{package_name}' into your environment: {python_exec}" + ) + + +if __name__ == "__main__": + print(f"🔍 Using Python executable: {get_current_environment()}") + check_package_installed("numpy") # Replace with your desired package diff --git a/api/cli.py b/api/cli.py index 99f7572..e20ed14 100644 --- a/api/cli.py +++ b/api/cli.py @@ -1,18 +1,17 @@ -import importlib from pathlib import Path import click import yaml -from chebifier.prediction_models.base_predictor import BasePredictor +from chebifier.model_registry import ENSEMBLES, MODEL_TYPES +from .check_env import check_package_installed, get_current_environment from .hugging_face import download_model_files -from .setup_env import SetupEnvAndPackage -yaml_path = Path("api/registry.yml") +yaml_path = Path("api/api_registry.yml") if yaml_path.exists(): with yaml_path.open("r") as f: - model_registry = yaml.safe_load(f) + api_registry = yaml.safe_load(f) else: raise FileNotFoundError(f"{yaml_path} not found.") @@ -40,7 +39,7 @@ def cli(): @click.option( "--model-type", "-m", - type=click.Choice(model_registry.keys()), + type=click.Choice(api_registry.keys()), default="mv", help="Type of model to use", ) @@ -60,29 +59,39 @@ def predict(smiles, smiles_file, output, model_type): click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return - model_config = model_registry[model_type] - predictor_kwargs = {"model_name": model_type} - - current_dir = Path(__file__).resolve().parent - - if "hugging_face" in model_config: - print(f"For model type `{model_type}` following files are used:") - local_file_path = download_model_files(model_config["hugging_face"]) - predictor_kwargs["ckpt_path"] = local_file_path["ckpt"] - predictor_kwargs["target_labels_path"] = local_file_path["labels"] - - SetupEnvAndPackage().setup( - repo_url=model_config["repo_url"], - clone_dir=current_dir / ".cloned_repos", - venv_dir=current_dir, - ) - - model_cls_path = model_config["wrapper"] - module_path, class_name = model_cls_path.rsplit(".", 1) - module = importlib.import_module(module_path) - model_cls: type = getattr(module, class_name) - model_instance = model_cls(**predictor_kwargs) - assert isinstance(model_instance, BasePredictor) + print("Current working environment is:", get_current_environment()) + + def get_individual_model(model_config): + predictor_kwargs = {} + if "hugging_face" in model_config: + predictor_kwargs = download_model_files(model_config["hugging_face"]) + check_package_installed(model_config["package_name"]) + return predictor_kwargs + + if model_type in MODEL_TYPES: + print(f"Predictor for Single/Individual Model: {model_type}") + model_config = api_registry[model_type] + predictor_kwargs = get_individual_model(model_config) + predictor_kwargs["model_name"] = model_type + model_instance = MODEL_TYPES[model_type](**predictor_kwargs) + + elif model_type in ENSEMBLES: + print(f"Predictor for Ensemble Model: {model_type}") + ensemble_config = {} + for i, en_comp in enumerate(api_registry[model_type]["ensemble_of"]): + assert en_comp in MODEL_TYPES + print(f"For ensemble component {en_comp}") + predictor_kwargs = get_individual_model(api_registry[en_comp]) + model_key = f"model_{i + 1}" + ensemble_config[model_key] = { + "type": en_comp, + "model_name": f"{en_comp}_{model_key}", + **predictor_kwargs, + } + model_instance = ENSEMBLES[model_type](ensemble_config) + + else: + raise ValueError("") # Make predictions predictions = model_instance.predict_smiles_list(smiles_list) diff --git a/api/hugging_face.py b/api/hugging_face.py index 62d16e8..5569d86 100644 --- a/api/hugging_face.py +++ b/api/hugging_face.py @@ -45,4 +45,7 @@ def download_model_files( local_paths[file_type] = Path(downloaded_file_path) print(f"\t Using file `{filename}` from: {downloaded_file_path}") - return local_paths + return { + "ckpt_path": local_paths["ckpt"], + "target_labels_path": local_paths["labels"], + } diff --git a/api/registry.yml b/api/registry.yml deleted file mode 100644 index c9069a8..0000000 --- a/api/registry.yml +++ /dev/null @@ -1,23 +0,0 @@ -electra: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: electra - files: - ckpt: electra.ckpt - labels: classes.txt - repo_url: https://github.com/ChEB-AI/python-chebai - wrapper: chebifier.prediction_models.ElectraPredictor - -resgated: - hugging_face: - repo_id: aditya0by0/python-chebifier - subfolder: resgated - files: - ckpt: resgated.ckpt - labels: classes.txt - repo_url: https://github.com/ChEB-AI/python-chebai-graph - wrapper: chebifier.prediction_models.ResGatedPredictor - -chemlog: - repo_url: https://github.com/sfluegel05/chemlog-peptides - wrapper: chebifier.prediction_models.ChemLogPredictor diff --git a/api/setup_env.py b/api/setup_env.py deleted file mode 100644 index a246c26..0000000 --- a/api/setup_env.py +++ /dev/null @@ -1,165 +0,0 @@ -import os -import re -import subprocess -import sys -from pathlib import Path - -# Conditional import of tomllib based on Python version -if sys.version_info >= (3, 11): - import tomllib # built-in in Python 3.11+ -else: - import toml as tomllib # use third-party toml library for older versions - - -class SetupEnvAndPackage: - """Utility class for cloning a repository, setting up a virtual environment, and installing a package.""" - - def setup( - self, - repo_url: str, - clone_dir: Path, - venv_dir: Path, - venv_name: str = ".venv-chebifier", - ) -> None: - """ - Orchestrates the full setup process: cloning the repository, - creating a virtual environment, and installing the package. - - Args: - repo_url (str): URL of the Git repository. - clone_dir (Path): Directory to clone the repo into. - venv_dir (Path): Directory where the virtual environment will be created. - venv_name (str): Name of the virtual environment folder. - """ - cloned_repo_path = self._clone_repo(repo_url, clone_dir) - venv_path = self._create_virtualenv(venv_dir, venv_name) - self._install_from_pyproject(venv_path, cloned_repo_path) - - def _clone_repo(self, repo_url: str, clone_dir: Path) -> Path: - """ - Clone a Git repository into a specified directory. - - Args: - repo_url (str): Git URL to clone. - clone_dir (Path): Directory to clone into. - - Returns: - Path: Path to the cloned repository. - """ - repo_name = repo_url.rstrip("/").split("/")[-1].replace(".git", "") - clone_path = Path(clone_dir or repo_name) - - if not clone_path.exists(): - print(f"Cloning {repo_url} into {clone_path}...") - subprocess.check_call( - ["git", "clone", "--depth=1", repo_url, str(clone_path)] - ) - else: - print(f"Repo already exists at {clone_path}") - - return clone_path - - @staticmethod - def _create_virtualenv(venv_dir: Path, venv_name: str = ".venv-chebifier") -> Path: - """ - Create a virtual environment at the specified path. - - Args: - venv_dir (Path): Base directory where the venv will be created. - venv_name (str): Name of the virtual environment directory. - - Returns: - Path: Path to the virtual environment. - """ - venv_path = venv_dir / venv_name - - if venv_path.exists(): - print(f"Virtual environment already exists at: {venv_path}") - return venv_path - - print(f"Creating virtual environment at: {venv_path}") - - try: - subprocess.check_call(["virtualenv", str(venv_path)]) - except FileNotFoundError: - print("virtualenv not found, installing it now...") - subprocess.check_call( - [sys.executable, "-m", "pip", "install", "virtualenv"] - ) - subprocess.check_call(["virtualenv", str(venv_path)]) - - return venv_path - - def _install_from_pyproject(self, venv_dir: Path, cloned_repo_path: Path) -> None: - """ - Install the cloned package in editable mode. - - Args: - venv_dir (Path): Path to the virtual environment. - cloned_repo_path (Path): Path to the cloned repository. - """ - pip_executable = ( - venv_dir / "Scripts" / "pip.exe" - if os.name == "nt" - else venv_dir / "bin" / "pip" - ) - - if not pip_executable.exists(): - raise FileNotFoundError(f"pip not found at {pip_executable}") - - try: - package_name = self._get_package_name(cloned_repo_path) - except Exception as e: - raise RuntimeError(f"Error extracting package name: {e}") - - try: - subprocess.check_output( - [str(pip_executable), "show", package_name], stderr=subprocess.DEVNULL - ) - print(f"Package '{package_name}' is already installed.") - except subprocess.CalledProcessError: - print(f"Installing '{package_name}' from {cloned_repo_path}...") - subprocess.check_call( - [str(pip_executable), "install", "-e", "."], - cwd=cloned_repo_path, - ) - - @staticmethod - def _get_package_name(cloned_repo_path: Path) -> str: - """ - Extracts the package name from `pyproject.toml` or `setup.py`. - - Args: - cloned_repo_path (Path): Path to the cloned repository. - - Returns: - str: Name of the Python package. - - Raises: - ValueError: If parsing fails. - FileNotFoundError: If neither config file is found. - """ - pyproject_path = cloned_repo_path / "pyproject.toml" - setup_path = cloned_repo_path / "setup.py" - - if pyproject_path.exists(): - try: - with pyproject_path.open("rb") as f: - pyproject = tomllib.load(f) - return pyproject["project"]["name"] - except Exception as e: - raise ValueError(f"Failed to parse pyproject.toml: {e}") - - elif setup_path.exists(): - try: - setup_contents = setup_path.read_text() - match = re.search(r'name\s*=\s*[\'"]([^\'"]+)[\'"]', setup_contents) - if match: - return match.group(1) - else: - raise ValueError("Could not find package name in setup.py") - except Exception as e: - raise ValueError(f"Failed to parse setup.py: {e}") - - else: - raise FileNotFoundError("Neither pyproject.toml nor setup.py found.") diff --git a/chebifier/cli.py b/chebifier/cli.py index a6b8743..b51dc04 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,11 +1,7 @@ import click import yaml -from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import ( - WMVwithF1Ensemble, - WMVwithPPVNPVEnsemble, -) +from .model_registry import ENSEMBLES @click.group() @@ -14,13 +10,6 @@ def cli(): pass -ENSEMBLES = { - "mv": BaseEnsemble, - "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble, -} - - @cli.command() @click.argument("config_file", type=click.Path(exists=True)) @click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index d4a4fe3..19f49d2 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,22 +4,14 @@ import torch import tqdm -from chebifier.prediction_models import ( - BasePredictor, - ChemLogPredictor, - ElectraPredictor, - ResGatedPredictor, -) - -MODEL_TYPES = { - "electra": ElectraPredictor, - "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor, -} +from chebifier.prediction_models import BasePredictor class BaseEnsemble(ABC): def __init__(self, model_configs: dict): + # Deferred Import: To avoid circular import error + from chebifier.model_registry import MODEL_TYPES + self.models = [] self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py new file mode 100644 index 0000000..4961f3e --- /dev/null +++ b/chebifier/model_registry.py @@ -0,0 +1,29 @@ +from chebifier.ensemble.base_ensemble import BaseEnsemble +from chebifier.ensemble.weighted_majority_ensemble import ( + WMVwithF1Ensemble, + WMVwithPPVNPVEnsemble, +) +from chebifier.prediction_models import ( + ChemLogPredictor, + ElectraPredictor, + ResGatedPredictor, +) + +ENSEMBLES = { + "en_mv": BaseEnsemble, + "en_wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "en_wmv-f1": WMVwithF1Ensemble, +} + + +MODEL_TYPES = { + "electra": ElectraPredictor, + "resgated": ResGatedPredictor, + "chemlog": ChemLogPredictor, +} + + +common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys() +assert ( + not common_keys +), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" From a044f23d6ec797645d05987eee590b9f3c46adf6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:34:31 +0200 Subject: [PATCH 14/20] add ruff action workflow --- .github/workflows/lint.yml | 26 ++++++++++++++++++++++++++ .pre-commit-config.yaml | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..1b63c41 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black ruff + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e32d80c..b8a785a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,4 +28,4 @@ repos: rev: v0.12.1 hooks: - id: ruff - args: [] # No --fix, disables formatting + args: [--fix] From 51a2d348e5d2dcaf4a1b6a2b79debb4c63a47964 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 18:49:42 +0200 Subject: [PATCH 15/20] same version for workflow and pre-commit yaml --- .github/workflows/lint.yml | 2 +- .pre-commit-config.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1b63c41..bb9154f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install black ruff + pip install black==25.1.0 ruff==0.12.2 - name: Run Black run: black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8a785a..cbb7284 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: "24.2.0" + rev: "25.1.0" hooks: - id: black - id: black-jupyter # for formatting jupyter-notebook @@ -25,7 +25,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.1 + rev: v0.12.2 hooks: - id: ruff args: [--fix] From d2c586aa9e8e25a3734fa3de48351264351d615e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 6 Jul 2025 19:02:48 +0200 Subject: [PATCH 16/20] Update base_predictor.py --- chebifier/prediction_models/base_predictor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index e6b7952..3eeee52 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -3,13 +3,12 @@ class BasePredictor(ABC): - def __init__( self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, - **kwargs + **kwargs, ): self.model_name = model_name self.model_weight = model_weight From f3b39052857f4357b5baa528a058ff2b8c836dc2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 11:56:19 +0200 Subject: [PATCH 17/20] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c42e4df..0559817 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ cd python-chebifier pip install -e . ``` -Some dependencies of `chebai-graph` cannot be installed automatically. If you want to use Graph Neural Networks, follow +u`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). ## Usage From 001538daf33abc584f22c4a0afccb2c62a510f33 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 11:59:41 +0200 Subject: [PATCH 18/20] fix cli and ensemble imports --- chebifier/cli.py | 9 --------- chebifier/ensemble/base_ensemble.py | 13 ++----------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/chebifier/cli.py b/chebifier/cli.py index 2c3ad0d..5fa9679 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -2,21 +2,12 @@ import yaml from .model_registry import ENSEMBLES -from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble - @click.group() def cli(): """Command line interface for Chebifier.""" pass -ENSEMBLES = { - "mv": BaseEnsemble, - "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble -} - @cli.command() @click.argument('config_file', type=click.Path(exists=True)) @click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index 5f94d02..a071a33 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,22 +1,13 @@ import os -from abc import ABC import torch import tqdm from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor -from chebifier.prediction_models.electra_predictor import ElectraPredictor -from chebifier.prediction_models.gnn_predictor import ResGatedPredictor - -MODEL_TYPES = { - "electra": ElectraPredictor, - "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor -} -class BaseEnsemble(ABC): + +class BaseEnsemble: def __init__(self, model_configs: dict, chebi_version: int = 241): # Deferred Import: To avoid circular import error From f8583cbdfa0059378e3039cb2c763c6866e737bd Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 13:08:03 +0200 Subject: [PATCH 19/20] add huggingface download to cli --- chebifier/__main__.py | 4 ++++ chebifier/cli.py | 8 +++++--- chebifier/ensemble/base_ensemble.py | 22 +++++++++++++++------- chebifier/model_registry.py | 6 +++--- configs/huggingface_config.yml | 22 ++++++++++++++++++++++ pyproject.toml | 3 --- 6 files changed, 49 insertions(+), 16 deletions(-) create mode 100644 chebifier/__main__.py create mode 100644 configs/huggingface_config.yml diff --git a/chebifier/__main__.py b/chebifier/__main__.py new file mode 100644 index 0000000..9aebe0f --- /dev/null +++ b/chebifier/__main__.py @@ -0,0 +1,4 @@ +from chebifier.cli import cli + +if __name__ == '__main__': + cli() \ No newline at end of file diff --git a/chebifier/cli.py b/chebifier/cli.py index 5fa9679..a21ebf3 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,3 +1,5 @@ +import os + import click import yaml @@ -9,14 +11,14 @@ def cli(): pass @cli.command() -@click.argument('config_file', type=click.Path(exists=True)) +@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models") @click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') @click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') @click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') @click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') @click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)") @click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)") -def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version): +def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence): """Predict ChEBI classes for SMILES strings using an ensemble model. CONFIG_FILE is the path to a YAML configuration file for the ensemble model. @@ -39,7 +41,7 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi return # Make predictions - predictions = ensemble.predict_smiles_list(smiles_list) + predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence) if output: # save as json diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index a071a33..a946c2a 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -4,6 +4,7 @@ from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother +from api.hugging_face import download_model_files from chebifier.prediction_models.base_predictor import BasePredictor @@ -17,14 +18,20 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] - model_instance = model_cls(model_name, **model_config) + if "hugging_face" in model_config: + hugging_face_kwargs = download_model_files(model_config["hugging_face"]) + else: + hugging_face_kwargs = {} + model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) - self.smoother = PredictionSmoother(ChEBIOver50(chebi_version=chebi_version), disjoint_files=[ + self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) + self.chebi_dataset._download_required_data() # download chebi if not already downloaded + self.disjoint_files=[ os.path.join("data", "disjoint_chebi.csv"), os.path.join("data", "disjoint_additional.csv") - ]) + ] def gather_predictions(self, smiles_list): @@ -110,7 +117,7 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights - def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: + def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): @@ -128,11 +135,12 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights) + class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs) # Smooth predictions class_names = list(predicted_classes.keys()) - self.smoother.label_names = class_names - class_decisions = self.smoother(class_decisions) + # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) + new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files) + class_decisions = new_smoother(class_decisions) class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 4961f3e..cf7d6d0 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -10,9 +10,9 @@ ) ENSEMBLES = { - "en_mv": BaseEnsemble, - "en_wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "en_wmv-f1": WMVwithF1Ensemble, + "mv": BaseEnsemble, + "wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "wmv-f1": WMVwithF1Ensemble, } diff --git a/configs/huggingface_config.yml b/configs/huggingface_config.yml new file mode 100644 index 0000000..c26950d --- /dev/null +++ b/configs/huggingface_config.yml @@ -0,0 +1,22 @@ + +chemlog_peptides: + type: chemlog + model_weight: 100 + +#resgated_huggingface: +# type: resgated +# hugging_face: +# repo_id: aditya0by0/python-chebifier +# subfolder: resgated +# files: +# ckpt: resgated.ckpt +# labels: classes.txt + +electra_huggingface: + type: electra + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt diff --git a/pyproject.toml b/pyproject.toml index 8a0223f..ff7837d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,6 @@ dependencies = [ "chemlog>=1.0.4" ] -[project.scripts] -chebifier = "chebifier.cli:cli" - [tool.setuptools] packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"] From 90aedd43a105f83c65f2351bd59d124fc0bc0c51 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 11 Jul 2025 13:11:10 +0200 Subject: [PATCH 20/20] reformat with black --- chebifier/__main__.py | 4 +- chebifier/cli.py | 86 +++++++++++++++---- chebifier/ensemble/base_ensemble.py | 125 +++++++++++++++++----------- 3 files changed, 148 insertions(+), 67 deletions(-) diff --git a/chebifier/__main__.py b/chebifier/__main__.py index 9aebe0f..22bf70c 100644 --- a/chebifier/__main__.py +++ b/chebifier/__main__.py @@ -1,4 +1,4 @@ from chebifier.cli import cli -if __name__ == '__main__': - cli() \ No newline at end of file +if __name__ == "__main__": + cli() diff --git a/chebifier/cli.py b/chebifier/cli.py index a21ebf3..11c138b 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -5,49 +5,99 @@ from .model_registry import ENSEMBLES + @click.group() def cli(): """Command line interface for Chebifier.""" pass + @cli.command() -@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models") -@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') -@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') -@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') -@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') -@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)") -@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)") -def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence): +@click.option( + "--config_file", + type=click.Path(exists=True), + default=os.path.join("configs", "huggingface_config.yml"), + help="Configuration file for ensemble models", +) +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--ensemble-type", + "-e", + type=click.Choice(ENSEMBLES.keys()), + default="mv", + help="Type of ensemble to use (default: Majority Voting)", +) +@click.option( + "--chebi-version", + "-v", + type=int, + default=241, + help="ChEBI version to use for checking consistency (default: 241)", +) +@click.option( + "--use-confidence", + "-c", + is_flag=True, + default=True, + help="Weight predictions based on how 'confident' a model is in its prediction (default: True)", +) +def predict( + config_file, + smiles, + smiles_file, + output, + ensemble_type, + chebi_version, + use_confidence, +): """Predict ChEBI classes for SMILES strings using an ensemble model. - + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. """ # Load configuration from YAML file - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = yaml.safe_load(f) - + # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version) - + # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) if smiles_file: - with open(smiles_file, 'r') as f: + with open(smiles_file, "r") as f: smiles_list.extend([line.strip() for line in f if line.strip()]) - + if not smiles_list: click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return # Make predictions - predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence) + predictions = ensemble.predict_smiles_list( + smiles_list, use_confidence=use_confidence + ) if output: # save as json import json - with open(output, 'w') as f: - json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) else: # Print results @@ -59,5 +109,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi click.echo(" No predictions") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index a946c2a..0795fc0 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -22,18 +22,19 @@ def __init__(self, model_configs: dict, chebi_version: int = 241): hugging_face_kwargs = download_model_files(model_config["hugging_face"]) else: hugging_face_kwargs = {} - model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs) + model_instance = model_cls( + model_name, **model_config, **hugging_face_kwargs + ) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) self.chebi_dataset._download_required_data() # download chebi if not already downloaded - self.disjoint_files=[ + self.disjoint_files = [ os.path.join("data", "disjoint_chebi.csv"), - os.path.join("data", "disjoint_additional.csv") + os.path.join("data", "disjoint_additional.csv"), ] - def gather_predictions(self, smiles_list): # get predictions from all models for the SMILES list # order them by alphabetically by label class @@ -60,11 +61,12 @@ def gather_predictions(self, smiles_list): ): if logits_for_smiles is not None: for cls in logits_for_smiles: - ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls] + ordered_logits[j, predicted_classes_dict[cls], i] = ( + logits_for_smiles[cls] + ) return ordered_logits, predicted_classes - def consolidate_predictions(self, predictions, classwise_weights, **kwargs): """ Aggregates predictions from multiple models using weighted majority voting. @@ -80,11 +82,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): has_valid_predictions = valid_counts > 0 # Calculate positive and negative predictions for all classes at once - positive_mask = (predictions > self.positive_prediction_threshold) & valid_predictions - negative_mask = (predictions < self.positive_prediction_threshold) & valid_predictions + positive_mask = ( + predictions > self.positive_prediction_threshold + ) & valid_predictions + negative_mask = ( + predictions < self.positive_prediction_threshold + ) & valid_predictions if "use_confidence" in kwargs and kwargs["use_confidence"]: - confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + confidence = 2 * torch.abs( + predictions.nan_to_num() - self.positive_prediction_threshold + ) else: confidence = torch.ones_like(predictions) @@ -95,8 +103,12 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Calculate weighted predictions using broadcasting # predictions shape: (num_smiles, num_classes, num_models) # weights shape: (num_classes, num_models) - positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) - negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + positive_weighted = ( + positive_mask.float() * confidence * pos_weights.unsqueeze(0) + ) + negative_weighted = ( + negative_mask.float() * confidence * neg_weights.unsqueeze(0) + ) # Sum over models dimension positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) @@ -104,9 +116,9 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) - - + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) return class_decisions @@ -117,11 +129,15 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights - def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list: + def predict_smiles_list( + self, smiles_list, load_preds_if_possible=True, **kwargs + ) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + ordered_predictions, predicted_classes = self.gather_predictions( + smiles_list + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: @@ -129,17 +145,27 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs f.write(f"{cls}\n") predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} else: - print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + print( + f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + ) ordered_predictions = torch.load(preds_file) with open(predicted_classes_file, "r") as f: - predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + predicted_classes = { + line.strip(): i for i, line in enumerate(f.readlines()) + } classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs) + class_decisions = self.consolidate_predictions( + ordered_predictions, classwise_weights, **kwargs + ) # Smooth predictions class_names = list(predicted_classes.keys()) # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) - new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files) + new_smoother = PredictionSmoother( + self.chebi_dataset, + label_names=class_names, + disjoint_files=self.disjoint_files, + ) class_decisions = new_smoother(class_decisions) class_names = list(predicted_classes.keys()) @@ -153,31 +179,36 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs if __name__ == "__main__": - ensemble = BaseEnsemble({"resgated_0ps1g189":{ - "type": "resgated", - "ckpt_path": "data/0ps1g189/epoch=122.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - "molecular_properties": [ - "chebai_graph.preprocessing.properties.AtomType", - "chebai_graph.preprocessing.properties.NumAtomBonds", - "chebai_graph.preprocessing.properties.AtomCharge", - "chebai_graph.preprocessing.properties.AtomAromaticity", - "chebai_graph.preprocessing.properties.AtomHybridization", - "chebai_graph.preprocessing.properties.AtomNumHs", - "chebai_graph.preprocessing.properties.BondType", - "chebai_graph.preprocessing.properties.BondInRing", - "chebai_graph.preprocessing.properties.BondAromaticity", - "chebai_graph.preprocessing.properties.RDKit2DNormalized", - ], - #"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" - }, - -"electra_14ko0zcf": { - "type": "electra", - "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - #"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", -} - }) - r = ensemble.predict_smiles_list(["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], load_preds_if_possible=False) + ensemble = BaseEnsemble( + { + "resgated_0ps1g189": { + "type": "resgated", + "ckpt_path": "data/0ps1g189/epoch=122.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + "molecular_properties": [ + "chebai_graph.preprocessing.properties.AtomType", + "chebai_graph.preprocessing.properties.NumAtomBonds", + "chebai_graph.preprocessing.properties.AtomCharge", + "chebai_graph.preprocessing.properties.AtomAromaticity", + "chebai_graph.preprocessing.properties.AtomHybridization", + "chebai_graph.preprocessing.properties.AtomNumHs", + "chebai_graph.preprocessing.properties.BondType", + "chebai_graph.preprocessing.properties.BondInRing", + "chebai_graph.preprocessing.properties.BondAromaticity", + "chebai_graph.preprocessing.properties.RDKit2DNormalized", + ], + # "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" + }, + "electra_14ko0zcf": { + "type": "electra", + "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + # "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", + }, + } + ) + r = ensemble.predict_smiles_list( + ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], + load_preds_if_possible=False, + ) print(len(r), r[0])