boheumd / A2Summ

The official implementation of 'Align and Attend: Multimodal Summarization with Dual Contrastive Losses' (CVPR 2023)
https://boheumd.github.io/A2Summ/
71 stars 10 forks source link

Implementation for inference. #20

Open jai613 opened 3 weeks ago

jai613 commented 3 weeks ago

Hi. So, I am working on the inference part of your code... Currently I am using the daily mail dataset. So, for a single video here's what I did. 1) Converted video into frames (1 fps) using cv2 2) Used resnet for extracting the frame-level features 3) Used roberta for extracting the sentence level features... 4) While initializing the model and running it, I am making sure that certain things are not passed like the video_label and the text_label because they are used for the calculation of loss.

Below is the implementation (I have given the implementation of inference alone). Could you please tell me if the implementation is right? Any suggestions?

import torch
from torch.nn.utils.rnn import pad_sequence
from config import *
from models import Model_MSMO
from img_load import *
import matplotlib.pyplot as plt

def inference_single_video(video_list, text_list, model):
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    video = pad_sequence(video_list, batch_first=True)
    text = pad_sequence(text_list, batch_first=True)
    mask_video = torch.ones(video.shape[1], dtype=torch.long) 
    mask_text = torch.ones(text.shape[1], dtype=torch.long)   

    video = video.to(device)  
    text = text.to(device)      
    mask_video = mask_video.to(device)  
    mask_text = mask_text.to(device)   
    model.to('cuda')
    with torch.no_grad():
        pred_video, pred_text= model(video=video, text=text, mask_video=mask_video, mask_text=mask_text,inference=True)
    top_k_frames = 5
    top_k_sentences = 5
    _, keyframe_indices = torch.topk(pred_video[0], k=top_k_frames)
    keyframe_indices = keyframe_indices.cpu().tolist()
    _, keysentence_indices = torch.topk(pred_text[0], k=top_k_sentences)
    keysentence_indices = keysentence_indices.cpu().tolist()
    return keyframe_indices, keysentence_indices

args = get_arguments()
model = Model_MSMO(args=args)
model.load_state_dict(torch.load(f'saved_model/Daily_Mail/model_best_video.pt', map_location='cpu')['model_state_dict'])  # Load the model state
model.load_state_dict(torch.load(f'saved_model/Daily_Mail/model_best_text.pt', map_location='cpu')['model_state_dict'])  # Load the model state
video_frames = [torch.from_numpy(np.load(r"inference/frame_features_resnet50.npy"))]
text = [torch.from_numpy(np.load(r"inference/sentence_embeddings.npy"))]

keyframe_indices, keysentence_indices = inference_single_video(video_frames, text, model)

print("Key Frames:", keyframe_indices)
print("Key Sentences:", keysentence_indices)

frames_dir = r"inference/frames"
images = get_frames()
for i in sorted(keyframe_indices):
    img_to_show = cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB)
    plt.imshow(img_to_show)
    plt.axis('off')
    plt.title(f"Image at index {i}: {os.listdir(frames_dir)[i]}")
    plt.show()

Some contextual info about the code: frame_dir is the directory that contains the images of the video (at 1fps) video_frames is the encoded feature vector of frames of the video text is the encoded feature vector of the sentences of the video assumed number of keyframes and sentences to be 5(k)