diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..bb9154f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or any version your project uses + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install black==25.1.0 ruff==0.12.2 + + - name: Run Black + run: black --check . + + - name: Run Ruff (no formatting) + run: ruff check . --no-fix diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..613c70b --- /dev/null +++ b/.gitignore @@ -0,0 +1,179 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt + +build +.virtual_documents +.jupyter +chebai.egg-info +lightning_logs +logs +.isort.cfg +/.vscode +/api/.cloned_repos diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..cbb7284 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: +- repo: https://github.com/psf/black + rev: "25.1.0" + hooks: + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] diff --git a/README.md b/README.md index c42e4df..0559817 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ cd python-chebifier pip install -e . ``` -Some dependencies of `chebai-graph` cannot be installed automatically. If you want to use Graph Neural Networks, follow +u`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). ## Usage diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/__main__.py b/api/__main__.py new file mode 100644 index 0000000..ec70a17 --- /dev/null +++ b/api/__main__.py @@ -0,0 +1,10 @@ +from .cli import cli + +if __name__ == "__main__": + """ + Entry point for the CLI application. + + This script calls the `cli` function from the `api.cli` module + when executed as the main program. + """ + cli() diff --git a/api/api_registry.yml b/api/api_registry.yml new file mode 100644 index 0000000..b6e30bd --- /dev/null +++ b/api/api_registry.yml @@ -0,0 +1,24 @@ +electra: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt + package_name: chebai + +resgated: + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: resgated + files: + ckpt: resgated.ckpt + labels: classes.txt + package_name: chebai-graph + +chemlog: + package_name: chemlog + + +en_mv: + ensemble_of: {electra, chemlog} diff --git a/api/check_env.py b/api/check_env.py new file mode 100644 index 0000000..b215fda --- /dev/null +++ b/api/check_env.py @@ -0,0 +1,30 @@ +import subprocess +import sys + + +def get_current_environment() -> str: + """ + Return the path of the Python executable for the current environment. + """ + return sys.executable + + +def check_package_installed(package_name: str) -> None: + """ + Check if the given package is installed in the current Python environment. + """ + python_exec = get_current_environment() + try: + subprocess.check_output( + [python_exec, "-m", "pip", "show", package_name], stderr=subprocess.DEVNULL + ) + print(f"✅ Package '{package_name}' is already installed.") + except subprocess.CalledProcessError: + raise ( + f"❌ Please install '{package_name}' into your environment: {python_exec}" + ) + + +if __name__ == "__main__": + print(f"🔍 Using Python executable: {get_current_environment()}") + check_package_installed("numpy") # Replace with your desired package diff --git a/api/cli.py b/api/cli.py new file mode 100644 index 0000000..e20ed14 --- /dev/null +++ b/api/cli.py @@ -0,0 +1,121 @@ +from pathlib import Path + +import click +import yaml + +from chebifier.model_registry import ENSEMBLES, MODEL_TYPES + +from .check_env import check_package_installed, get_current_environment +from .hugging_face import download_model_files + +yaml_path = Path("api/api_registry.yml") +if yaml_path.exists(): + with yaml_path.open("r") as f: + api_registry = yaml.safe_load(f) +else: + raise FileNotFoundError(f"{yaml_path} not found.") + + +@click.group() +def cli(): + """Command line interface for Api-Chebifier.""" + pass + + +@cli.command() +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--model-type", + "-m", + type=click.Choice(api_registry.keys()), + default="mv", + help="Type of model to use", +) +def predict(smiles, smiles_file, output, model_type): + """Predict ChEBI classes for SMILES strings using an ensemble model. + + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. + """ + + # Collect SMILES strings from arguments and/or file + smiles_list = list(smiles) + if smiles_file: + with open(smiles_file, "r") as f: + smiles_list.extend([line.strip() for line in f if line.strip()]) + + if not smiles_list: + click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") + return + + print("Current working environment is:", get_current_environment()) + + def get_individual_model(model_config): + predictor_kwargs = {} + if "hugging_face" in model_config: + predictor_kwargs = download_model_files(model_config["hugging_face"]) + check_package_installed(model_config["package_name"]) + return predictor_kwargs + + if model_type in MODEL_TYPES: + print(f"Predictor for Single/Individual Model: {model_type}") + model_config = api_registry[model_type] + predictor_kwargs = get_individual_model(model_config) + predictor_kwargs["model_name"] = model_type + model_instance = MODEL_TYPES[model_type](**predictor_kwargs) + + elif model_type in ENSEMBLES: + print(f"Predictor for Ensemble Model: {model_type}") + ensemble_config = {} + for i, en_comp in enumerate(api_registry[model_type]["ensemble_of"]): + assert en_comp in MODEL_TYPES + print(f"For ensemble component {en_comp}") + predictor_kwargs = get_individual_model(api_registry[en_comp]) + model_key = f"model_{i + 1}" + ensemble_config[model_key] = { + "type": en_comp, + "model_name": f"{en_comp}_{model_key}", + **predictor_kwargs, + } + model_instance = ENSEMBLES[model_type](ensemble_config) + + else: + raise ValueError("") + + # Make predictions + predictions = model_instance.predict_smiles_list(smiles_list) + + if output: + # save as json + import json + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) + + else: + # Print results + for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)): + click.echo(f"Result for: {smiles}") + if prediction: + click.echo(f" Predicted classes: {', '.join(map(str, prediction))}") + else: + click.echo(" No predictions") + + +if __name__ == "__main__": + cli() diff --git a/api/hugging_face.py b/api/hugging_face.py new file mode 100644 index 0000000..5569d86 --- /dev/null +++ b/api/hugging_face.py @@ -0,0 +1,51 @@ +""" +Hugging Face Api: + - For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations + + Refer for Hugging Face Hub caching and versioning documentation: + https://huggingface.co/docs/huggingface_hub/en/guides/download + https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache +""" + +from pathlib import Path + +from huggingface_hub import hf_hub_download + + +def download_model_files( + model_config: dict[str, str | dict[str, str]], +) -> dict[str, Path]: + """ + Downloads specified model files from a Hugging Face Hub repository using hf_hub_download. + + Hugging Face Hub provides internal caching and versioning, so file management or duplication + checks are not required. + + Args: + model_config (Dict[str, str | Dict[str, str]]): A dictionary containing: + - 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname'). + - 'subfolder' (str): The subfolder within the repo where the files are located. + - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to + actual file names (e.g., 'electra.ckpt', 'classes.txt'). + + Returns: + Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file. + """ + repo_id = model_config["repo_id"] + subfolder = model_config["subfolder"] + filenames = model_config["files"] + + local_paths: dict[str, Path] = {} + for file_type, filename in filenames.items(): + downloaded_file_path = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + ) + local_paths[file_type] = Path(downloaded_file_path) + print(f"\t Using file `{filename}` from: {downloaded_file_path}") + + return { + "ckpt_path": local_paths["ckpt"], + "target_labels_path": local_paths["labels"], + } diff --git a/chebifier/__main__.py b/chebifier/__main__.py new file mode 100644 index 0000000..22bf70c --- /dev/null +++ b/chebifier/__main__.py @@ -0,0 +1,4 @@ +from chebifier.cli import cli + +if __name__ == "__main__": + cli() diff --git a/chebifier/cli.py b/chebifier/cli.py index 5870db1..11c138b 100644 --- a/chebifier/cli.py +++ b/chebifier/cli.py @@ -1,10 +1,9 @@ - - +import os import click import yaml -from chebifier.ensemble.base_ensemble import BaseEnsemble -from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble + +from .model_registry import ENSEMBLES @click.group() @@ -12,50 +11,93 @@ def cli(): """Command line interface for Chebifier.""" pass -ENSEMBLES = { - "mv": BaseEnsemble, - "wmv-ppvnpv": WMVwithPPVNPVEnsemble, - "wmv-f1": WMVwithF1Ensemble -} @cli.command() -@click.argument('config_file', type=click.Path(exists=True)) -@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict') -@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)') -@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)') -@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)') -@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)") -@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)") -def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version): +@click.option( + "--config_file", + type=click.Path(exists=True), + default=os.path.join("configs", "huggingface_config.yml"), + help="Configuration file for ensemble models", +) +@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict") +@click.option( + "--smiles-file", + "-f", + type=click.Path(exists=True), + help="File containing SMILES strings (one per line)", +) +@click.option( + "--output", + "-o", + type=click.Path(), + help="Output file to save predictions (optional)", +) +@click.option( + "--ensemble-type", + "-e", + type=click.Choice(ENSEMBLES.keys()), + default="mv", + help="Type of ensemble to use (default: Majority Voting)", +) +@click.option( + "--chebi-version", + "-v", + type=int, + default=241, + help="ChEBI version to use for checking consistency (default: 241)", +) +@click.option( + "--use-confidence", + "-c", + is_flag=True, + default=True, + help="Weight predictions based on how 'confident' a model is in its prediction (default: True)", +) +def predict( + config_file, + smiles, + smiles_file, + output, + ensemble_type, + chebi_version, + use_confidence, +): """Predict ChEBI classes for SMILES strings using an ensemble model. - + CONFIG_FILE is the path to a YAML configuration file for the ensemble model. """ # Load configuration from YAML file - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = yaml.safe_load(f) - + # Instantiate ensemble model ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version) - + # Collect SMILES strings from arguments and/or file smiles_list = list(smiles) if smiles_file: - with open(smiles_file, 'r') as f: + with open(smiles_file, "r") as f: smiles_list.extend([line.strip() for line in f if line.strip()]) - + if not smiles_list: click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.") return # Make predictions - predictions = ensemble.predict_smiles_list(smiles_list) + predictions = ensemble.predict_smiles_list( + smiles_list, use_confidence=use_confidence + ) if output: # save as json import json - with open(output, 'w') as f: - json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2) + + with open(output, "w") as f: + json.dump( + {smiles: pred for smiles, pred in zip(smiles_list, predictions)}, + f, + indent=2, + ) else: # Print results @@ -67,5 +109,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi click.echo(" No predictions") -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/chebifier/ensemble/base_ensemble.py b/chebifier/ensemble/base_ensemble.py index c4db998..0795fc0 100644 --- a/chebifier/ensemble/base_ensemble.py +++ b/chebifier/ensemble/base_ensemble.py @@ -1,37 +1,39 @@ import os -from abc import ABC import torch import tqdm from chebai.preprocessing.datasets.chebi import ChEBIOver50 from chebai.result.analyse_sem import PredictionSmoother +from api.hugging_face import download_model_files from chebifier.prediction_models.base_predictor import BasePredictor -from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor -from chebifier.prediction_models.electra_predictor import ElectraPredictor -from chebifier.prediction_models.gnn_predictor import ResGatedPredictor -MODEL_TYPES = { - "electra": ElectraPredictor, - "resgated": ResGatedPredictor, - "chemlog": ChemLogPredictor -} -class BaseEnsemble(ABC): +class BaseEnsemble: def __init__(self, model_configs: dict, chebi_version: int = 241): + # Deferred Import: To avoid circular import error + from chebifier.model_registry import MODEL_TYPES + self.models = [] self.positive_prediction_threshold = 0.5 for model_name, model_config in model_configs.items(): model_cls = MODEL_TYPES[model_config["type"]] - model_instance = model_cls(model_name, **model_config) + if "hugging_face" in model_config: + hugging_face_kwargs = download_model_files(model_config["hugging_face"]) + else: + hugging_face_kwargs = {} + model_instance = model_cls( + model_name, **model_config, **hugging_face_kwargs + ) assert isinstance(model_instance, BasePredictor) self.models.append(model_instance) - self.smoother = PredictionSmoother(ChEBIOver50(chebi_version=chebi_version), disjoint_files=[ + self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version) + self.chebi_dataset._download_required_data() # download chebi if not already downloaded + self.disjoint_files = [ os.path.join("data", "disjoint_chebi.csv"), - os.path.join("data", "disjoint_additional.csv") - ]) - + os.path.join("data", "disjoint_additional.csv"), + ] def gather_predictions(self, smiles_list): # get predictions from all models for the SMILES list @@ -44,21 +46,27 @@ def gather_predictions(self, smiles_list): if logits_for_smiles is not None: for cls in logits_for_smiles: predicted_classes.add(cls) - print(f"Sorting predictions...") + print("Sorting predictions...") predicted_classes = sorted(list(predicted_classes)) predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)} - ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan + ordered_logits = ( + torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) + * torch.nan + ) for i, model_prediction in enumerate(model_predictions): - for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction), - total=len(model_prediction), - desc=f"Sorting predictions for {self.models[i].model_name}"): + for j, logits_for_smiles in tqdm.tqdm( + enumerate(model_prediction), + total=len(model_prediction), + desc=f"Sorting predictions for {self.models[i].model_name}", + ): if logits_for_smiles is not None: for cls in logits_for_smiles: - ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls] + ordered_logits[j, predicted_classes_dict[cls], i] = ( + logits_for_smiles[cls] + ) return ordered_logits, predicted_classes - def consolidate_predictions(self, predictions, classwise_weights, **kwargs): """ Aggregates predictions from multiple models using weighted majority voting. @@ -74,11 +82,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): has_valid_predictions = valid_counts > 0 # Calculate positive and negative predictions for all classes at once - positive_mask = (predictions > self.positive_prediction_threshold) & valid_predictions - negative_mask = (predictions < self.positive_prediction_threshold) & valid_predictions + positive_mask = ( + predictions > self.positive_prediction_threshold + ) & valid_predictions + negative_mask = ( + predictions < self.positive_prediction_threshold + ) & valid_predictions if "use_confidence" in kwargs and kwargs["use_confidence"]: - confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) + confidence = 2 * torch.abs( + predictions.nan_to_num() - self.positive_prediction_threshold + ) else: confidence = torch.ones_like(predictions) @@ -89,8 +103,12 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Calculate weighted predictions using broadcasting # predictions shape: (num_smiles, num_classes, num_models) # weights shape: (num_classes, num_models) - positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) - negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) + positive_weighted = ( + positive_mask.float() * confidence * pos_weights.unsqueeze(0) + ) + negative_weighted = ( + negative_mask.float() * confidence * neg_weights.unsqueeze(0) + ) # Sum over models dimension positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) @@ -98,9 +116,9 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs): # Determine which classes to include for each SMILES net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) - class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) - - + class_decisions = ( + net_score > 0 + ) & has_valid_predictions # Shape: (num_smiles, num_classes) return class_decisions @@ -111,11 +129,15 @@ def calculate_classwise_weights(self, predicted_classes): return positive_weights, negative_weights - def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: + def predict_smiles_list( + self, smiles_list, load_preds_if_possible=True, **kwargs + ) -> list: preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" if not load_preds_if_possible or not os.path.isfile(preds_file): - ordered_predictions, predicted_classes = self.gather_predictions(smiles_list) + ordered_predictions, predicted_classes = self.gather_predictions( + smiles_list + ) # save predictions torch.save(ordered_predictions, preds_file) with open(predicted_classes_file, "w") as f: @@ -123,17 +145,28 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: f.write(f"{cls}\n") predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} else: - print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") + print( + f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}" + ) ordered_predictions = torch.load(preds_file) with open(predicted_classes_file, "r") as f: - predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} + predicted_classes = { + line.strip(): i for i, line in enumerate(f.readlines()) + } classwise_weights = self.calculate_classwise_weights(predicted_classes) - class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights) + class_decisions = self.consolidate_predictions( + ordered_predictions, classwise_weights, **kwargs + ) # Smooth predictions class_names = list(predicted_classes.keys()) - self.smoother.label_names = class_names - class_decisions = self.smoother(class_decisions) + # initialise new smoother class since we don't know the labels beforehand (this could be more efficient) + new_smoother = PredictionSmoother( + self.chebi_dataset, + label_names=class_names, + disjoint_files=self.disjoint_files, + ) + class_decisions = new_smoother(class_decisions) class_names = list(predicted_classes.keys()) class_indices = {predicted_classes[cls]: cls for cls in class_names} @@ -144,32 +177,38 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: return result + if __name__ == "__main__": - ensemble = BaseEnsemble({"resgated_0ps1g189":{ - "type": "resgated", - "ckpt_path": "data/0ps1g189/epoch=122.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - "molecular_properties": [ - "chebai_graph.preprocessing.properties.AtomType", - "chebai_graph.preprocessing.properties.NumAtomBonds", - "chebai_graph.preprocessing.properties.AtomCharge", - "chebai_graph.preprocessing.properties.AtomAromaticity", - "chebai_graph.preprocessing.properties.AtomHybridization", - "chebai_graph.preprocessing.properties.AtomNumHs", - "chebai_graph.preprocessing.properties.BondType", - "chebai_graph.preprocessing.properties.BondInRing", - "chebai_graph.preprocessing.properties.BondAromaticity", - "chebai_graph.preprocessing.properties.RDKit2DNormalized", - ], - #"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" - }, - -"electra_14ko0zcf": { - "type": "electra", - "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", - "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", - #"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", -} - }) - r = ensemble.predict_smiles_list(["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], load_preds_if_possible=False) - print(len(r), r[0]) \ No newline at end of file + ensemble = BaseEnsemble( + { + "resgated_0ps1g189": { + "type": "resgated", + "ckpt_path": "data/0ps1g189/epoch=122.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + "molecular_properties": [ + "chebai_graph.preprocessing.properties.AtomType", + "chebai_graph.preprocessing.properties.NumAtomBonds", + "chebai_graph.preprocessing.properties.AtomCharge", + "chebai_graph.preprocessing.properties.AtomAromaticity", + "chebai_graph.preprocessing.properties.AtomHybridization", + "chebai_graph.preprocessing.properties.AtomNumHs", + "chebai_graph.preprocessing.properties.BondType", + "chebai_graph.preprocessing.properties.BondInRing", + "chebai_graph.preprocessing.properties.BondAromaticity", + "chebai_graph.preprocessing.properties.RDKit2DNormalized", + ], + # "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json" + }, + "electra_14ko0zcf": { + "type": "electra", + "ckpt_path": "data/14ko0zcf/epoch=193.ckpt", + "target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt", + # "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json", + }, + } + ) + r = ensemble.predict_smiles_list( + ["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], + load_preds_if_possible=False, + ) + print(len(r), r[0]) diff --git a/chebifier/ensemble/weighted_majority_ensemble.py b/chebifier/ensemble/weighted_majority_ensemble.py index ac0a796..ed40626 100644 --- a/chebifier/ensemble/weighted_majority_ensemble.py +++ b/chebifier/ensemble/weighted_majority_ensemble.py @@ -3,9 +3,7 @@ from chebifier.ensemble.base_ensemble import BaseEnsemble - class WMVwithPPVNPVEnsemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -23,15 +21,18 @@ def calculate_classwise_weights(self, predicted_classes): positive_weights[predicted_classes[cls], j] *= weights["PPV"] negative_weights[predicted_classes[cls], j] *= weights["NPV"] - print(f"Calculated model weightings. The averages for positive / negative weights are:") + print( + "Calculated model weightings. The averages for positive / negative weights are:" + ) for i, model in enumerate(self.models): - print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}") + print( + f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}" + ) return positive_weights, negative_weights class WMVwithF1Ensemble(BaseEnsemble): - def calculate_classwise_weights(self, predicted_classes): """ Given the positions of predicted classes in the predictions tensor, assign weights to each class. The @@ -45,11 +46,15 @@ def calculate_classwise_weights(self, predicted_classes): continue for cls, weights in model.classwise_weights.items(): if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0: - f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + f1 = ( + 2 + * weights["TP"] + / (2 * weights["TP"] + weights["FP"] + weights["FN"]) + ) weights_by_cls[predicted_classes[cls], j] *= 1 + f1 - print(f"Calculated model weightings. The average weights are:") + print("Calculated model weightings. The average weights are:") for i, model in enumerate(self.models): print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}") - return weights_by_cls, weights_by_cls \ No newline at end of file + return weights_by_cls, weights_by_cls diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py new file mode 100644 index 0000000..cf7d6d0 --- /dev/null +++ b/chebifier/model_registry.py @@ -0,0 +1,29 @@ +from chebifier.ensemble.base_ensemble import BaseEnsemble +from chebifier.ensemble.weighted_majority_ensemble import ( + WMVwithF1Ensemble, + WMVwithPPVNPVEnsemble, +) +from chebifier.prediction_models import ( + ChemLogPredictor, + ElectraPredictor, + ResGatedPredictor, +) + +ENSEMBLES = { + "mv": BaseEnsemble, + "wmv-ppvnpv": WMVwithPPVNPVEnsemble, + "wmv-f1": WMVwithF1Ensemble, +} + + +MODEL_TYPES = { + "electra": ElectraPredictor, + "resgated": ResGatedPredictor, + "chemlog": ChemLogPredictor, +} + + +common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys() +assert ( + not common_keys +), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index e69de29..ed08890 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -0,0 +1,6 @@ +from .base_predictor import BasePredictor +from .chemlog_predictor import ChemLogPredictor +from .electra_predictor import ElectraPredictor +from .gnn_predictor import ResGatedPredictor + +__all__ = ["BasePredictor", "ChemLogPredictor", "ElectraPredictor", "ResGatedPredictor"] diff --git a/chebifier/prediction_models/base_predictor.py b/chebifier/prediction_models/base_predictor.py index b5229c1..287f097 100644 --- a/chebifier/prediction_models/base_predictor.py +++ b/chebifier/prediction_models/base_predictor.py @@ -1,13 +1,21 @@ -from abc import ABC import json +from abc import ABC -class BasePredictor(ABC): - def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs): +class BasePredictor(ABC): + def __init__( + self, + model_name: str, + model_weight: int = 1, + classwise_weights_path: str = None, + **kwargs, + ): self.model_name = model_name self.model_weight = model_weight if classwise_weights_path is not None: - self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8")) + self.classwise_weights = json.load( + open(classwise_weights_path, encoding="utf-8") + ) else: self.classwise_weights = None diff --git a/chebifier/prediction_models/chemlog_predictor.py b/chebifier/prediction_models/chemlog_predictor.py index 854cffb..4bcb9b8 100644 --- a/chebifier/prediction_models/chemlog_predictor.py +++ b/chebifier/prediction_models/chemlog_predictor.py @@ -1,51 +1,57 @@ import tqdm +from chemlog.alg_classification.charge_classifier import get_charge_category +from chemlog.alg_classification.peptide_size_classifier import get_n_amino_acid_residues +from chemlog.alg_classification.proteinogenics_classifier import ( + get_proteinogenic_amino_acids, +) +from chemlog.alg_classification.substructure_classifier import ( + is_diketopiperazine, + is_emericellamide, +) +from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call -from chebifier.prediction_models.base_predictor import BasePredictor -from chemlog.alg_classification.charge_classifier import AlgChargeClassifier, get_charge_category -from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier, get_n_amino_acid_residues -from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier, get_proteinogenic_amino_acids -from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier, is_emericellamide, is_diketopiperazine -from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS - +from .base_predictor import BasePredictor AA_DICT = { - "A": "L-alanine", - "C": "L-cysteine", - "D": "L-aspartic acid", - "E": "L-glutamic acid", - "F": "L-phenylalanine", - "G": "glycine", - "H": "L-histidine", - "I": "L-isoleucine", - "K": "L-lysine", - "L": "L-leucine", - "M": "L-methionine", - "fMet": "N-formylmethionine", - "N": "L-asparagine", - "O": "L-pyrrolysine", - "P": "L-proline", - "Q": "L-glutamine", - "R": "L-arginine", - "S": "L-serine", - "T": "L-threonine", - "U": "L-selenocysteine", - "V": "L-valine", - "W": "L-tryptophan", - "Y": "L-tyrosine", - } + "A": "L-alanine", + "C": "L-cysteine", + "D": "L-aspartic acid", + "E": "L-glutamic acid", + "F": "L-phenylalanine", + "G": "glycine", + "H": "L-histidine", + "I": "L-isoleucine", + "K": "L-lysine", + "L": "L-leucine", + "M": "L-methionine", + "fMet": "N-formylmethionine", + "N": "L-asparagine", + "O": "L-pyrrolysine", + "P": "L-proline", + "Q": "L-glutamine", + "R": "L-arginine", + "S": "L-serine", + "T": "L-threonine", + "U": "L-selenocysteine", + "V": "L-valine", + "W": "L-tryptophan", + "Y": "L-tyrosine", +} class ChemLogPredictor(BasePredictor): - def __init__(self, model_name: str, **kwargs): super().__init__(model_name, **kwargs) self.strategy = "algo" self.classifier_instances = { k: v() for k, v in CLASSIFIERS[self.strategy].items() } - self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", - "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"] - + # fmt: off + self.peptide_labels = [ + "15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923", + "48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837" + ] + # fmt: on print(f"Initialised ChemLog model {self.model_name}") def predict_smiles_list(self, smiles_list: list[str]) -> list: @@ -55,7 +61,19 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list: if mol is None: results.append(None) else: - results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels}) + results.append( + { + label: ( + 1 + if label + in strategy_call( + self.strategy, self.classifier_instances, mol + )["chebi_classes"] + else 0 + ) + for label in self.peptide_labels + } + ) for classifier in self.classifier_instances.values(): classifier.on_finish() @@ -72,16 +90,15 @@ def get_chemlog_result_info(self, smiles): n_amino_acid_residues, add_output = get_n_amino_acid_residues(mol) if n_amino_acid_residues > 1: proteinogenics, proteinogenics_locations, _ = get_proteinogenic_amino_acids( - mol, - add_output["amino_residue"], - add_output["carboxy_residue"]) + mol, add_output["amino_residue"], add_output["carboxy_residue"] + ) else: proteinogenics, proteinogenics_locations, _ = [], [], [] results = { - 'charge_category': charge_category.name, - 'n_amino_acid_residues': n_amino_acid_residues, - 'proteinogenics': proteinogenics, - 'proteinogenics_locations': proteinogenics_locations, + "charge_category": charge_category.name, + "n_amino_acid_residues": n_amino_acid_residues, + "proteinogenics": proteinogenics, + "proteinogenics_locations": proteinogenics_locations, } if n_amino_acid_residues == 5: @@ -97,97 +114,217 @@ def get_chemlog_result_info(self, smiles): return {**results, **add_output} - def build_explain_blocks_atom_allocations(self, atoms, cls_name): return [ ("heading", cls_name), - ("text", f"The peptide has been identified as an instance of '" - f"{cls_name}'. This was decided based on the presence of the following structure:"), - ("single", atoms) + ( + "text", + f"The peptide has been identified as an instance of '" + f"{cls_name}'. This was decided based on the presence of the following structure:", + ), + ("single", atoms), ] def build_explain_blocks_peptides(self, info): blocks = [] if "error" in info: - blocks.append(("text", f"An error occurred while processing the molecule: {info['error']}")) + blocks.append( + ( + "text", + f"An error occurred while processing the molecule: {info['error']}", + ) + ) return blocks blocks.append(("heading", "Functional groups")) if len(info["amide_bond"]) == 0: - blocks.append(("text", "The molecule does not contain any amide. Therefore, it cannot be a peptide, " - "peptide anion, peptide zwitterion or peptide cation.")) + blocks.append( + ( + "text", + "The molecule does not contain any amide. Therefore, it cannot be a peptide, " + "peptide anion, peptide zwitterion or peptide cation.", + ) + ) return blocks - blocks.append(("text", "The molecule contains the following functional groups:")) - blocks.append(("tabs", {"Amide": info["amide_bond"], - "Carboxylic acid derivative": info["carboxy_residue"], - "Amino group": [[a] for a in info["amino_residue"]]})) + blocks.append( + ("text", "The molecule contains the following functional groups:") + ) + blocks.append( + ( + "tabs", + { + "Amide": info["amide_bond"], + "Carboxylic acid derivative": info["carboxy_residue"], + "Amino group": [[a] for a in info["amino_residue"]], + }, + ) + ) blocks.append(("heading", "Identifying the peptide structure")) if len(info["chunks"]) == 0: - blocks.append(("text", "All atoms in the molecule are connected via a chain of carbon atoms. " - "Therefore, the molecule cannot be a peptide, peptide anion, peptide zwitterion " - "or peptide cation.")) + blocks.append( + ( + "text", + "All atoms in the molecule are connected via a chain of carbon atoms. " + "Therefore, the molecule cannot be a peptide, peptide anion, peptide zwitterion " + "or peptide cation.", + ) + ) return blocks - blocks.append(("text", "To divide up the molecule into potential amino acids, it has been split into the " - f"{len(info['chunks'])} 'building blocks' (based on heteroatoms).")) - blocks.append(("text", "For each, we have checked if it constitutes an amino acid residue.")) + blocks.append( + ( + "text", + "To divide up the molecule into potential amino acids, it has been split into the " + f"{len(info['chunks'])} 'building blocks' (based on heteroatoms).", + ) + ) + blocks.append( + ( + "text", + "For each, we have checked if it constitutes an amino acid residue.", + ) + ) if len(info["chunks"]) == len(info["longest_aa_chain"]): - blocks.append(("text", "All chunks have been identified as amino acid residues that are connected " - "via amide bonds:")) + blocks.append( + ( + "text", + "All chunks have been identified as amino acid residues that are connected " + "via amide bonds:", + ) + ) blocks.append(("tabs", {"Amino acid residue": info["longest_aa_chain"]})) elif len(info["longest_aa_chain"]) == 0: blocks.append(("tabs", {"Chunks": info["chunks"]})) blocks.append( - ("text", "In these chunks, no amino acids have been identified. " - "Therefore, the molecule cannot be a peptide, " - "peptide anion, peptide zwitterion or peptide cation.")) + ( + "text", + "In these chunks, no amino acids have been identified. " + "Therefore, the molecule cannot be a peptide, " + "peptide anion, peptide zwitterion or peptide cation.", + ) + ) return blocks else: - blocks.append(("text", f"{len(info['longest_aa_chain'])} of these chunks have been identified as amino acid " - f"residues and are connected via amide bonds:")) - blocks.append(("tabs", {"Chunks": info["chunks"], - "Amino acid residue": info["longest_aa_chain"]})) + blocks.append( + ( + "text", + f"{len(info['longest_aa_chain'])} of these chunks have been identified as amino acid " + f"residues and are connected via amide bonds:", + ) + ) + blocks.append( + ( + "tabs", + { + "Chunks": info["chunks"], + "Amino acid residue": info["longest_aa_chain"], + }, + ) + ) if len(info["longest_aa_chain"]) < 2: - blocks.append(("text", "Only one amino acid has been identified. Therefore, the molecule cannot be a " - "peptide, peptide anion, peptide zwitterion or peptide cation.")) + blocks.append( + ( + "text", + "Only one amino acid has been identified. Therefore, the molecule cannot be a " + "peptide, peptide anion, peptide zwitterion or peptide cation.", + ) + ) return blocks blocks.append(("heading", "Charge-based classification")) if info["charge_category"] == "SALT": - blocks.append(("text", "The molecule consists of disconnected anionic and cationic fragments. " - "Therefore, we classify it as a peptide salt. Since there is no class 'peptide salt'" - "in ChEBI, no prediction is made.")) + blocks.append( + ( + "text", + "The molecule consists of disconnected anionic and cationic fragments. " + "Therefore, we classify it as a peptide salt. Since there is no class 'peptide salt'" + "in ChEBI, no prediction is made.", + ) + ) return blocks elif info["charge_category"] == "CATION": - blocks.append(("text", "The molecule has a net positive charge, therefore it is a 'peptide cation'.")) + blocks.append( + ( + "text", + "The molecule has a net positive charge, therefore it is a 'peptide cation'.", + ) + ) return blocks elif info["charge_category"] == "ANION": - blocks.append(("text", "The molecule has a net negative charge, therefore it is a 'peptide anion'.")) + blocks.append( + ( + "text", + "The molecule has a net negative charge, therefore it is a 'peptide anion'.", + ) + ) return blocks elif info["charge_category"] == "ZWITTERION": - blocks.append(("text", "The molecule is overall neutral, but a zwitterion, i.e., it contains connected " - "(but non-adjacent) atoms with opposite charges.")) + blocks.append( + ( + "text", + "The molecule is overall neutral, but a zwitterion, i.e., it contains connected " + "(but non-adjacent) atoms with opposite charges.", + ) + ) if info["n_amino_acid_residues"] == 2: - blocks.append(("text", "Since we have identified 2 amino acid residues, the final classification is " - "'dipeptide zwitterion'.")) + blocks.append( + ( + "text", + "Since we have identified 2 amino acid residues, the final classification is " + "'dipeptide zwitterion'.", + ) + ) if info["n_amino_acid_residues"] == 3: - blocks.append(("text", "Since we have identified 3 amino acid residues, the final classification is " - "'tripeptide zwitterion'.")) + blocks.append( + ( + "text", + "Since we have identified 3 amino acid residues, the final classification is " + "'tripeptide zwitterion'.", + ) + ) return blocks - subclasses_dict = {2: "di", 3: "tri", 4: "tetra", 5: "penta", 6: "oligo", 7: "oligo", 8: "oligo", 9: "oligo", - 10: "poly"} - blocks.append(("text", "The molecule is overall neutral and not a zwitterion. Therefore, it is a peptide.")) - blocks.append(("text", f"More specifically, since we have identified " - f"{info["n_amino_acid_residues"]} amino acid residues," - f"the final classification is '{subclasses_dict[min(10, info["n_amino_acid_residues"])]}peptide'.")) + subclasses_dict = { + 2: "di", + 3: "tri", + 4: "tetra", + 5: "penta", + 6: "oligo", + 7: "oligo", + 8: "oligo", + 9: "oligo", + 10: "poly", + } + blocks.append( + ( + "text", + "The molecule is overall neutral and not a zwitterion. Therefore, it is a peptide.", + ) + ) + blocks.append( + ( + "text", + f"More specifically, since we have identified " + f"{info['n_amino_acid_residues']} amino acid residues," + f"the final classification is '{subclasses_dict[min(10, info['n_amino_acid_residues'])]}peptide'.", + ) + ) return blocks def build_explain_blocks_proteinogenics(self, proteinogenics, atoms): blocks = [("heading", "Proteinogenic amino acids")] if len(proteinogenics) == 0: - blocks.append(("text", "No proteinogenic amino acids have been identified.")) + blocks.append( + ("text", "No proteinogenic amino acids have been identified.") + ) return blocks - blocks.append(("text", "In addition to the classification, we have searched for the residues of 23 " - "proteinogenic amino acids in the molecule.")) - blocks.append(("text", "The following proteinogenic amino acids have been identified:")) + blocks.append( + ( + "text", + "In addition to the classification, we have searched for the residues of 23 " + "proteinogenic amino acids in the molecule.", + ) + ) + blocks.append( + ("text", "The following proteinogenic amino acids have been identified:") + ) proteinogenics_dict = {AA_DICT[aa]: [] for aa in proteinogenics} for aa, atoms_aa in zip(proteinogenics, atoms): proteinogenics_dict[AA_DICT[aa]].append(atoms_aa) @@ -198,11 +335,18 @@ def explain_smiles(self, smiles) -> dict: info = self.get_chemlog_result_info(smiles) highlight_blocks = self.build_explain_blocks_peptides(info) - for chebi_id, internal_name in [(64372, "emericellamide"), (65061, "2,5-diketopiperazines")]: + for chebi_id, internal_name in [ + (64372, "emericellamide"), + (65061, "2,5-diketopiperazines"), + ]: if f"{internal_name}_atoms" in info: - highlight_blocks += self.build_explain_blocks_atom_allocations(info[f"{internal_name}_atoms"], internal_name) - highlight_blocks += self.build_explain_blocks_proteinogenics(info["proteinogenics"], info["proteinogenics_locations"]) + highlight_blocks += self.build_explain_blocks_atom_allocations( + info[f"{internal_name}_atoms"], internal_name + ) + highlight_blocks += self.build_explain_blocks_proteinogenics( + info["proteinogenics"], info["proteinogenics_locations"] + ) return { "smiles": smiles, "highlights": highlight_blocks, - } \ No newline at end of file + } diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index a1b7084..7d64418 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,8 +1,9 @@ import numpy as np - -from chebifier.prediction_models.nn_predictor import NNPredictor from chebai.models.electra import Electra -from chebai.preprocessing.reader import ChemDataReader, EMBEDDING_OFFSET +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ChemDataReader + +from .nn_predictor import NNPredictor + def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): n_nodes = len(node_labels) @@ -35,7 +36,6 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): class ElectraPredictor(NNPredictor): - def __init__(self, model_name: str, ckpt_path: str, **kwargs): super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs) print(f"Initialised Electra model {self.model_name} (device: {self.device})") @@ -44,8 +44,10 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra: model = Electra.load_from_checkpoint( ckpt_path, map_location=self.device, - criterion=None, strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, ) model.eval() return model @@ -57,17 +59,16 @@ def explain_smiles(self, smiles) -> dict: result = self.calculate_results([token_dict]) token_labels = ( - ["[CLR]"] + [None for _ in range(EMBEDDING_OFFSET - 1)] + list(reader.cache.keys()) + ["[CLR]"] + + [None for _ in range(EMBEDDING_OFFSET - 1)] + + list(reader.cache.keys()) ) graphs = [ [ - build_graph_from_attention( - a[0, i], tokens, token_labels, threshold=0.1 - ) + build_graph_from_attention(a[0, i], tokens, token_labels, threshold=0.1) for i in range(a.shape[1]) ] for a in result["attentions"] ] return {"graphs": graphs} - diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index 9038846..edddba7 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -1,16 +1,18 @@ -from chebifier.prediction_models.nn_predictor import NNPredictor import chebai_graph.preprocessing.properties as p import torch from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred -from chebai_graph.preprocessing.reader import GraphPropertyReader from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder +from chebai_graph.preprocessing.reader import GraphPropertyReader from torch_geometric.data.data import Data as GeomData +from .nn_predictor import NNPredictor -class ResGatedPredictor(NNPredictor): +class ResGatedPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs): - super().__init__(model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs) + super().__init__( + model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs + ) # molecular_properties is a list of class paths if molecular_properties is not None: properties = [self.load_class(prop)() for prop in molecular_properties] @@ -32,11 +34,13 @@ def load_class(self, class_path: str): def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred: model = ResGatedGraphConvNetGraphPred.load_from_checkpoint( - ckpt_path, map_location=torch.device(self.device), criterion=None, strict=False, input_dim=128, - metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None, - config={"in_length": 256, "hidden_length": 512, "dropout_rate": 0.1, "n_conv_layers": 3, - "n_linear_layers": 3, "n_atom_properties": 158, "n_bond_properties": 7, - "n_molecule_properties": 200}) + ckpt_path, + map_location=torch.device(self.device), + criterion=None, + strict=False, + metrics=dict(train=dict(), test=dict(), validation=dict()), + pretrained_checkpoint=None, + ) model.eval() return model @@ -55,14 +59,21 @@ def read_smiles(self, smiles): # use default value if we meet an unseen value if isinstance(prop.encoder, IndexEncoder): if str(value) in prop.encoder.cache: - index = prop.encoder.cache.index(str(value)) + prop.encoder.offset + index = ( + prop.encoder.cache.index(str(value)) + prop.encoder.offset + ) else: index = 0 - print(f"Unknown property value {value} for property {prop} at smiles {smiles}") + print( + f"Unknown property value {value} for property {prop} at smiles {smiles}" + ) if isinstance(prop.encoder, OneHotEncoder): - encoded_values.append(torch.nn.functional.one_hot( - torch.tensor(index), num_classes=prop.encoder.get_encoding_length() - )) + encoded_values.append( + torch.nn.functional.one_hot( + torch.tensor(index), + num_classes=prop.encoder.get_encoding_length(), + ) + ) else: encoded_values.append(torch.tensor([index])) @@ -77,9 +88,7 @@ def read_smiles(self, smiles): if len(encoded_values.size()) == 1: encoded_values = encoded_values.unsqueeze(1) else: - encoded_values = torch.zeros( - (0, prop.encoder.get_encoding_length()) - ) + encoded_values = torch.zeros((0, prop.encoder.get_encoding_length())) if isinstance(prop, p.AtomProperty): x = torch.cat([x, encoded_values], dim=1) elif isinstance(prop, p.BondProperty): @@ -93,4 +102,4 @@ def read_smiles(self, smiles): edge_attr=edge_attr, molecule_attr=molecule_attr, ) - return d \ No newline at end of file + return d diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index 1ee5e46..3b603b5 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,24 +1,34 @@ -import tqdm - -from chebifier.prediction_models.base_predictor import BasePredictor -from rdkit import Chem import numpy as np import torch +import tqdm +from rdkit import Chem + +from .base_predictor import BasePredictor -class NNPredictor(BasePredictor): - def __init__(self, model_name: str, ckpt_path: str, reader_cls, target_labels_path: str, **kwargs): +class NNPredictor(BasePredictor): + def __init__( + self, + model_name: str, + ckpt_path: str, + reader_cls, + target_labels_path: str, + **kwargs, + ): super().__init__(model_name, **kwargs) self.reader_cls = reader_cls self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = self.init_model(ckpt_path=ckpt_path) - self.target_labels = [line.strip() for line in open(target_labels_path, encoding="utf-8")] + self.target_labels = [ + line.strip() for line in open(target_labels_path, encoding="utf-8") + ] self.batch_size = kwargs.get("batch_size", 1) - def init_model(self, ckpt_path: str, **kwargs): - raise NotImplementedError("Model initialization must be implemented in subclasses.") + raise NotImplementedError( + "Model initialization must be implemented in subclasses." + ) def calculate_results(self, batch): collator = self.reader_cls.COLLATOR() @@ -66,14 +76,27 @@ def predict_smiles_list(self, smiles_list) -> list: token_dicts.append(d) results = [] if token_dicts: - for batch in tqdm.tqdm(self.batchify(token_dicts), desc=f"{self.model_name}", total=len(token_dicts)//self.batch_size): + for batch in tqdm.tqdm( + self.batchify(token_dicts), + desc=f"{self.model_name}", + total=len(token_dicts) // self.batch_size, + ): result = self.calculate_results(batch) if isinstance(result, dict) and "logits" in result: result = result["logits"] results += torch.sigmoid(result).cpu().detach().tolist() results = np.stack(results, axis=0) - preds = [{self.target_labels[j]: p for j, p in enumerate(results[index_map[i]])} - if i not in could_not_parse else None for i in range(len(smiles_list))] + preds = [ + ( + { + self.target_labels[j]: p + for j, p in enumerate(results[index_map[i]]) + } + if i not in could_not_parse + else None + ) + for i in range(len(smiles_list)) + ] return preds else: - return [None for _ in smiles_list] \ No newline at end of file + return [None for _ in smiles_list] diff --git a/configs/huggingface_config.yml b/configs/huggingface_config.yml new file mode 100644 index 0000000..c26950d --- /dev/null +++ b/configs/huggingface_config.yml @@ -0,0 +1,22 @@ + +chemlog_peptides: + type: chemlog + model_weight: 100 + +#resgated_huggingface: +# type: resgated +# hugging_face: +# repo_id: aditya0by0/python-chebifier +# subfolder: resgated +# files: +# ckpt: resgated.ckpt +# labels: classes.txt + +electra_huggingface: + type: electra + hugging_face: + repo_id: aditya0by0/python-chebifier + subfolder: electra + files: + ckpt: electra.ckpt + labels: classes.txt diff --git a/pyproject.toml b/pyproject.toml index 8a0223f..ff7837d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,9 +27,6 @@ dependencies = [ "chemlog>=1.0.4" ] -[project.scripts] -chebifier = "chebifier.cli:cli" - [tool.setuptools] packages = ["chebifier", "chebifier.ensemble", "chebifier.prediction_models"]