Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .github/workflows/black.yml

This file was deleted.

26 changes: 26 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10' # or any version your project uses

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install black==25.1.0 ruff==0.12.2

- name: Run Black
run: black --check .

- name: Run Ruff (no formatting)
run: ruff check . --no-fix
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,4 @@ electra_pretrained.ckpt
.jupyter
.virtual_documents
.isort.cfg
.vscode
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: "24.2.0"
rev: "25.1.0"
hooks:
- id: black
- id: black-jupyter # for formatting jupyter-notebook
Expand All @@ -23,3 +23,9 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
108 changes: 10 additions & 98 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
## 🔧 Installation


To install, follow these steps:
To install this repository, download [`python-chebai`](https://github.com/ChEB-AI/python-chebai) and this repository, then run

1. Clone the repository:
```
git clone https://github.com/ChEB-AI/python-chebai-proteins.git
cd python-chebai
pip install .

cd python-chebai-proteins
pip install .
```

2. Install the package:
_Note for developers_: If you want to install the package in editable mode, use the following command instead:

```bash
pip install -e .
```
cd python-chebai
pip install .
```


## 🗂 Recommended Folder Structure

Expand All @@ -43,39 +46,6 @@ This setup enables shared access to data and model configurations.

## 🚀 Training & Pretraining Guide

### ⚠️ Important Setup Instructions

Before running any training scripts, ensure the environment is correctly configured:

* Either:

* Install the `python-chebai` repository as a package using:

```bash
pip install .
```
* **OR**

* Manually set the `PYTHONPATH` environment variable if working across multiple directories (`python-chebai` and `python-chebai-proteins`):

* If your current working directory is `python-chebai-proteins`, set:

```bash
export PYTHONPATH=path/to/python-chebai
```
or vice versa.

* If you're working within both repositories simultaneously or facing module not found errors, we **recommend configuring both directories**:

```bash
# Linux/macOS
export PYTHONPATH=path/to/python-chebai:path/to/python-chebai-proteins

# Windows (use semicolon instead of colon)
set PYTHONPATH=path\to\python-chebai;path\to\python-chebai-proteins
```

> 🔎 See the [PYTHONPATH Explained](#-pythonpath-explained) section below for more details.


### 📊 SCOPE hierarchy prediction
Expand All @@ -86,61 +56,3 @@ python -m chebai fit --trainer=../configs/training/default_trainer.yml --trainer
```

Same command can be used for **DeepGO** just by changing the config path for data.







## 🧭 PYTHONPATH Explained

### What is `PYTHONPATH`?

`PYTHONPATH` is an environment variable that tells Python where to search for modules that aren't installed via `pip` or not in your current working directory.

### Why You Need It

If your config refers to a custom module like:

```yaml
class_path: chebai_proteins.preprocessing.datasets.scope.scope.SCOPe50
```

...and you're running the code from `python-chebai`, Python won't know where to find `chebai_proteins` (from another repo like `python-chebai-proteins/`) unless you add it to `PYTHONPATH`.


### How Python Finds Modules

Python looks for imports in this order:

1. Current directory
2. Standard library
3. Paths in `PYTHONPATH`
4. Installed packages (`site-packages`)

You can inspect the full search paths:

```bash
python -c "import sys; print(sys.path)"
```



### ✅ Setting `PYTHONPATH`

#### 🐧 Linux / macOS

```bash
export PYTHONPATH=/path/to/python-chebai-graph
echo $PYTHONPATH
```

#### 🪟 Windows CMD

```cmd
set PYTHONPATH=C:\path\to\python-chebai-graph
echo %PYTHONPATH%
```

> 💡 Note: This is temporary for your terminal session. To make it permanent, add it to your system environment variables.
7 changes: 7 additions & 0 deletions chebai_proteins/loss/bce_logits_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch


class WrappedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss):
def forward(self, input, target, **kwargs):
# As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
return super().forward(input, target)
51 changes: 35 additions & 16 deletions chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,43 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC):

# Gene Ontology (GO) has three major branches, one for biological processes (BP), molecular functions (MF) and
# cellular components (CC). The value "all" will take data related to all three branches into account.
# TODO: should we be really allowing all branches for single dataset?
_ALL_GO_BRANCHES: str = "all"
_GO_BRANCH_NAMESPACE: Dict[str, str] = {
"BP": "biological_process",
"MF": "molecular_function",
"CC": "cellular_component",
"BP": "biological_process", # Huge branch, with 20,000+ GO terms
"MF": "molecular_function", # smaller branch, with 6000+ GO terms
"CC": "cellular_component", # smallest branch, with 2,000+ GO terms
}

def __init__(self, **kwargs):
self.go_branch: str = self._get_go_branch(**kwargs)
READER = None

def __init__(
self,
go_branch: str,
max_sequence_len: int = 1002,
use_esm2_embeddings: bool = False,
**kwargs,
):
if bool(use_esm2_embeddings):
self.READER = dr.ESM2EmbeddingReader

self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
self.go_branch: str = self._get_go_branch(go_branch)

self.max_sequence_length: int = int(max_sequence_len)
assert (
self.max_sequence_length >= 1
), "Max sequence length should be greater than or equal to 1."

super(_GOUniProtDataExtractor, self).__init__(**kwargs)

if self.reader.n_gram is not None:
if hasattr(self.reader, "n_gram") and self.reader.n_gram is not None:
assert self.max_sequence_length >= self.reader.n_gram, (
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
f"or equal to n_gram ({self.reader.n_gram})."
)

@classmethod
def _get_go_branch(cls, **kwargs) -> str:
def _get_go_branch(cls, go_branch_value: str, **kwargs) -> str:
"""
Retrieves the Gene Ontology (GO) branch based on provided keyword arguments.
This method checks if a valid GO branch value is provided in the keyword arguments.
Expand All @@ -141,7 +153,6 @@ def _get_go_branch(cls, **kwargs) -> str:
ValueError: If the provided 'go_branch' value is not in the allowed list of values.
"""

go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES)
allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES]
if go_branch_value not in allowed_values:
raise ValueError(
Expand Down Expand Up @@ -181,7 +192,7 @@ def _download_gene_ontology_data(self) -> str:

if not os.path.isfile(go_path):
print("Missing Gene Ontology raw data")
print(f"Downloading Gene Ontology data....")
print("Downloading Gene Ontology data....")
r = requests.get(self._GO_DATA_URL, allow_redirects=True)
r.raise_for_status() # Check if the request was successful
open(go_path, "wb").write(r.content)
Expand All @@ -207,7 +218,7 @@ def _download_swiss_uni_prot_data(self) -> Optional[str]:
os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True)

if not os.path.isfile(uni_prot_file_path):
print(f"Downloading Swiss UniProt data....")
print("Downloading Swiss UniProt data....")

# Create a temporary file
with NamedTemporaryFile(delete=False) as tf:
Expand All @@ -223,7 +234,7 @@ def _download_swiss_uni_prot_data(self) -> Optional[str]:

# Unpack the gzipped file
try:
print(f"Unzipping the file....")
print("Unzipping the file....")
with gzip.open(temp_filename, "rb") as f_in:
output_file_path = uni_prot_file_path
with open(output_file_path, "wb") as f_out:
Expand Down Expand Up @@ -375,7 +386,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
Returns:
pd.DataFrame: The raw dataset created from the graph.
"""
print(f"Processing graph")
print("Processing graph")

data_df = self._get_swiss_to_go_mapping()
# add ancestors to go ids
Expand Down Expand Up @@ -457,6 +468,14 @@ def _get_swiss_to_go_mapping(self) -> pd.DataFrame:

if not record.sequence or len(record.sequence) > self.max_sequence_length:
# Consider protein with only sequence representation and seq. length not greater than max seq. length

# DeepGO1 paper ignores proteins with sequence length greater than 1002: https://github.com/bio-ontology-research-group/deepgo/blob/master/aaindex.py#L9-L14
# But DeepGO2 paper truncates the sequence to 1000: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33
# Latest Discussion: https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2385693976
# So, we ignore proteins with sequence length greater than max_sequence_length
# The rationale is that with only a partial representation of the protein sequence, the model may not learn effectively.
# Also, proteins longer than 1002 are only 3.32% of the total proteins in Swiss-Prot dataset.
# https://github.com/ChEB-AI/python-chebai/issues/36#issuecomment-2431460448
continue

if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
Expand Down Expand Up @@ -559,8 +578,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
)
except FileNotFoundError:
raise FileNotFoundError(
f"File data.pt doesn't exists. "
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
"File data.pt doesn't exists. "
"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
)

df_go_data = pd.DataFrame(data_go)
Expand All @@ -586,7 +605,7 @@ def base_dir(self) -> str:
Returns:
str: The path to the base directory, which is "data/GO_UniProt".
"""
return os.path.join("data", f"GO_UniProt")
return os.path.join("data", "GO_UniProt")

@property
def raw_file_names_dict(self) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, **kwargs):
Args:
**kwargs: Additional arguments for the superclass initialization.
"""
self._go_uniprot_extractor = GOUniProtOver250()
self._go_uniprot_extractor = GOUniProtOver250(go_branch="all")
assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES

self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
Expand Down Expand Up @@ -143,7 +143,6 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
has_valid_associated_go_label = False
for cross_ref in record.cross_references:
if cross_ref[0] == self._go_uniprot_extractor._GO_DATA_INIT:

if len(cross_ref) <= 3:
# No evidence code
continue
Expand Down Expand Up @@ -223,8 +222,8 @@ def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
)
except FileNotFoundError:
raise FileNotFoundError(
f"File data.pt doesn't exists. "
f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
"File data.pt doesn't exists. "
"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files"
)

df_go_data = pd.DataFrame(data_go)
Expand Down
Loading
Loading