Error when inference with MELD dataset #3

I downloaded all weight and dataset but when I run inference with MELD dataset, the output returned nan (maybe weight or code inference have some problem?) @yuntaeyang


First of all, we can't determine it by looking at the screen, but we assume it's not a problem with the code. To figure out the problem, we need to check audio_inputs, and if audio_inputs is normal, it could be a problem with the model setting. If audio_inputs is abnormal, we need to look at the data in that part. I'm sorry that I can't give you a clear solution.

I used audio_inputs from MELD raw data (I downloaded and extract them then I check it still okay) @yuntaeyang I use same the code from this repo (only change the path of all checkpoint and test_data_csv_path)

import glob
import os
import pandas as pd
import numpy as np
import argparse
import random
from tqdm import tqdm
from dataclasses import dataclass
import warnings
from sklearn.metrics import classification_report
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix

import torch
from import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from transformers import get_linear_schedule_with_warmup
from transformers import RobertaTokenizer, RobertaModel
import gc

from preprocessing import *
from utils import *
from dataset import *
from model import *
#from basic_fusion import *

def parse_args():
    parser = argparse.ArgumentParser(description='Process some arguments')
    parser.add_argument('--epochs', default=10, type=int, help='epoch for training.')
    parser.add_argument('--learning_rate', default=1e-5, type=float, help='learning rate for training.')
    parser.add_argument('--batch_size', default=4, type=int, help='batch for training.')
    parser.add_argument('--seed', default=42, type=int, help='random seed fix')
    args = parser.parse_args()
    return args

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def evaluation(model_t, audio_s, video_s, fusion, dataloader):
    label_list = []
    pred_list = []

    with torch.no_grad():
        for i_batch, data in enumerate(dataloader):            
            batch_input_tokens, attention_masks, audio_inputs, video_inputs, batch_labels = data
            batch_input_tokens, attention_masks, audio_inputs, video_inputs, batch_labels = batch_input_tokens.cuda(), attention_masks.cuda(), audio_inputs.cuda(), video_inputs.cuda(), batch_labels.cuda()

            text_hidden, text_logits = model_t(batch_input_tokens, attention_masks)
            audio_hidden, audio_logits = audio_s(audio_inputs)
            video_hidden, video_logits = video_s(video_inputs)

            print(text_hidden.shape, audio_hidden.shape, video_hidden.shape)

            print(text_logits, audio_logits, video_logits)

            pred_logits = fusion(text_hidden, audio_hidden, video_hidden)


            pred_label = pred_logits.argmax(1).detach().cpu().numpy() 
            true_label = batch_labels.detach().cpu().numpy()

            print("pred label: ", pred_label)
            print("true label: ", true_label)


    return pred_list, label_list

def main(args):
    class Config():
        mask_time_length: int = 3
    """Dataset Loading"""

    text_model = "roberta-large"
    audio_model = "facebook/data2vec-audio-base-960h"
    video_model = "facebook/timesformer-base-finetuned-k400"

    data_path = './dataset/MELD.Raw/'

    train_path = data_path + 'train_meld_emo.csv'
    dev_path = data_path + 'dev_meld_emo.csv'
    # test_path = data_path + 'test_meld_emo.csv'
    test_path = "/home/kevin/research/multi_modal_emotion_detection/TelME/dataset/MELD.Raw/test_meld_emo.csv"

    test_dataset = meld_dataset(preprocessing(test_path))
    test_loader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle=False, num_workers=16, collate_fn=make_batchs)

    clsNum = len(test_dataset.emoList)
    init_config = Config()

    '''teacher model load'''
    model_t = Teacher_model(text_model, clsNum)
    for para in model_t.parameters():
        para.requires_grad = False
    model_t = model_t.cuda()

    '''student model'''
    audio_s = Student_Audio(audio_model, clsNum, init_config)
    for para in audio_s.parameters():
        para.requires_grad = False
    audio_s = audio_s.cuda()

    video_s = Student_Video(video_model, clsNum)
    for para in video_s.parameters():
        para.requires_grad = False
    video_s = video_s.cuda()

    hidden_size, beta_shift, dropout_prob, num_head = 768, 1e-1, 0.2, 3
    fusion = ASF(clsNum, hidden_size, beta_shift, dropout_prob, num_head)
    #fusion = M3_Fusion(clsNum)
    for para in fusion.parameters():
        para.requires_grad = False
    fusion = fusion.cuda()

    """Training Setting"""        
    test_pred_list, test_label_list = evaluation(model_t, audio_s, video_s, fusion, test_loader)
    print(classification_report(test_label_list, test_pred_list, target_names=test_dataset.emoList, digits=5))
    print(confusion_matrix(test_label_list, test_pred_list, normalize='true'))

if __name__ == "__main__":
    args = parse_args()
yuntaeyang commented 4 months ago

Hello, I checked with Google Colab setting to see if there is any problem with the model checkpoint that we shared with our code. In that setting, you can check the result without any problem. I think I need to check more about your setting rather than the problem with the code. If you want, I will share the link that I tested on Google Colab. 스크린샷 2024-06-07 오후 3 55 57

yuntaeyang commented 4 months ago

For issues currently occurring, some MELD data may be missing from the setting, which is just a baseless guess.

kevinkhanhvu commented 4 months ago

Thanks @yuntaeyang , can you share with me the link colab notebook you test, I think maybe the dataset MELD I downloaded have some problem!

yuntaeyang commented 4 months ago

For inference, you can run the test.ipynb file, but you will restart the session after the requirements.txt installation.

Thank you!