Skip to content

Improve Orientation transform to use the "space" (LPS vs RAS) of a metatensor by default #8473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: dev
Choose a base branch
from
49 changes: 44 additions & 5 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
GridSamplePadMode,
InterpolateMode,
NumpyPadMode,
SpaceKeys,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
Expand All @@ -75,6 +76,7 @@
issequenceiterable,
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg_default
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
Expand Down Expand Up @@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):

backend = [TransformBackends.NUMPY, TransformBackends.TORCH]

@deprecated_arg_default(
name="labels",
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
new_default=None,
msg_suffix=(
"Default value changed to None meaning that the transform now uses the 'space' of a "
"meta-tensor, if applicable, to determine appropriate axis labels."
),
)
def __init__(
self,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
lazy: bool = False,
) -> None:
"""
Expand All @@ -573,7 +584,14 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
If ``None``, an appropriate value is chosen depending on the
value of the ``"space"`` metadata item of a metatensor: if
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
input is not a meta-tensor or has no ``"space"`` item, the
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
``None``, the provided value is always used and the ``"space"``
metadata item (if any) of the input is ignored.
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False

Expand Down Expand Up @@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
affine_: np.ndarray
affine_np: np.ndarray
labels = self.labels
if isinstance(data_array, MetaTensor):
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
affine_ = to_affine_nd(sr, affine_np)

# Set up "labels" such that LPS tensors are handled correctly by default
if (
self.labels is None
and "space" in data_array.meta
and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS
):
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

else:
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
# default to identity
Expand All @@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
"please make sure the input is in the channel-first format."
)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
if len(dst) < sr:
raise ValueError(
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
Expand All @@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
labels = self.labels

# Set up "labels" such that LPS tensors are handled correctly by default
if (
isinstance(data, MetaTensor)
and self.labels is None
and "space" in data.meta
and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
):
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data)
Expand Down
21 changes: 19 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
ensure_tuple_rep,
fall_back_tuple,
)
from monai.utils.deprecate_utils import deprecated_arg_default
from monai.utils.enums import TraceKeys
from monai.utils.module import optional_import

Expand Down Expand Up @@ -545,12 +546,21 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform):

backend = Orientation.backend

@deprecated_arg_default(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also have replaced="1.6".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this already actually but then monai becomes unimportable, so I'm not really sure how this is supposed to work 🤷:

>>> import monai
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/cpb28/Developer/monai/monai/__init__.py", line 101, in <module>
    load_submodules(sys.modules[__name__], False, exclude_pattern=excludes)
  File "/Users/cpb28/Developer/monai/monai/utils/module.py", line 187, in load_submodules
    mod = import_module(name)
          ^^^^^^^^^^^^^^^^^^^
  File "/Users/cpb28/.pyenv/versions/3.12.1/lib/python3.12/importlib/__init__.py", line 90, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cpb28/Developer/monai/monai/apps/__init__.py", line 14, in <module>
    from .datasets import CrossValidation, DecathlonDataset, MedNISTDataset, TciaDataset
  File "/Users/cpb28/Developer/monai/monai/apps/datasets.py", line 33, in <module>
    from monai.data import (
  File "/Users/cpb28/Developer/monai/monai/data/__init__.py", line 29, in <module>
    from .dataset import (
  File "/Users/cpb28/Developer/monai/monai/data/dataset.py", line 39, in <module>
    from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id
  File "/Users/cpb28/Developer/monai/monai/transforms/__init__.py", line 241, in <module>
    from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
  File "/Users/cpb28/Developer/monai/monai/transforms/io/array.py", line 31, in <module>
    from monai.data import image_writer
  File "/Users/cpb28/Developer/monai/monai/data/image_writer.py", line 23, in <module>
    from monai.transforms.spatial.array import Resize, SpatialResample
  File "/Users/cpb28/Developer/monai/monai/transforms/spatial/array.py", line 551, in <module>
    class Orientation(InvertibleTransform, LazyTransform):
  File "/Users/cpb28/Developer/monai/monai/transforms/spatial/array.py", line 561, in Orientation
    @deprecated_arg_default(
     ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cpb28/Developer/monai/monai/utils/deprecate_utils.py", line 313, in _decorator
    raise ValueError(
ValueError: Argument `labels` was replaced to the new default value `None` before the specified version 1.6.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let's leave that off and it'll be something we revisit when looking at the deprecated items before the next release.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so is there anything left to do on this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also specify since?
something like since=1.5, replaced=1.7

Copy link
Author

@CPBridge CPBridge Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @KumoLiu, I'm happy to make whatever changes necessary but I just don't really understand how this is supposed to work. I feel like I'm missing something here.

If I specify since as "1.5" or earlier, I get an error on import like the above, ending in

ValueError: Argument `labels` was replaced to the new default value `None` before the specified version None.

Similarly, if I specify replaced as "1.6" or later, I get e.g. (for "1.7" as you suggest) I get:

ValueError: Argument `labels` was replaced to the new default value `None` before the specified version 1.7.

If I do both at the same time (as in your comment), I still get:

ValueError: Argument `labels` was replaced to the new default value `None` before the specified version 1.7.

In any of these cases, all tests fail of course.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KumoLiu @ericspod any ideas on how to proceed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@KumoLiu I think we proceed with this one now and see about adjusting the arguments later along with other things that may need removing for 1.6. If you're ok we can resolve the comment and merge. Thank!

name="labels",
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
new_default=None,
msg_suffix=(
"Default value changed to None meaning that the transform now uses the 'space' of a "
"meta-tensor, if applicable, to determine appropriate axis labels."
),
)
def __init__(
self,
keys: KeysCollection,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
allow_missing_keys: bool = False,
lazy: bool = False,
) -> None:
Expand All @@ -564,7 +574,14 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
If ``None``, an appropriate value is chosen depending on the
value of the ``"space"`` metadata item of a metatensor: if
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
input is not a meta-tensor or has no ``"space"`` item, the
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
``None``, the provided value is always used and the ``"space"``
metadata item (if any) of the input is ignored.
allow_missing_keys: don't raise exception if key is missing.
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False
Expand Down
96 changes: 87 additions & 9 deletions tests/transforms/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from typing import cast

import nibabel as nib
import numpy as np
Expand All @@ -21,6 +22,7 @@
from monai.data.meta_obj import set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.transforms import Orientation, create_rotate, create_translate
from monai.utils import SpaceKeys
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.test_utils import TEST_DEVICES, assert_allclose

Expand All @@ -33,6 +35,18 @@
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"LPS",
True,
*device,
]
)
Expand All @@ -43,6 +57,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"ALS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PRS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"PRS",
True,
*device,
]
)
Expand All @@ -53,6 +79,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"LPS",
True,
*device,
]
)
Expand All @@ -63,6 +101,18 @@
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"AL",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PR"},
torch.arange(6).reshape((2, 1, 3)),
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"PR",
True,
*device,
]
)
Expand All @@ -73,6 +123,18 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "R"},
torch.arange(6).reshape((2, 3)),
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"R",
True,
*device,
]
)
Expand All @@ -83,6 +145,7 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
Expand All @@ -93,6 +156,7 @@
torch.as_tensor(np.diag([-1, 1])),
torch.arange(6).reshape((2, 3)),
"L",
False,
*device,
]
)
Expand All @@ -107,6 +171,7 @@
),
torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),
"LPS",
False,
*device,
]
)
Expand All @@ -121,6 +186,7 @@
),
torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),
"RAS",
False,
*device,
]
)
Expand All @@ -131,6 +197,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[3, 0], [4, 1], [5, 2]]]),
"RA",
False,
*device,
]
)
Expand All @@ -141,6 +208,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[2, 5], [1, 4], [0, 3]]]),
"LP",
False,
*device,
]
)
Expand All @@ -151,6 +219,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"LPID",
False,
*device,
]
)
Expand All @@ -161,6 +230,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"RASD",
False,
*device,
]
)
Expand All @@ -175,6 +245,11 @@
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
]

TESTS_INVERSE = []
for device in TEST_DEVICES:
TESTS_INVERSE.append([True, *device])
TESTS_INVERSE.append([False, *device])


class TestOrientationCase(unittest.TestCase):
@parameterized.expand(TESTS)
Expand All @@ -185,17 +260,20 @@ def test_ornt_meta(
affine: torch.Tensor,
expected_data: torch.Tensor,
expected_code: str,
lps_convention: bool,
device,
):
img = MetaTensor(img, affine=affine).to(device)
meta = {"space": SpaceKeys.LPS} if lps_convention else None
img = MetaTensor(img, affine=affine, meta=meta).to(device)
ornt = Orientation(**init_param)
call_param = {"data_array": img}
res = ornt(**call_param) # type: ignore[arg-type]
if img.ndim in (3, 4):
test_resampler_lazy(ornt, res, init_param, call_param)

assert_allclose(res, expected_data.to(device))
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore
labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore
self.assertEqual("".join(new_code), expected_code)

@parameterized.expand(TESTS_TORCH)
Expand Down Expand Up @@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
with self.assertRaises(ValueError):
Orientation(**init_param)(img)

@parameterized.expand(TEST_DEVICES)
def test_inverse(self, device):
@parameterized.expand(TESTS_INVERSE)
def test_inverse(self, lps_convention: bool, device):
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
affine = torch.tensor(
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
)
meta = {"fname": "somewhere"}
meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}
img = MetaTensor(img_t, affine=affine, meta=meta)
tr = Orientation("LPS")
# check that image and affine have changed
img = tr(img)
img = cast(MetaTensor, tr(img))
self.assertNotEqual(img.shape, img_t.shape)
self.assertGreater((affine - img.affine).max(), 0.5)
self.assertGreater(float((affine - img.affine).max()), 0.5)
# check that with inverse, image affine are back to how they were
img = tr.inverse(img)
img = cast(MetaTensor, tr.inverse(img))
self.assertEqual(img.shape, img_t.shape)
self.assertLess((affine - img.affine).max(), 1e-2)
self.assertLess(float((affine - img.affine).max()), 1e-2)


if __name__ == "__main__":
Expand Down
Loading
Loading