-
Notifications
You must be signed in to change notification settings - Fork 13
Added and integrated C++ graphium_cpp library, a Python module implem… #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
5ffe261
8286383
dca9b2b
8304210
4ee35d4
cf23e37
5db0e2a
c75a452
4aa1f85
c53451a
268e245
e032e8e
bdefe89
5298444
c38aa06
86abf21
80276da
9492e62
d94097c
11e6935
ff93c2d
a892068
38a5510
f7771b3
4256839
91c37a3
f347a0d
26b5531
1ded38b
314d636
e49b4da
f001464
2782fbc
77d27b5
cb1df19
f3f6a0d
c5c0085
6dd827f
59c84a2
7bc8ade
6903243
9f38afb
5ab9ca9
9c7504f
f8358f3
692decc
8891e66
c2d3c87
d3d19d7
0a1696f
50265df
58fc2aa
5852467
ea9a775
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,12 @@ | ||
| """ | ||
| -------------------------------------------------------------------------------- | ||
| Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals and Graphcore. | ||
| Copyright (c) 2023 Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates. | ||
|
|
||
| Use of this software is subject to the terms and conditions outlined in the LICENSE file. | ||
| Unauthorized modification, distribution, or use is prohibited. Provided 'as is' without | ||
| warranties of any kind. | ||
|
|
||
| Valence Labs, Recursion Pharmaceuticals and Graphcore are not liable for any damages arising from its use. | ||
| Valence Labs, Recursion Pharmaceuticals, Graphcore, and NVIDIA Corporation & Affiliates are not liable for any damages arising from its use. | ||
| Refer to the LICENSE file for the full terms and conditions. | ||
| -------------------------------------------------------------------------------- | ||
| """ | ||
|
|
@@ -26,11 +26,11 @@ | |
| from graphium.utils.packing import fast_packing, get_pack_sizes, node_to_pack_indices_mask | ||
| from loguru import logger | ||
| from graphium.data.utils import get_keys | ||
|
|
||
| from graphium.data.dataset import torch_enum_to_dtype | ||
|
|
||
| def graphium_collate_fn( | ||
| elements: Union[List[Any], Dict[str, List[Any]]], | ||
| labels_size_dict: Optional[Dict[str, Any]] = None, | ||
| labels_num_cols_dict: Optional[Dict[str, Any]] = None, | ||
| labels_dtype_dict: Optional[Dict[str, Any]] = None, | ||
| mask_nan: Union[str, float, Type[None]] = "raise", | ||
| do_not_collate_keys: List[str] = [], | ||
|
|
@@ -52,7 +52,7 @@ def graphium_collate_fn( | |
| elements: | ||
| The elements to batch. See `torch.utils.data.dataloader.default_collate`. | ||
|
|
||
| labels_size_dict: | ||
| labels_num_cols_dict: | ||
| (Note): This is an attribute of the `MultitaskDataset`. | ||
| A dictionary of the form Dict[tasks, sizes] which has task names as keys | ||
| and the size of the label tensor as value. The size of the tensor corresponds to how many | ||
|
|
@@ -86,14 +86,26 @@ def graphium_collate_fn( | |
| The batched elements. See `torch.utils.data.dataloader.default_collate`. | ||
| """ | ||
|
|
||
| # Skip any elements that failed | ||
| if None in elements: | ||
| elements = [e for e in elements if e is not None] | ||
|
|
||
| elem = elements[0] | ||
| if isinstance(elem, Mapping): | ||
| batch = {} | ||
| for key in elem: | ||
| # Multitask setting: We have to pad the missing labels | ||
| if key == "labels": | ||
| labels = [d[key] for d in elements] | ||
| batch[key] = collate_labels(labels, labels_size_dict, labels_dtype_dict) | ||
| if "features" in elem: | ||
| num_nodes = [d["features"].num_nodes for d in elements] | ||
| num_edges = [d["features"].num_edges for d in elements] | ||
| else: | ||
| num_nodes = [d["num_nodes"] for d in elements] | ||
| num_edges = [d["num_edges"] for d in elements] | ||
|
||
| batch[key] = collate_labels(labels, labels_num_cols_dict, labels_dtype_dict, num_nodes, num_edges) | ||
| elif key == "num_nodes" or key == "num_edges": | ||
| continue | ||
|
|
||
| # If the features are a dictionary containing GraphDict elements, | ||
| # Convert to pyg graphs and use the pyg batching. | ||
|
|
@@ -182,23 +194,21 @@ def collage_pyg_graph(pyg_graphs: Iterable[Union[Data, Dict]], batch_size_per_pa | |
| return Batch.from_data_list(pyg_batch) | ||
|
|
||
|
|
||
| def pad_to_expected_label_size(labels: torch.Tensor, label_size: List[int]): | ||
| def pad_to_expected_label_size(labels: torch.Tensor, label_rows: int, label_cols: int): | ||
| """Determine difference of ``labels`` shape to expected shape `label_size` and pad | ||
| with ``torch.nan`` accordingly. | ||
| """ | ||
| if label_size == list(labels.shape): | ||
| if len(labels.shape) == 2 and label_rows == labels.shape[0] and label_cols == labels.shape[1]: | ||
| return labels | ||
|
|
||
| missing_dims = len(label_size) - len(labels.shape) | ||
| missing_dims = 2 - len(labels.shape) | ||
| for _ in range(missing_dims): | ||
| labels.unsqueeze(-1) | ||
|
|
||
| pad_sizes = [(0, expected - actual) for expected, actual in zip(label_size, labels.shape)] | ||
| pad_sizes = [item for before_after in pad_sizes for item in before_after] | ||
| pad_sizes.reverse() | ||
| pad_sizes = [label_cols - labels.shape[1], 0, label_rows - labels.shape[0], 0] | ||
|
|
||
| if any([s < 0 for s in pad_sizes]): | ||
| logger.warning(f"More labels available than expected. Will remove data to fit expected size.") | ||
| logger.warning(f"More labels available than expected. Will remove data to fit expected size. cols: {labels.shape[1]}->{label_cols}, rows: {labels.shape[0]}->{label_rows}") | ||
|
|
||
| return torch.nn.functional.pad(labels, pad_sizes, value=torch.nan) | ||
|
|
||
|
|
@@ -226,31 +236,41 @@ def collate_pyg_graph_labels(pyg_labels: List[Data]): | |
| return Batch.from_data_list(pyg_batch) | ||
|
|
||
|
|
||
| def get_expected_label_size(label_data: Data, task: str, label_size: List[int]): | ||
| def get_expected_label_rows( | ||
| label_data: Data, | ||
| task: str, | ||
| num_nodes: int, | ||
| num_edges: int | ||
| ): | ||
| """Determines expected label size based on the specfic graph properties | ||
| and the number of targets in the task-dataset. | ||
| """ | ||
| if task.startswith("graph_"): | ||
| num_labels = 1 | ||
| elif task.startswith("node_"): | ||
| num_labels = label_data.x.size(0) | ||
| num_labels = num_nodes | ||
| elif task.startswith("edge_"): | ||
| num_labels = label_data.edge_index.size(1) | ||
| num_labels = num_edges | ||
| elif task.startswith("nodepair_"): | ||
| raise NotImplementedError() | ||
| return [num_labels] + label_size | ||
| else: | ||
| print("Task name "+task+" in get_expected_label_rows") | ||
| raise NotImplementedError() | ||
| return num_labels | ||
|
|
||
|
|
||
| def collate_labels( | ||
| labels: List[Data], | ||
| labels_size_dict: Optional[Dict[str, Any]] = None, | ||
| labels_num_cols_dict: Optional[Dict[str, Any]] = None, | ||
| labels_dtype_dict: Optional[Dict[str, Any]] = None, | ||
| num_nodes: List[int] = None, | ||
| num_edges: List[int] = None | ||
| ): | ||
| """Collate labels for multitask learning. | ||
|
|
||
| Parameters: | ||
| labels: List of labels | ||
| labels_size_dict: Dict of the form Dict[tasks, sizes] which has task names as keys | ||
| labels_num_cols_dict: Dict of the form Dict[tasks, sizes] which has task names as keys | ||
| and the size of the label tensor as value. The size of the tensor corresponds to how many | ||
| labels/values there are to predict for that task. | ||
| labels_dtype_dict: | ||
|
|
@@ -260,45 +280,39 @@ def collate_labels( | |
|
|
||
| Returns: | ||
| A dictionary of the form Dict[tasks, labels] where tasks is the name of the task and labels | ||
| is a tensor of shape (batch_size, *labels_size_dict[task]). | ||
| is a tensor of shape (batch_size, *labels_num_cols_dict[task]). | ||
| """ | ||
| if labels_size_dict is not None: | ||
| for this_label in labels: | ||
| for task in labels_size_dict.keys(): | ||
| labels_size_dict[task] = list(labels_size_dict[task]) | ||
| if len(labels_size_dict[task]) >= 2: | ||
| labels_size_dict[task] = labels_size_dict[task][1:] | ||
| elif not task.startswith("graph_"): | ||
| labels_size_dict[task] = [1] | ||
| if labels_num_cols_dict is not None: | ||
| for index, this_label in enumerate(labels): | ||
| label_keys_set = set(get_keys(this_label)) | ||
| empty_task_labels = set(labels_size_dict.keys()) - label_keys_set | ||
| empty_task_labels = set(labels_num_cols_dict.keys()) - label_keys_set | ||
| for task in empty_task_labels: | ||
| labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) | ||
| dtype = labels_dtype_dict[task] | ||
| this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype) | ||
| label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) | ||
| dtype = torch_enum_to_dtype(labels_dtype_dict[task]) | ||
| this_label[task] = torch.full((label_rows, labels_num_cols_dict[task]), fill_value=torch.nan, dtype=dtype) | ||
|
|
||
| for task in label_keys_set - set(["x", "edge_index"]) - empty_task_labels: | ||
| labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task]) | ||
| label_rows = get_expected_label_rows(this_label, task, num_nodes[index], num_edges[index]) | ||
|
|
||
| if not isinstance(this_label[task], (torch.Tensor)): | ||
| this_label[task] = torch.as_tensor(this_label[task]) | ||
|
|
||
| # Ensure explicit task dimension also for single task labels | ||
| if len(this_label[task].shape) == 1: | ||
| # Distinguish whether target dim or entity dim is missing | ||
| if labels_size_dict[task][0] == this_label[task].shape[0]: | ||
| if label_rows == this_label[task].shape[0]: | ||
| # num graphs/nodes/edges/nodepairs already matching | ||
| this_label[task] = this_label[task].unsqueeze(1) | ||
| else: | ||
| # data lost unless entity dim is supposed to be 1 | ||
| if labels_size_dict[task][0] == 1: | ||
| if label_rows == 1: | ||
| this_label[task] = this_label[task].unsqueeze(0) | ||
| else: | ||
| raise ValueError( | ||
| f"Labels for {labels_size_dict[task][0]} nodes/edges/nodepairs expected, got 1." | ||
| f"Labels for {label_rows} nodes/edges/nodepairs expected, got 1." | ||
| ) | ||
|
|
||
| this_label[task] = pad_to_expected_label_size(this_label[task], labels_size_dict[task]) | ||
| this_label[task] = pad_to_expected_label_size(this_label[task], label_rows, labels_num_cols_dict[task]) | ||
|
|
||
| return collate_pyg_graph_labels(labels) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.