Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install dependencies and FFmpeg
run: |
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge
ffmpeg -version
- name: Build and install torchcodec
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ files = src/torchcodec
show_error_codes = True
pretty = True
allow_redefinition = True
follow_untyped_imports = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 changes: 1 addition & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
# but that results in circular import.
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
from . import decoders, encoders, samplers # noqa
from . import decoders, encoders, samplers, transforms # noqa

try:
# Note that version.py is generated during install.
Expand Down
61 changes: 60 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
from typing import Literal, Optional, Sequence, Tuple, Union

import torch
from torch import device as torch_device, Tensor
Expand All @@ -19,6 +19,7 @@
create_decoder,
ERROR_REPORTING_INSTRUCTIONS,
)
from torchcodec.transforms import DecoderNativeTransform, Resize


class VideoDecoder:
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
transforms: Optional[Sequence[DecoderNativeTransform]] = None,
seek_mode: Literal["exact", "approximate"] = "exact",
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
Expand Down Expand Up @@ -148,13 +150,16 @@ def __init__(

device_variant = _get_cuda_backend()

transform_specs = _make_transform_specs(transforms)

core.add_video_stream(
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
device_variant=device_variant,
transform_specs=transform_specs,
custom_frame_mappings=custom_frame_mappings_data,
)

Expand Down Expand Up @@ -432,6 +437,60 @@ def _get_and_validate_stream_metadata(
)


# This function, _make_transform_specs, and the transforms argument to
# VideoDecoder actually accept a union of DecoderNativeTransform and
# TorchVision transforms. We don't put that in our type annotation because
# that would require importing torchvision at module scope which would mean we
# have a hard dependency on torchvision.
# TODO: better explanation of the above.
def _convert_to_decoder_native_transforms(
transforms: Sequence[DecoderNativeTransform],
) -> Sequence[DecoderNativeTransform]:
try:
from torchvision.transforms import v2

tv_available = True
except ImportError:
tv_available = False

converted_transforms = []
for transform in transforms:
if not isinstance(transform, DecoderNativeTransform):
if not tv_available:
raise ValueError(
f"The supplied transform, {transform}, is not a TorchCodec "
" DecoderNativeTransform. TorchCodec also accept TorchVision "
"v2 transforms, but TorchVision is not installed."
)
if isinstance(transform, v2.Resize):
if len(transform.size) != 2:
raise ValueError(
"TorchVision Resize transform must have a (height, width) "
f"pair for the size, got {transform.size}."
)
converted_transforms.append(Resize(size=transform.size))
else:
raise ValueError(
f"Unsupported transform: {transform}. Transforms must be "
"either a TorchCodec DecoderNativeTransform or a TorchVision "
"v2 transform."
)
else:
converted_transforms.append(transform)

return converted_transforms


def _make_transform_specs(
transforms: Optional[Sequence[DecoderNativeTransform]],
) -> str:
if transforms is None:
return ""

transforms = _convert_to_decoder_native_transforms(transforms)
return ";".join([t.make_params() for t in transforms])

Copy link
Contributor Author

@scotts scotts Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussion point 2: This is what we'll have to do with TorchVision transforms at the moment. We'll need special handling for each transform, looking into its internals to get what we need and enforce decoder-native limitations.

In the future, we can change TorchVision transforms to have an API so that we can get what we need in a generic way. But for now, we'll need to do something like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still undecided on whether we should accept TV transforms or not (ironic, I know), but I think this is totally OK.

And I think we'll need that level of coupling anyway, even if we were to write our own TC transforms. Echoing what you wrote:

If we were to [...] create ]TorchCodec-specific user specification API, we'd want to make sure that its semantics match that of TorchVision. That is, if we had torchcodec.transforms.Resize(height=x, width=y), we'd want to make sure its semantics matched torchvision.transforms.v2.Resize(size=(x,y)). In that specific example, we'd want to make sure that both default to bilinear interpolation. Extrapolating that specific example across all transforms we want to support, we'd basically be creating mirror version of what TorchVision has. That seems silly, since it's more for users to understand and more for us to maintain.

Basically, that coupling between TC and TV will have to exist either in the code (as in this PR), or in our heads as API designers.


Side note, slightly related: if we're going to have our own TC transforms, I think we'll want their API to exactly match (or be a strict subset of) the TV transforms. E.g. we'd have torchcodec.transforms.Resize(size=...) instead of torchcodec.transforms.Resize(height=..., width=...) ?

Copy link
Contributor Author

@scotts scotts Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug, I came to same conclusion as:

Side note, slightly related: if we're going to have our own TC transforms, I think we'll want their API to exactly match (or be a strict subset of) the TV transforms. E.g. we'd have torchcodec.transforms.Resize(size=...) instead of torchcodec.transforms.Resize(height=..., width=...) ?

At which point, I don't think we've really gained anything by having them separate. And users will probably also start asking, hey, can you just accept the TorchVision ones? I also just realized a new counter-point, which I'll put up in the summary as counter point 3.


def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._decoder_native_transforms import DecoderNativeTransform, Resize # noqa
39 changes: 39 additions & 0 deletions src/torchcodec/transforms/_decoder_native_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Sequence


@dataclass
class DecoderNativeTransform(ABC):
"""TODO: docstring"""

@abstractmethod
def make_params(self) -> str:
pass


@dataclass
class Resize(DecoderNativeTransform):
"""
TODO. One benefit of having parallel definitions is that it gives us a place
to put documentation about what behavior we do and do not support. For
example, we don't yet have fields for `interpolation` and `antialias`
because we don't allow users to control those yet in decoder-native
transforms.
"""

# Also note that this type is more restrictive than what TorchVision
# accepts, but it accurately reflects current decoder-native transform
# limitations. We can reflect that not just in our docs, but also type
# annotations.
size: Sequence[int]

def make_params(self) -> str:
assert len(self.size) == 2
return f"resize, {self.size[0]}, {self.size[1]}"
139 changes: 76 additions & 63 deletions test/test_transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_frame_at_index,
get_json_metadata,
)
from torchcodec.decoders import VideoDecoder

from torchvision.transforms import v2

Expand All @@ -34,7 +35,81 @@
TEST_SRC_2_720P,
)

torch._dynamo.config.capture_dynamic_output_shape_ops = True

class TestPublicVideoDecoderTransformOps:
@pytest.mark.parametrize(
"height_scaling_factor, width_scaling_factor",
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)),
)
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
def test_resize_torchvision(
self, video, height_scaling_factor, width_scaling_factor
):
height = int(video.get_height() * height_scaling_factor)
width = int(video.get_width() * width_scaling_factor)

decoder_resize = VideoDecoder(
video.path, transforms=[v2.Resize(size=(height, width))]
)

decoder_full = VideoDecoder(video.path)

num_frames = len(decoder_resize)
assert num_frames == len(decoder_full)

for frame_index in [
0,
int(num_frames * 0.1),
int(num_frames * 0.2),
int(num_frames * 0.3),
int(num_frames * 0.4),
int(num_frames * 0.5),
int(num_frames * 0.75),
int(num_frames * 0.90),
num_frames - 1,
]:
expected_shape = (video.get_num_color_channels(), height, width)
frame_resize = decoder_resize[frame_index]
frame_full = decoder_full[frame_index]

frame_tv = v2.functional.resize(frame_full, size=(height, width))
frame_tv_no_antialias = v2.functional.resize(
frame_full, size=(height, width), antialias=False
)

assert frame_resize.shape == expected_shape
assert frame_tv.shape == expected_shape
assert frame_tv_no_antialias.shape == expected_shape

assert_tensor_close_on_at_least(
frame_resize, frame_tv, percentage=99.8, atol=1
)
torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6)

if height_scaling_factor < 1 or width_scaling_factor < 1:
# Antialias only relevant when down-scaling!
with pytest.raises(AssertionError, match="Expected at least"):
assert_tensor_close_on_at_least(
frame_resize, frame_tv_no_antialias, percentage=99, atol=1
)
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
torch.testing.assert_close(
frame_resize, frame_tv_no_antialias, rtol=0, atol=6
)

def test_resize_fails(self):
with pytest.raises(
ValueError,
match=r"must have a \(height, width\) pair for the size",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))])

def test_transform_fails(self):
with pytest.raises(
ValueError,
match="Unsupported transform",
):
VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)])


class TestCoreVideoDecoderTransformOps:
Expand Down Expand Up @@ -172,68 +247,6 @@ def test_transform_fails(self):
):
add_video_stream(decoder, transform_specs="invalid, 1, 2")

@pytest.mark.parametrize(
"height_scaling_factor, width_scaling_factor",
((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0), (2.0, 2.0)),
)
@pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P])
def test_resize_torchvision(
self, video, height_scaling_factor, width_scaling_factor
):
num_frames = self.get_num_frames_core_ops(video)

height = int(video.get_height() * height_scaling_factor)
width = int(video.get_width() * width_scaling_factor)
resize_spec = f"resize, {height}, {width}"

decoder_resize = create_from_file(str(video.path))
add_video_stream(decoder_resize, transform_specs=resize_spec)

decoder_full = create_from_file(str(video.path))
add_video_stream(decoder_full)

for frame_index in [
0,
int(num_frames * 0.1),
int(num_frames * 0.2),
int(num_frames * 0.3),
int(num_frames * 0.4),
int(num_frames * 0.5),
int(num_frames * 0.75),
int(num_frames * 0.90),
num_frames - 1,
]:
expected_shape = (video.get_num_color_channels(), height, width)
frame_resize, *_ = get_frame_at_index(
decoder_resize, frame_index=frame_index
)

frame_full, *_ = get_frame_at_index(decoder_full, frame_index=frame_index)
frame_tv = v2.functional.resize(frame_full, size=(height, width))
frame_tv_no_antialias = v2.functional.resize(
frame_full, size=(height, width), antialias=False
)

assert frame_resize.shape == expected_shape
assert frame_tv.shape == expected_shape
assert frame_tv_no_antialias.shape == expected_shape

assert_tensor_close_on_at_least(
frame_resize, frame_tv, percentage=99.8, atol=1
)
torch.testing.assert_close(frame_resize, frame_tv, rtol=0, atol=6)

if height_scaling_factor < 1 or width_scaling_factor < 1:
# Antialias only relevant when down-scaling!
with pytest.raises(AssertionError, match="Expected at least"):
assert_tensor_close_on_at_least(
frame_resize, frame_tv_no_antialias, percentage=99, atol=1
)
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
torch.testing.assert_close(
frame_resize, frame_tv_no_antialias, rtol=0, atol=6
)

def test_resize_ffmpeg(self):
height = 135
width = 240
Expand Down
Loading