yeliudev / R2-Tuning

🌀 R^2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding (ECCV 2024)
http://arxiv.org/abs/2404.00801
BSD 3-Clause "New" or "Revised" License
62 stars 1 forks source link

Inference on long videos #9

Closed ir2718 closed 4 months ago

ir2718 commented 4 months ago

Hi,

first of all thanks for the repository, it's very cool.

I'm trying to do inference on a long video (50 mins), but I keep encountering an error regarding the buffer size:

  File "home/xxx/R2-Tuning/models/generator.py", line 59, in forward
    assert size <= buffer.size(0), 'reached max buffer size'
           ^^^^^^^^^^^^^^^^^^^^^^
AssertionError: reached max buffer size

What's the preferred way to do inference in this case? Is chunking into smaller videos the solution?

yeliudev commented 4 months ago

Hi @ir2718. Thanks for your interest in our work!

For long videos, if you want to re-train the model, you may try to set larger values (>1024) at here and here. If you would like to run inference using the provided checkpoints, it's better to split the video into chunks with ~150s long.

ir2718 commented 4 months ago

@yeliudev

I'm trying to run inference with the trained checkpoints. Does this look okay to you?

    #### do chunking on the fly
    chunk_size = cfg.model.buffer_size
    dicts = [
        dict(video=x, query=data["query"], fps=data["fps"])
        for x in torch.split(data["video"], chunk_size, dim=1)
    ]

    #### do inference on each chunk
    preds = {
        "_out": {
            "boundary": [],
        }
    }
    top_probs = []
    with torch.inference_mode():
        for i, data in enumerate(dicts):
            pred = model(data)
            top_probs.append(pred["_out"]["boundary"][:cfg.model.max_num_moment, 2])
            preds["_out"]["boundary"].append(
                pred["_out"]["boundary"] +
                torch.tensor([
                    chunk_size * i / cfg.data["test"].fps,
                    chunk_size * i / cfg.data["test"].fps,
                    0.,

                ], device=pred["_out"]["boundary"].device)
            )

    #### connect processed chunks
    preds["_out"]["boundary"] = torch.cat(preds["_out"]["boundary"], dim=0)
    top_probs_indices = torch.argsort(torch.cat(top_probs, dim=0)).flip(dims=(0,))
    preds["_out"]["boundary"] = preds["_out"]["boundary"][top_probs_indices]
yeliudev commented 4 months ago

Yeah the code looks good to me. If you are using the checkpoint trained on QVHighlights, it's better to have chunk_size = 75, as the videos in QVHighlights are all truncated to 150s long with 0.5 FPS.