Skip to content

Conversation

@NicolasHug
Copy link
Contributor

@NicolasHug NicolasHug commented Oct 21, 2025

Towards #943

This PR speeds up the BETA Cuda interface when we need to fallback on the CPU. The idea is simple: we do the color conversion on the GPU instead of doing them on the CPU. This has 2 benefits:

  • the color conversion is faster since it runs on GPU
  • the CPU -> GPU transfer is faster since we now transfer a smaller YUV frame instead of a bigger RGB frame.

Before

Decode Frame on CPU (fallback) -> YUV to RGB Conversion **on CPU** -> Send bigger RGB frame from CPU to GPU

Now

Decode Frame on CPU (fallback) -> Send smaller YUV frame from CPU to GPU -> YUV to RGB conversion **on GPU**

This is for the BETA interface only. I'll handle the FFmpeg interface as a follow-up.

Benchmarks

Benchmarks show 1.6X speed up on 1080p frames. We should expect the larger speed-ups for larger resolutions. I used this snippet.

import torch
from time import perf_counter_ns
import argparse
from pathlib import Path
from torchcodec.decoders import VideoDecoder, set_cuda_backend
from joblib import Parallel, delayed


def bench(f, *args, num_exp=100, warmup=0, **kwargs):
    for _ in range(warmup):
        f(*args, **kwargs)

    times = []
    for _ in range(num_exp):
        start = perf_counter_ns()
        f(*args, **kwargs)
        end = perf_counter_ns()
        times.append(end - start)
    return torch.tensor(times).float()


def report_stats(times, unit="ms"):
    mul = {
        "ns": 1,
        "µs": 1e-3,
        "ms": 1e-6,
        "s": 1e-9,
    }[unit]
    times = times * mul
    std = times.std().item()
    med = times.median().item()
    print(f"{med = :.2f}{unit} +- {std:.2f}")
    return med


def decode_one_video(video_path):
    with set_cuda_backend("beta"):
       decoder = VideoDecoder(str(video_path), device="cuda:0", seek_mode="approximate")
    indices = torch.arange(len(decoder))
    decoder.get_frames_at(indices)

    torch.cuda.synchronize()


parser = argparse.ArgumentParser()
parser.add_argument("video_path")
args = parser.parse_args()

times = bench(decode_one_video, video_path=args.video_path, warmup=1, num_exp=10)
report_stats(times)

It's impossible to benchmark the new and old strategies together, since we need recompilation. Also I had to add modify our code to enforce the CPU fallback to be activated on videos that would otherwise not be falling back (basically just changed if (!nativeNVDECSupport(codecContext))... to if(true))

#  OLD
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 543.16ms +- 56.39
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 533.95ms +- 55.50
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 659.17ms +- 56.99
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 549.27ms +- 53.50
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 538.66ms +- 34.52

#  NEW
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 318.23ms +- 8.21
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 319.42ms +- 11.30
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 325.05ms +- 9.74
~/dev/torchcodec-cuda (fallback-colorconversion*) » python benchmark_fallback.py h264_1080.mp4                           nicolashug@nicolashug-fedora-PW0H326Y
med = 325.94ms +- 12.64

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 21, 2025
@NicolasHug NicolasHug marked this pull request as ready for review October 24, 2025 21:42
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As can be seen above we are now using swscale in BetaCudaInterface.cpp to do the YUV -> NV12 conversion. So I had to extract out all the swscale logic away from CpuDeviceInteface.cpp and put it into FFmpegCommon.cpp so it can be reused across the two interfaces. Almost everything below this comment can be ignored and treated as copy/pasting code around. Just pay attention to the test at the bottom.

beta_frame = beta_dec.get_frame_at(0)

torch.testing.assert_close(ffmpeg.data, beta.data, rtol=0, atol=0)
assert psnr(ref_frames.data.cpu(), beta_frame.data.cpu()) > 25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does comparing frames on GPU vs CPU change floating point precision, or is there some other reason to move the frames here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there will be a small difference due to floating point precision, but that wasn't the reason.
In fact... there was no good reason to call .cpu(), that was probably the result of copy/pasting from my previous write_png debug. I removed it, thanks for catching!

UniqueAVFrame nv12CpuFrame(av_frame_alloc());
TORCH_CHECK(nv12CpuFrame != nullptr, "Failed to allocate NV12 CPU frame");

nv12CpuFrame->format = AV_PIX_FMT_NV12;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it accurate to say that NV12 very similar to AV_PIX_FMT_YUV420P (uses YUV, and has 4:2:0 chroma subsampling), but we use NV12 here because that is the format the NPP library requires? As explained in this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's exactly right. NV12 would contain the exact same values as AV_PIX_FMT_YUV420P, just ordered a bit differently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants