pytorch / torchcodec

PyTorch video decoding
BSD 3-Clause "New" or "Revised" License
77 stars 9 forks source link

Fix binary search of getFramesDisplayedByTimestamps #286

Closed NicolasHug closed 2 weeks ago

NicolasHug commented 2 weeks ago

This PR fixes the pts -> index conversion logic of getFramesDisplayedByTimestamps, and adds a non-regression test (which checks all pts-based APIs, not just the buggy one).

TL;DR we were doing a binary search on an array that was NOT sorted, because the last's frame nextPts is 0. The fix is to exclude this frame from the binary search to keep the search space sorted. That last frame is still correctly returned when it's supposed to be returned.

Background 1

On our favorite nasa_13013.mp4 video, the stream.allFrames of the last few frame infos are as such:

info.pts    info.nextPts
...
...
12.846167   12.879533
12.879533   12.912900
12.912900   12.946267
12.946267   12.979633   <--- second to last frame
12.979633   0.000000    <--- last frame

Background 2

For C++ noobs like me. Our code in main looks like this. This basically means "return the first element in stream.AllFrames" where the condition in the lambda is false. I.e. this returns the first frame f for which f.nextPts > framePts. Logically this is all sound: the returned frame f should be the one that is displayed at framePts.

    auto it = std::lower_bound(
        stream.allFrames.begin(),
        stream.allFrames.end(),
        framePts,
        [&stream](const FrameInfo& info, double framePts) {
          return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
        });
    int64_t frameIndex = it - stream.allFrames.begin();
    frameIndex = min(frameIndex, stream.allFrames.size() - 1);

Note: we need frameIndex = min(frameIndex, stream.allFrames.size() - 1); at the end, because if no element is found such that the lambda is false, std::lower_bound() returns its END parameter. If we query for the last frame e.g. framePts = 12.98, the stop condition is never met, and we end up with a return index of 390 which we need to restrict to 389.

Everything described so far is correct and logically sound, the bug isn't related to the last frame, it's related to the second-to-last frame.

The error

If we request a frame at pts = 12.95, we should comfortably get the second-to-last frame. But on main, we get the last frame instead:

from torchcodec.decoders._core import (
    _add_video_stream,
    create_from_file,
    get_frames_by_pts,
    scan_all_streams_to_update_metadata,
)        

path = "test/resources/nasa_13013.mp4"

decoder = create_from_file(path)
_add_video_stream(decoder)
scan_all_streams_to_update_metadata(decoder)
stream_index = 3

_, pts_seconds, _= get_frames_by_pts(
    decoder,
    stream_index=stream_index,
    timestamps=[12.95]  # should be second-to-last frame
)

print(pts_seconds)  # Gives 12.979633, which is the last frame! this is wrong

Why this happens

Instrumenting the binary search, we can see the following when we ask for 12.95:

Comparing 6.539867 with 12.950000
Comparing 9.809800 with 12.950000
Comparing 11.444767 with 12.950000
Comparing 12.245567 with 12.950000
Comparing 12.645967 with 12.950000
Comparing 12.846167 with 12.950000
Comparing 12.946267 with 12.950000  <-- At the frame *before* the second-to-last one!
Comparing 0.000000 with 12.950000    <-- At the last frame

The binary search is doing its job: it's binary searching. And at some point, the search space is restricted to the last 3 frames. It's clear that the binary search never gets a chance to try the second-to-last frame! The problem is that the last frame's pts is 0, and clearly this is the first time the binary search sees that its condition is false, so it returns.

So this is it: the problem is that we hard-code the last frame's nextPts field to 0. Conceptually, this thing is undefined, but we still need to give it a value. And because it's 0, it breaks the working assumption of std::lower_bound() that the elements must be sorted.

The fix

I think the fix is to slightly change the search such that the last frame is never considered (see diff). We can also remove the min() logic.

What happens now when we query for the second-to-last-frame:

Comparing 6.506500 with 12.950000
Comparing 9.776433 with 12.950000
Comparing 11.411400 with 12.950000
Comparing 12.212200 with 12.950000
Comparing 12.612600 with 12.950000
Comparing 12.812800 with 12.950000
Comparing 12.912900 with 12.950000
Comparing 12.979633 with 12.950000  < -- At the second-to-last frame, gets returned
Comparing 12.946267 with 12.950000  < -- At the frame *before* the second-to-last one!

Here the binary search finds that the first element for which the condition is False is the second-to-last frame. It doesn't get a change to be fooled by the last frame's nextPts of 0 because we don't let it consider this value.

What happens now when we query for the last frame:

Comparing 6.506500 with 12.980000
Comparing 9.776433 with 12.980000
Comparing 11.411400 with 12.980000
Comparing 12.212200 with 12.980000
Comparing 12.612600 with 12.980000
Comparing 12.812800 with 12.980000
Comparing 12.912900 with 12.980000
Comparing 12.979633 with 12.980000  < -- At the second-to-last frame

Here the binary search never finds its stop condition, so it returns stream.allFrames.end() - 1 which corresponds to the last frame.

Wait but doesn't that mean that we just shifted the bug by one frame???

No, the real issue was that the last frame's nextPts field is 0. The array on which we were doing a binary search wasn't sorted, the last element broke the sort. Now, the last frame in the binary search space never has a nextPts field of 0, the search space is sorted.

scotts commented 2 weeks ago

Bra-freaking-vo. Excellent catch. Of my bug. :) And based on your reasoning, I think we also have the same problem in getFramesDisplayedByTimestampInRange: https://github.com/pytorch/torchcodec/blob/b841eb3f40fcd17de39c56d3ba524e3777783438/src/torchcodec/decoders/_core/VideoDecoder.cpp#L1240-L1246

As an alternate fix, that I think will address both problems and perhaps lead to more idiomatic looking code, we could change: https://github.com/pytorch/torchcodec/blob/b841eb3f40fcd17de39c56d3ba524e3777783438/src/torchcodec/decoders/_core/VideoDecoder.h#L300-L303 To instead be:

  struct FrameInfo {
    int64_t pts = 0;
    int64_t nextPts = INT64_MAX;
  };

That makes nextPts more in line with how we're using it.

NicolasHug commented 2 weeks ago

Closing in favor of https://github.com/pytorch/torchcodec/pull/287