antoyang / TubeDETR

[CVPR 2022 Oral] TubeDETR: Spatio-Temporal Video Grounding with Transformers
Apache License 2.0
167 stars 8 forks source link

hi, I didn't find the visualization code for the part of Time-aligned cross-attention visualization #20

Closed freeman-1995 closed 12 months ago

freeman-1995 commented 1 year ago

appreciate your nice work, can u provided the visualization code for analysis more convient

antoyang commented 12 months ago

Hey, I did not check it much, but this should look like the following:

import json
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast
import seaborn as sns

parser = argparse.ArgumentParser()
parser.add_argument(
    '--idx',
    type=int,
    default=5,
    help='Index of video')
parser.add_argument(
    '--pred_path',
    default="",
    type=str)
parser.add_argument(
    '--viz_attn',
    action='store_true'
)
parser.add_argument(
    '--out_dir',
    default="",
    type=str)  # where to save output viz
parser.add_argument(
    '--data_dir',
    default="",
    type=str)  # path to annotations
parser.add_argument(
    '--video_path',
    default="",
    type=str)  # path to videos stored as image files
parser.add_argument(
    '--outvideo_path',
    default="",
    type=str)  # where to save output videos stored as image files
parser.add_argument(
    '--save_vid',
    action='store_true'
)
args = parser.parse_args()

fps = 5
video_max_len = 200
ann_file=f"{args.data_dir}/test.json"
anns = json.load(open(ann_file, 'r'))

test=json.load(open(f"{args.data_dir}/vidor_validation.json", "r"))
videos = [x for x in anns["videos"] if int(x["video_id"]) == args.idx]
print([x['caption'] for x in videos])
print(len(videos))

if args.pred_path:
    preds = json.load(open(args.pred_path, 'r'))['test_vidstg_vidstg']
    predictions, video_predictions = preds['predictions'], preds['video_predictions']

print("annotations loaded")
video = videos[0]
video_original_id = video['original_video_id']
print(video_original_id)
print(video["caption"])

if args.viz_attn:
    tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', local_files_only=True)
    tks = tokenizer.encode(video['caption'])
    tokens = [tokenizer.decode([i]) for i in tks]

video_fps = video["fps"]  # used for extraction
sampling_rate = fps / video_fps
start_frame = video["start_frame"]
end_frame = video["end_frame"]
frame_ids = [start_frame]
for frame_id in range(start_frame, end_frame):
    if int(frame_ids[-1] * sampling_rate) < int(frame_id * sampling_rate):
        frame_ids.append(frame_id)

if len(frame_ids) > video_max_len:  # subsample at video_max_len
    frame_ids = [frame_ids[(j * len(frame_ids)) // video_max_len] for j in range(video_max_len)]

if args.pred_path:
    print("predicted start", video_predictions[str(args.idx)]["sted"][0] / video["fps"])
    print("predicted end", video_predictions[str(args.idx)]["sted"][1] / video["fps"])
print("gt start", video['tube_start_frame'] / video["fps"])
print("gt end", video['tube_end_frame'] / video["fps"])

if args.pred_path:
    interf = [frame_id for frame_id in frame_ids if min(video_predictions[str(args.idx)]["sted"][0],
                                                    video['tube_start_frame']) <= frame_id < max(video_predictions[str(args.idx)]["sted"][1],
                                                                                                 video['tube_end_frame'])]
else:
    add = 30
    interf = [frame_id for frame_id in frame_ids if video['tube_start_frame'] - add <= frame_id < video['tube_end_frame'] + add]
if args.save_vid or args.viz_attn:
    interf = frame_ids
frame2idx = {frame_id: i_frame for i_frame, frame_id in enumerate(frame_ids)}
trajectory = anns['trajectories'][video_original_id][str(video['target_id'])]
video_id = video['video_id']

if not args.viz_attn:
    os.makedirs(f'{args.outvideo_path}/{video_id}pred', exist_ok=True)
if args.viz_attn:
    os.makedirs(f'{args.outvideo_path}/{video_id}predviz', exist_ok=True)
    cross_attn = np.array(preds['spatial_weights'][str(video_id)]) # thw
    text_attn = np.array(preds['text_weights'][str(video_id)]) # tl
    tsa_attn = np.array(preds['tsa_weights'][str(video_id)]) # tt
    pred_sted = torch.tensor(preds['pred_sted'][str(video_id)]).softmax(0).numpy() # t2
for img_id in tqdm(interf):
    img_path = os.path.join(args.video_path, video_original_id, str(int(img_id) + 1).zfill(5) + '.jpg')
    img = Image.open(img_path).convert('RGB')
    imgw,imgh= img.size
    fig, ax = plt.subplots()
    ax.axis('off')
    if args.viz_attn:
        cross_attn_frame = cross_attn[frame2idx[img_id]]
        mask = cv2.resize(cross_attn_frame / cross_attn_frame.max(), img.size)[..., np.newaxis]
        img = (mask * np.array(img)).astype("uint8")
    ax.imshow(img, aspect='auto')
    if video['tube_start_frame'] <= img_id < video['tube_end_frame']:
        x1, y1, w, h = trajectory[str(img_id)]['bbox']
        if args.pred_path:
            rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='#2FFF00', fill=False)  # facecolor='none'
        else:
            rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', fill=False)
        ax.add_patch(rect)
    if args.pred_path:
        if video_predictions[str(args.idx)]["sted"][0] <= img_id < video_predictions[str(args.idx)]["sted"][1]:
            x1, y1, x2, y2 = predictions[f"{args.idx}_{img_id}"]['boxes'][0]
            w = x2 - x1
            h = y2 - y1
            rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='#FAFF00', fill=False)  # facecolor='none'
            ax.add_patch(rect)
    fig.set_dpi(200)
    fig.set_size_inches(imgw / 200, imgh / 200)
    if args.viz_attn:
        fig.savefig(f'{args.outvideo_path}/{video_id}predviz/{str(img_id).zfill(5)}.jpg', format='jpg',
                    bbox_inches='tight', pad_inches=0)
    else:
        fig.savefig(f'{args.outvideo_path}/{video_id}pred/{str(img_id).zfill(5)}.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
    plt.close(fig)
if args.save_vid:
    if args.viz_attn:
        os.system(
            f"ffmpeg -r 5 -pattern_type glob -i '{args.outvideo_path}/{video_id}predviz/*.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -r 5 -crf 25 -c:v libx264 -pix_fmt yuv420p -movflags +faststart {args.out_dir}/{video_id}vizattn.mp4")
    else:
        os.system(f"ffmpeg -r 5 -pattern_type glob -i '{args.outvideo_path}/{video_id}pred/*.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -r 5 -crf 25 -c:v libx264 -pix_fmt yuv420p -movflags +faststart {args.out_dir}/{video_id}.mp4")

if args.viz_attn:
    fig, ax = plt.subplots(figsize=(8, 16))
    sns.heatmap(text_attn / text_attn.max(1)[:, np.newaxis], vmin = 0, vmax = 1, ax = ax, xticklabels = tokens)
    ax.set_ylabel('time')
    fig.savefig(f'{args.out_dir}/{video_id}text.jpg', format='jpg', bbox_inches='tight',
                pad_inches=0)
    plt.close(fig)

    fig, ax = plt.subplots(figsize=(20, 16))
    sns.heatmap(tsa_attn / tsa_attn.max(1)[:, np.newaxis], vmin=0, vmax=1, ax=ax)
    ax.set_ylabel('time')
    ax.set_xlabel('time')
    fig.savefig(f'{args.out_dir}/{video_id}tsa.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
    plt.close(fig)

    fig, ax = plt.subplots()
    plot1 = ax.plot(pred_sted[:, 0], '-o', color = 'b', label = 'start probability')
    plot2 = ax.plot(pred_sted[:, 1], '-o', color = 'r', label = 'end probability')
    ax.legend(['start probability', 'end probability'])
    fig.savefig(f'{args.out_dir}/{video_id}sted.jpg', format='jpg', bbox_inches='tight',
                pad_inches=0)
    plt.close(fig)