Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
02c5409
api code to download model from hugging face
aditya0by0 Jun 24, 2025
c13423c
Merge branch 'dev' into feature/api_downloadble_models
aditya0by0 Jun 27, 2025
b539f0a
Create .pre-commit-config.yaml
aditya0by0 Jun 27, 2025
2c2aba2
utility to setup env and model package dependencies
aditya0by0 Jun 27, 2025
2b9f335
`gather_predictions` will return predicted_classes_dict
aditya0by0 Jun 27, 2025
6faf3bd
use package namespace imports for prediction models
aditya0by0 Jun 28, 2025
a4f5f85
add hugging face api
aditya0by0 Jun 28, 2025
481a2eb
api registry
aditya0by0 Jun 28, 2025
584b6a6
api cli
aditya0by0 Jun 28, 2025
05d8580
Update .gitignore
aditya0by0 Jun 28, 2025
997120e
use hugging face's cache system instead of custom file management
aditya0by0 Jun 28, 2025
9c3beea
pre-commit -run -a
aditya0by0 Jun 28, 2025
e6602ef
remove explicit config kwargs for resgated
aditya0by0 Jul 1, 2025
fd814e9
api support for ensemble
aditya0by0 Jul 6, 2025
a044f23
add ruff action workflow
aditya0by0 Jul 6, 2025
51a2d34
same version for workflow and pre-commit yaml
aditya0by0 Jul 6, 2025
d2c586a
Update base_predictor.py
aditya0by0 Jul 6, 2025
e0b3ca7
merge from dev
aditya0by0 Jul 9, 2025
ebc450f
Merge branch 'refs/heads/dev' into feature/api_downloadble_models
sfluegel05 Jul 11, 2025
f3b3905
fix readme
sfluegel05 Jul 11, 2025
001538d
fix cli and ensemble imports
sfluegel05 Jul 11, 2025
f8583cb
add huggingface download to cli
sfluegel05 Jul 11, 2025
90aedd4
reformat with black
sfluegel05 Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
docs/build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# configs/ # commented as new configs can be added as a part of a feature

/.idea
/data
/logs
/results_buffer
electra_pretrained.ckpt

build
.virtual_documents
.jupyter
chebai.egg-info
lightning_logs
logs
.isort.cfg
/.vscode
/api/.cloned_repos
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
repos:
- repo: https://github.com/psf/black
rev: "24.2.0"
hooks:
- id: black
- id: black-jupyter # for formatting jupyter-notebook

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile=black"]

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.1
hooks:
- id: ruff
args: [] # No --fix, disables formatting
Empty file added api/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions api/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .cli import cli

if __name__ == "__main__":
"""
Entry point for the CLI application.

This script calls the `cli` function from the `api.cli` module
when executed as the main program.
"""
cli()
112 changes: 112 additions & 0 deletions api/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import importlib
from pathlib import Path

import click
import yaml

from chebifier.prediction_models.base_predictor import BasePredictor

from .hugging_face import download_model_files
from .setup_env import SetupEnvAndPackage

yaml_path = Path("api/registry.yml")
if yaml_path.exists():
with yaml_path.open("r") as f:
model_registry = yaml.safe_load(f)
else:
raise FileNotFoundError(f"{yaml_path} not found.")


@click.group()
def cli():
"""Command line interface for Api-Chebifier."""
pass


@cli.command()
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
@click.option(
"--smiles-file",
"-f",
type=click.Path(exists=True),
help="File containing SMILES strings (one per line)",
)
@click.option(
"--output",
"-o",
type=click.Path(),
help="Output file to save predictions (optional)",
)
@click.option(
"--model-type",
"-m",
type=click.Choice(model_registry.keys()),
default="mv",
help="Type of model to use",
)
def predict(smiles, smiles_file, output, model_type):
"""Predict ChEBI classes for SMILES strings using an ensemble model.

CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
"""

# Collect SMILES strings from arguments and/or file
smiles_list = list(smiles)
if smiles_file:
with open(smiles_file, "r") as f:
smiles_list.extend([line.strip() for line in f if line.strip()])

if not smiles_list:
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
return

model_config = model_registry[model_type]
predictor_kwargs = {"model_name": model_type}

current_dir = Path(__file__).resolve().parent

if "hugging_face" in model_config:
print(f"For model type `{model_type}` following files are used:")
local_file_path = download_model_files(model_config["hugging_face"])
predictor_kwargs["ckpt_path"] = local_file_path["ckpt"]
predictor_kwargs["target_labels_path"] = local_file_path["labels"]

SetupEnvAndPackage().setup(
repo_url=model_config["repo_url"],
clone_dir=current_dir / ".cloned_repos",
venv_dir=current_dir,
)

model_cls_path = model_config["wrapper"]
module_path, class_name = model_cls_path.rsplit(".", 1)
module = importlib.import_module(module_path)
model_cls: type = getattr(module, class_name)
model_instance = model_cls(**predictor_kwargs)
assert isinstance(model_instance, BasePredictor)

# Make predictions
predictions = model_instance.predict_smiles_list(smiles_list)

if output:
# save as json
import json

with open(output, "w") as f:
json.dump(
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
f,
indent=2,
)

else:
# Print results
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
click.echo(f"Result for: {smiles}")
if prediction:
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
else:
click.echo(" No predictions")


if __name__ == "__main__":
cli()
48 changes: 48 additions & 0 deletions api/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Hugging Face Api:
- For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations

Refer for Hugging Face Hub caching and versioning documentation:
https://huggingface.co/docs/huggingface_hub/en/guides/download
https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache
"""

from pathlib import Path

from huggingface_hub import hf_hub_download


def download_model_files(
model_config: dict[str, str | dict[str, str]],
) -> dict[str, Path]:
"""
Downloads specified model files from a Hugging Face Hub repository using hf_hub_download.

Hugging Face Hub provides internal caching and versioning, so file management or duplication
checks are not required.

Args:
model_config (Dict[str, str | Dict[str, str]]): A dictionary containing:
- 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname').
- 'subfolder' (str): The subfolder within the repo where the files are located.
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to
actual file names (e.g., 'electra.ckpt', 'classes.txt').

Returns:
Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file.
"""
repo_id = model_config["repo_id"]
subfolder = model_config["subfolder"]
filenames = model_config["files"]

local_paths: dict[str, Path] = {}
for file_type, filename in filenames.items():
downloaded_file_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
)
local_paths[file_type] = Path(downloaded_file_path)
print(f"\t Using file `{filename}` from: {downloaded_file_path}")

return local_paths
Loading