katha-ai / EmoTx-CVPR2023

[CVPR 2023] Official code repository for "How you feelin'? Learning Emotions and Mental States in Movie Scenes". https://arxiv.org/abs/2304.05634
https://katha-ai.github.io/projects/emotx
56 stars 8 forks source link

Testing File #9

Open VidyaPeddinti opened 1 year ago

VidyaPeddinti commented 1 year ago

Hello, Can you please provide the testing code required for this model. After training, how do we go about testing. Thanks!

dhruvhacks commented 12 months ago

Hey @VidyaPeddinti, If you wish to check results on the test set of MovieGraphs, you may define a test_dataloader and call the utils/train_eval_utils.py:evaluate() method.

So the script would look like (may require debugging)-

from dataloaders.mg_emo_dataset import character_emo_dataset
from models.emotx import EmoTx
from omegaconf import OmegaConf
from pathlib import Path
from torch.utils.data import DataLoader
from utils.train_eval_utils import set_seed, evaluate

import torch
import utils.mg_utils as utils

def get_config():
    """
    Loads the config file and updates it with the command line arguments.
    The model name is also updated. The config is then converted to a dictionary.
    """
    base_conf = OmegaConf.load("config.yaml")
    overrides = OmegaConf.from_cli()
    updated_conf = OmegaConf.merge(base_conf, overrides)
    return OmegaConf.to_container(updated_conf)

if __name__ == "__main__":
    # Load config
    set_seed(0)
    config = get_config()

    # Set variables
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trained_model_checkpoint_path = "<PATH_TO_CHECKPOINT>.pt"

    # Load EmoTx trained checkpoint
    model_checkpoint_filepath = trained_model_checkpoint_path
    chkpt = torch.load(model_checkpoint_filepath)
    model = EmoTx(
        num_labels=chkpt["num_labels"],
        num_pos_embeddings=chkpt["num_pos_embeddings"],
        scene_feat_dim=chkpt["scene_feat_dim"],
        char_feat_dim=chkpt["char_feat_dim"],
        srt_feat_dim=chkpt["srt_feat_dim"],
        num_chars=chkpt["num_chars"],
        num_enc_layers=chkpt["num_enc_layers"],
        max_individual_tokens=chkpt["max_individual_tokens"],
        hidden_dim=chkpt["hidden_dim"]
    )
    model.load_state_dict(chkpt["state_dict"])
    model = model.to(device).eval()

    # Create the test dataloader
    data_split = utils.read_train_val_test_splits(config["resource_path"])
    train_dataset = character_emo_dataset(config=config,
                                          movie_ids=data_split["train"],
                                          split_type="train",
                                          random_feat_selection=config["random_feat_selection"],
                                          with_srt=config["use_srt_feats"])
    emo2id = train_dataset.get_emo2id_map()
    test_dataset = character_emo_dataset(config=config,
                                         movie_ids=data_split["test"],
                                         split_type="test",
                                         random_feat_selection=False,
                                         with_srt=config["use_srt_feats"],
                                         emo2id=emo2id)
    test_dataloader = DataLoader(test_dataset,
                                batch_size=config["batch_size"],
                                shuffle=False,
                                num_workers=config["num_cpus"],
                                collate_fn=test_dataset.collate)

    # Define the criterion to get the test_loss
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor(config["pos_weight"][str(train_dataset.top_k)]).to(device))

    # Call the evaulate method to get the test predictions and metrics
    test_loss, test_metrics = evaluate(ckpt["num_labels"], test_dataloader, device, model, criterion)

We will add the inference script later.

richieej commented 11 months ago

Hello, Can we use the model to test data/movies not from MovieGraphs? Thank you

dhruvhacks commented 11 months ago

Hello @richieej Yes, you can. For this, you will have to go through feature extraction steps first. In an overview-

  1. Extract frame (scene) features using MViT_v1
  2. Detect, track and tag characters within the videos.
  3. Extract the face features for these tagged characters.
  4. Extract subtitle features.
  5. Prepare your videos/data in a directory structure similar to the MovieGraphs dataset.
  6. Use the provided data loaders with your video/movies path.
  7. Use the utils/train_eval_utils.py:evaluate to evaluate them (example to evaluate on test data is provided above).

In case some modalities are not not available (e.g. subtitles), you may set use_{scene|char|srt}_feats=False in the config.

richieej commented 10 months ago

Hello, How do you generate a pickle file with the metadata of the scenes, and is this step necessary? Thank you