diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index 6083986e1..aa1c2d2e7 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -42,7 +42,7 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { std::string ResizeTransform::getFilterGraphCpu() const { return "scale=" + std::to_string(outputDims_.width) + ":" + std::to_string(outputDims_.height) + - ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); + ":flags=" + toFilterGraphInterpolation(interpolationMode_); } std::optional ResizeTransform::getOutputFrameDims() const { diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 130927c2e..5a29b8b19 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ import json import numbers from pathlib import Path -from typing import Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union import torch from torch import device as torch_device, Tensor @@ -103,6 +103,7 @@ def __init__( dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, device: Optional[Union[str, torch_device]] = "cpu", + transforms: List[Any] = [], # TRANSFORMS TODO: what is the user-facing type? seek_mode: Literal["exact", "approximate"] = "exact", custom_frame_mappings: Optional[ Union[str, bytes, io.RawIOBase, io.BufferedReader] @@ -148,6 +149,8 @@ def __init__( device_variant = _get_cuda_backend() + transform_specs = make_transform_specs(transforms) + core.add_video_stream( self._decoder, stream_index=stream_index, @@ -155,6 +158,7 @@ def __init__( num_threads=num_ffmpeg_threads, device=device, device_variant=device_variant, + transform_specs=transform_specs, custom_frame_mappings=custom_frame_mappings_data, ) @@ -432,6 +436,22 @@ def _get_and_validate_stream_metadata( ) +def make_transform_specs(transforms: List[Any]) -> str: + from torchvision.transforms import v2 + + transform_specs = [] + for transform in transforms: + if isinstance(transform, v2.Resize): + if len(transform.size) != 2: + raise ValueError( + f"Resize transform must have a (height, width) pair for the size, got {transform.size}." + ) + transform_specs.append(f"resize, {transform.size[0]}, {transform.size[1]}") + else: + raise ValueError(f"Unsupported transform {transform}.") + return ";".join(transform_specs) + + def _read_custom_frame_mappings( custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] ) -> tuple[Tensor, Tensor, Tensor]: diff --git a/test/generate_reference_resources.py b/test/generate_reference_resources.py index 953fb996e..18a0183b2 100644 --- a/test/generate_reference_resources.py +++ b/test/generate_reference_resources.py @@ -123,6 +123,15 @@ def generate_nasa_13013_references(): NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter ) + frames = [17, 230, 389] + # Note that the resize algorithm passed to flags is exposed to users, + # but bilinear is the default we use. + resize_filter = "scale=240:135:flags=bilinear" + for frame in frames: + generate_frame_by_index( + NASA_VIDEO, frame_index=frame, stream_index=3, filters=resize_filter + ) + def generate_h265_video_references(): # This video was generated by running the following: diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt new file mode 100644 index 000000000..5da3e81fe Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000017.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt new file mode 100644 index 000000000..5094e44da Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000230.pt differ diff --git a/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt new file mode 100644 index 000000000..a15622389 Binary files /dev/null and b/test/resources/nasa_13013.mp4.scale_240_135_flags_bilinear.stream3.frame000389.pt differ diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 8d1ba5e53..238d44da1 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -22,15 +22,75 @@ get_json_metadata, get_next_frame, ) +from torchcodec.decoders import VideoDecoder from torchvision.transforms import v2 -from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda +from .utils import assert_frames_equal, NASA_VIDEO, needs_cuda, psnr torch._dynamo.config.capture_dynamic_output_shape_ops = True -class TestVideoDecoderTransformOps: +class TestPublicVideoDecoderTransformOps: + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((1.5, 1.31), (0.5, 0.71), (0.7, 1.31), (1.5, 0.71), (1.0, 1.0)), + ) + def test_resize_torchvision(self, height_scaling_factor, width_scaling_factor): + height = int(NASA_VIDEO.get_height() * height_scaling_factor) + width = int(NASA_VIDEO.get_width() * width_scaling_factor) + + decoder_resize = VideoDecoder( + NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))] + ) + decoder_full = VideoDecoder(NASA_VIDEO.path) + for frame_index in [0, 10, 17, 100, 230, 389]: + expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width) + frame_resize = decoder_resize[frame_index] + + frame_full = decoder_full[frame_index] + frame_tv = v2.functional.resize(frame_full, size=(height, width)) + + assert frame_resize.shape == expected_shape + assert frame_tv.shape == expected_shape + + # Copied from PR #992; not sure if it's the best way to check + assert psnr(frame_resize, frame_tv) > 25 + + def test_resize_ffmpeg(self): + height = 135 + width = 240 + expected_shape = (NASA_VIDEO.get_num_color_channels(), height, width) + resize_filtergraph = f"scale={width}:{height}:flags=bilinear" + decoder_resize = VideoDecoder( + NASA_VIDEO.path, transforms=[v2.Resize(size=(height, width))] + ) + for frame_index in [17, 230, 389]: + frame_resize = decoder_resize[frame_index] + frame_ref = NASA_VIDEO.get_frame_data_by_index( + frame_index, filters=resize_filtergraph + ) + + assert frame_resize.shape == expected_shape + assert frame_ref.shape == expected_shape + assert_frames_equal(frame_resize, frame_ref) + + def test_resize_fails(self): + with pytest.raises( + ValueError, + match=r"must have a \(height, width\) pair for the size", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + + def test_transform_fails(self): + with pytest.raises( + ValueError, + match="Unsupported transform", + ): + VideoDecoder(NASA_VIDEO.path, transforms=[v2.RandomHorizontalFlip(p=1.0)]) + + +class TestCoreVideoDecoderTransformOps: # We choose arbitrary values for width and height scaling to get better # test coverage. Some pairs upscale the image while others downscale it. @pytest.mark.parametrize( diff --git a/test/utils.py b/test/utils.py index cbd6a5bf4..43f29cf5a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -430,7 +430,7 @@ def empty_chw_tensor(self) -> torch.Tensor: [0, self.num_color_channels, self.height, self.width], dtype=torch.uint8 ) - def get_width(self, *, stream_index: Optional[int]) -> int: + def get_width(self, *, stream_index: Optional[int] = None) -> int: if stream_index is None: stream_index = self.default_stream_index