facebookresearch / pytorchvideo

A deep learning library for video understanding research.
https://pytorchvideo.org/
Apache License 2.0
3.29k stars 406 forks source link

Converting pretrained Slowfast_r50 model to Torchscript #212

Open realsazzad opened 2 years ago

realsazzad commented 2 years ago

I have been trying to convert the pretrained slowfast_r50 model to torchscript. But getting the following error. Could anyone help me out on this matter? Is it possible to convert the existing pretrained pytorchvideo model to torchscript or ONNX format? Thanks.

import torch
import json
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo
)
with open("kinetics_classnames.json", "r") as f:
    kinetics_classnames = json.load(f)

# Create an id to label name mapping
kinetics_id_to_classname = {}
for k, v in kinetics_classnames.items():
    kinetics_id_to_classname[v] = str(k).replace('"', "")

# Device on which to run the model
# Set to cuda to load on GPU
device = "cpu"

# Pick a pretrained model
model_name = "slowfast_r50"
model = torch.hub.load("facebookresearch/pytorchvideo:main", model=model_name, pretrained=True)

# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()

####################
# SlowFast transform
####################

side_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
alpha = 4

class PackPathway(torch.nn.Module):
    """
    Transform for converting video frames as a list of tensors.
    """

    def __init__(self):
        super().__init__()

    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(
                0, frames.shape[1] - 1, frames.shape[1] // alpha
            ).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list

transform = ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(num_frames),
            Lambda(lambda x: x / 255.0),
            NormalizeVideo(mean, std),
            ShortSideScale(
                size=side_size
            ),
            CenterCropVideo(crop_size),
            PackPathway()
        ]
    ),
)

# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate) / frames_per_second

# Load the example video
video_path = "demo.mp4"

# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration

# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)

# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)

# Apply a transform to normalize the video input
video_data = transform(video_data)

# Move the inputs to the desired device
inputs = video_data["video"]
inputs = [i.to(device)[None, ...] for i in inputs]

# Pass the input clip through the model
preds = model(inputs)
traced_script_module = torch.jit.trace(model, inputs)

#Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

Error stack:

Traceback (most recent call last): File "D:\Project Code\pytorchvideo\inference.py", line 138, in <module> traced_script_module = torch.jit.trace(model, inputs) File "C:\Users\sazza\AppData\Roaming\Python\Python39\site-packages\torch\jit\_trace.py", line 741, in trace return trace_module( File "C:\Users\sazza\AppData\Roaming\Python\Python39\site-packages\torch\jit\_trace.py", line 958, in trace_module module._c._create_method_from_trace( File "C:\Users\sazza\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "C:\Users\sazza\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py", line 1090, in _slow_forward result = self.forward(*input, **kwargs) TypeError: forward() takes 2 positional arguments but 3 were given

matteosal commented 1 year ago

A ping on this. Can these models be traced with TorchScript?