Closed freeman-1995 closed 12 months ago
Hey, I did not check it much, but this should look like the following:
import json
import os
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse
import torch
import cv2
import numpy as np
from transformers import RobertaTokenizerFast
import seaborn as sns
parser = argparse.ArgumentParser()
parser.add_argument(
'--idx',
type=int,
default=5,
help='Index of video')
parser.add_argument(
'--pred_path',
default="",
type=str)
parser.add_argument(
'--viz_attn',
action='store_true'
)
parser.add_argument(
'--out_dir',
default="",
type=str) # where to save output viz
parser.add_argument(
'--data_dir',
default="",
type=str) # path to annotations
parser.add_argument(
'--video_path',
default="",
type=str) # path to videos stored as image files
parser.add_argument(
'--outvideo_path',
default="",
type=str) # where to save output videos stored as image files
parser.add_argument(
'--save_vid',
action='store_true'
)
args = parser.parse_args()
fps = 5
video_max_len = 200
ann_file=f"{args.data_dir}/test.json"
anns = json.load(open(ann_file, 'r'))
test=json.load(open(f"{args.data_dir}/vidor_validation.json", "r"))
videos = [x for x in anns["videos"] if int(x["video_id"]) == args.idx]
print([x['caption'] for x in videos])
print(len(videos))
if args.pred_path:
preds = json.load(open(args.pred_path, 'r'))['test_vidstg_vidstg']
predictions, video_predictions = preds['predictions'], preds['video_predictions']
print("annotations loaded")
video = videos[0]
video_original_id = video['original_video_id']
print(video_original_id)
print(video["caption"])
if args.viz_attn:
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base', local_files_only=True)
tks = tokenizer.encode(video['caption'])
tokens = [tokenizer.decode([i]) for i in tks]
video_fps = video["fps"] # used for extraction
sampling_rate = fps / video_fps
start_frame = video["start_frame"]
end_frame = video["end_frame"]
frame_ids = [start_frame]
for frame_id in range(start_frame, end_frame):
if int(frame_ids[-1] * sampling_rate) < int(frame_id * sampling_rate):
frame_ids.append(frame_id)
if len(frame_ids) > video_max_len: # subsample at video_max_len
frame_ids = [frame_ids[(j * len(frame_ids)) // video_max_len] for j in range(video_max_len)]
if args.pred_path:
print("predicted start", video_predictions[str(args.idx)]["sted"][0] / video["fps"])
print("predicted end", video_predictions[str(args.idx)]["sted"][1] / video["fps"])
print("gt start", video['tube_start_frame'] / video["fps"])
print("gt end", video['tube_end_frame'] / video["fps"])
if args.pred_path:
interf = [frame_id for frame_id in frame_ids if min(video_predictions[str(args.idx)]["sted"][0],
video['tube_start_frame']) <= frame_id < max(video_predictions[str(args.idx)]["sted"][1],
video['tube_end_frame'])]
else:
add = 30
interf = [frame_id for frame_id in frame_ids if video['tube_start_frame'] - add <= frame_id < video['tube_end_frame'] + add]
if args.save_vid or args.viz_attn:
interf = frame_ids
frame2idx = {frame_id: i_frame for i_frame, frame_id in enumerate(frame_ids)}
trajectory = anns['trajectories'][video_original_id][str(video['target_id'])]
video_id = video['video_id']
if not args.viz_attn:
os.makedirs(f'{args.outvideo_path}/{video_id}pred', exist_ok=True)
if args.viz_attn:
os.makedirs(f'{args.outvideo_path}/{video_id}predviz', exist_ok=True)
cross_attn = np.array(preds['spatial_weights'][str(video_id)]) # thw
text_attn = np.array(preds['text_weights'][str(video_id)]) # tl
tsa_attn = np.array(preds['tsa_weights'][str(video_id)]) # tt
pred_sted = torch.tensor(preds['pred_sted'][str(video_id)]).softmax(0).numpy() # t2
for img_id in tqdm(interf):
img_path = os.path.join(args.video_path, video_original_id, str(int(img_id) + 1).zfill(5) + '.jpg')
img = Image.open(img_path).convert('RGB')
imgw,imgh= img.size
fig, ax = plt.subplots()
ax.axis('off')
if args.viz_attn:
cross_attn_frame = cross_attn[frame2idx[img_id]]
mask = cv2.resize(cross_attn_frame / cross_attn_frame.max(), img.size)[..., np.newaxis]
img = (mask * np.array(img)).astype("uint8")
ax.imshow(img, aspect='auto')
if video['tube_start_frame'] <= img_id < video['tube_end_frame']:
x1, y1, w, h = trajectory[str(img_id)]['bbox']
if args.pred_path:
rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='#2FFF00', fill=False) # facecolor='none'
else:
rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='r', fill=False)
ax.add_patch(rect)
if args.pred_path:
if video_predictions[str(args.idx)]["sted"][0] <= img_id < video_predictions[str(args.idx)]["sted"][1]:
x1, y1, x2, y2 = predictions[f"{args.idx}_{img_id}"]['boxes'][0]
w = x2 - x1
h = y2 - y1
rect = plt.Rectangle((x1, y1), w, h, linewidth=2, edgecolor='#FAFF00', fill=False) # facecolor='none'
ax.add_patch(rect)
fig.set_dpi(200)
fig.set_size_inches(imgw / 200, imgh / 200)
if args.viz_attn:
fig.savefig(f'{args.outvideo_path}/{video_id}predviz/{str(img_id).zfill(5)}.jpg', format='jpg',
bbox_inches='tight', pad_inches=0)
else:
fig.savefig(f'{args.outvideo_path}/{video_id}pred/{str(img_id).zfill(5)}.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
plt.close(fig)
if args.save_vid:
if args.viz_attn:
os.system(
f"ffmpeg -r 5 -pattern_type glob -i '{args.outvideo_path}/{video_id}predviz/*.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -r 5 -crf 25 -c:v libx264 -pix_fmt yuv420p -movflags +faststart {args.out_dir}/{video_id}vizattn.mp4")
else:
os.system(f"ffmpeg -r 5 -pattern_type glob -i '{args.outvideo_path}/{video_id}pred/*.jpg' -vf 'pad=ceil(iw/2)*2:ceil(ih/2)*2' -r 5 -crf 25 -c:v libx264 -pix_fmt yuv420p -movflags +faststart {args.out_dir}/{video_id}.mp4")
if args.viz_attn:
fig, ax = plt.subplots(figsize=(8, 16))
sns.heatmap(text_attn / text_attn.max(1)[:, np.newaxis], vmin = 0, vmax = 1, ax = ax, xticklabels = tokens)
ax.set_ylabel('time')
fig.savefig(f'{args.out_dir}/{video_id}text.jpg', format='jpg', bbox_inches='tight',
pad_inches=0)
plt.close(fig)
fig, ax = plt.subplots(figsize=(20, 16))
sns.heatmap(tsa_attn / tsa_attn.max(1)[:, np.newaxis], vmin=0, vmax=1, ax=ax)
ax.set_ylabel('time')
ax.set_xlabel('time')
fig.savefig(f'{args.out_dir}/{video_id}tsa.jpg', format='jpg', bbox_inches='tight', pad_inches=0)
plt.close(fig)
fig, ax = plt.subplots()
plot1 = ax.plot(pred_sted[:, 0], '-o', color = 'b', label = 'start probability')
plot2 = ax.plot(pred_sted[:, 1], '-o', color = 'r', label = 'end probability')
ax.legend(['start probability', 'end probability'])
fig.savefig(f'{args.out_dir}/{video_id}sted.jpg', format='jpg', bbox_inches='tight',
pad_inches=0)
plt.close(fig)
appreciate your nice work, can u provided the visualization code for analysis more convient