moabitcoin / ig65m-pytorch

PyTorch 3D video classification models pre-trained on 65 million Instagram videos
MIT License
265 stars 30 forks source link

Need help with prediction codes #34

Closed wuvei closed 4 years ago

wuvei commented 4 years ago

Hi, I am working on using a pretrained model to do video classifications and I'm a beginner. I borrowed codes from extract.py in cli and other sources. Following codes did produce some results, but seemed not correct. In addition, for some videos, there were indices larger than 400 in max_indices. Appreciate if anyone could help with the codes!

classes.json is from here.

import sys

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable

from torchvision.transforms import Compose

import numpy as np
from einops.layers.torch import Rearrange, Reduce
from tqdm import tqdm

from ig65m.models import r2plus1d_34_32_kinetics
from ig65m.datasets import VideoDataset
from ig65m.transforms import ToTensor, Resize, Normalize
from pathlib import Path
import json

class VideoModel(nn.Module):
    def __init__(self, pool_spatial="mean", pool_temporal="mean"):
        super().__init__()

        self.model = r2plus1d_34_32_kinetics(num_classes=400, pretrained=True, progress=True)

        self.pool_spatial = Reduce("n c t h w -> n c t", reduction=pool_spatial)
        self.pool_temporal = Reduce("n c t -> n c", reduction=pool_temporal)

    def forward(self, x):
        x = self.model.stem(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.pool_spatial(x)
        x = self.pool_temporal(x)

        return x

if __name__ == "__main__":
    if torch.cuda.is_available():
        print("🐎 Running on GPU(s)", file=sys.stderr)
        device = torch.device("cuda")
        torch.backends.cudnn.benchmark = True
    else:
        print("🐌 Running on CPU(s)", file=sys.stderr)
        device = torch.device("cpu")

    model = VideoModel(pool_spatial="mean",
                       pool_temporal="mean")

    model.eval()

    for params in model.parameters():
        params.requires_grad = False

    model = model.to(device)
    model = nn.DataParallel(model)
    with open('classes.json','r') as load_f:
        load_dict = json.load(load_f)
    class_names = np.array(load_dict)

    transform = Compose([
        ToTensor(),
        Rearrange("t h w c -> c t h w"),
        Resize(128),
        Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
    ])

    dataset = VideoDataset(Path("./Yoga3.mp4"), clip=32, transform=transform)
    loader = DataLoader(dataset, batch_size=1, num_workers=0, shuffle=False, pin_memory=True)

    video_outputs = []

    with torch.no_grad():
        for inputs in tqdm(loader, total=len(dataset) // 1):
            inputs = inputs.to(device)
            outputs = model(inputs)

            video_outputs.append(outputs.cpu().data)

    video_outputs = torch.cat(video_outputs)

    results = {
        'video': "Yoga-3.mp4",
        'clips': []
    }

    _, max_indices = video_outputs.max(dim=1)

    for i in range(video_outputs.size(0)):
        clip_results = {}
        clip_results['label'] = class_names[max_indices[i]]
        results['clips'].append(clip_results)

    print(results)
daniel-j-h commented 4 years ago

The extract tool writes out .npy files with clip features pooled in space and time.

https://github.com/moabitcoin/ig65m-pytorch/blob/fc749e2ee354c3e4ddbb144cf511bb868b008f61/ig65m/cli/extract.py#L18-L37

The VideoModel class you copied chops off the classification head, so that we get the clip features and not just a single class index. You can use the models directly and don't need the VideoModel class on top of it. Then you will also get the classification layer at the end:

https://github.com/moabitcoin/ig65m-pytorch/blob/fc749e2ee354c3e4ddbb144cf511bb868b008f61/ig65m/models.py#L73