SwinTransformer / Video-Swin-Transformer

This is an official implementation for "Video Swin Transformers".
https://arxiv.org/abs/2106.13230
Apache License 2.0
1.45k stars 200 forks source link

Steps to convert mmaction's video-swin-transformer to ONNX successfully #89

Open gigasurgeon opened 11 months ago

gigasurgeon commented 11 months ago

Hello all. I have been trying to export mmaction's video-swin transformer model to ONNX. However, the script tools/deployment/pytorch2onnx.py provided in this repo was giving me following errors:

error 1) Floating point exception (core dumped) error 2) RuntimeError: input_shape_value == reshape_value || input_shape_value == 1 || reshape_value == 1INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/onnx/shape_type_inference.cpp":513, please report a bug to PyTorch. ONNX Expand input shape constraint not satisfied. error 3) other issues

I tried other another repo's reimplementation i.e https://github.com/haofanwang/video-swin-transformer-pytorch , but found same set of issues. No luck. The default model was able to infer properly, but during onnx export it was failing. Even torch.jit.script() was failing.

So, after days of effort, I was able to come up with a way to successfully export my trained video-swin model to ONNX. Here's the code. I hope this will help.

from torchvision.models.video.swin_transformer import SwinTransformer3d
import torch
from collections import OrderedDict

torchvision_model = SwinTransformer3d(
        patch_size=[2, 4, 4],
        embed_dim= 128,
        depths= [2, 2, 18, 2],
        num_heads=[4, 8, 16, 32],
        window_size=[16, 7, 7],
        mlp_ratio=4.0,
        dropout=0.0,
        attention_dropout= 0.0,
        stochastic_depth_prob=0.1,
        num_classes=5)

mmaction_weights = torch.load('../dl_model_ckpt_swin/frames/swin_last.pth')

assert len(torchvision_model.state_dict())==len(mmaction_weights['state_dict']), "mamction video-swin weight's length doesn't match with torchvision video-swin model's architecture"

# print(torchvision_model)

############################
######## printing pytorch torchvision's swin state_dict without loading checkpoint
# for k, i in enumerate(torchvision_model.state_dict()):
#     print(i, torchvision_model.state_dict()[i].shape)

# print('*'*50)
# print()

######## printing mmaction's swin checkpoints state_dict
# for k, i in enumerate(mmaction_weights['state_dict']):
#     print(i, mmaction_weights['state_dict'][i].shape)

############################

########## asserting shape of state dicts

torchvision_model_keys = [i for i in torchvision_model.state_dict()]
mmaction_weight_keys = [i for i in mmaction_weights['state_dict']]

for i in range(len(torchvision_model_keys)):
    shape_1 = torchvision_model.state_dict()[torchvision_model_keys[i]].shape
    shape_2 = mmaction_weights['state_dict'][mmaction_weight_keys[i]].shape

    if shape_1!=shape_2:
        print('shapes not matching')
        break
print('done')

############################ changing actual weight values in the torchvision swin

new_torchvision_state_dict = OrderedDict()

for i in range(len(torchvision_model_keys)):
    new_torchvision_state_dict[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]

torchvision_model.load_state_dict(new_torchvision_state_dict)

# for i in range(len(torchvision_model_keys)):
# #     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
# #     print('a')
# #     torchvision_model.state_dict()[torchvision_model_keys[i]] = mmaction_weights['state_dict'][mmaction_weight_keys[i]]
#     print(mmaction_weights['state_dict'][mmaction_weight_keys[i]][[-1]])
#     print('a')
#     print(torchvision_model.state_dict()[torchvision_model_keys[i]][-1])
#     print('b')
#     print(new_torchvision_state_dict[torchvision_model_keys[i]][-1])
#     exit()

print('done')

input_shape = [1, 3, 8, 224, 224]
input_tensor = torch.randn(input_shape)
a = torchvision_model(input_tensor)
# torch.jit.script(torchvision_model, (input_tensor))
# torchvision_model = torch.compile(torchvision_model)

torch.onnx.export(
        torchvision_model,
        input_tensor,
        'video_swin.onnx',
        export_params=True,
        keep_initializers_as_inputs=True,
        verbose=True,
        opset_version=15)
adeljalalyousif commented 7 months ago

I have tried the following method but it did not work with video swin transformer, it only works with CNN model from ,torchvision, can you help me with this method:


from mmcv import Config  
from vst.mmaction.models import build_model # vst is folder contains the cloned GitHub of Video Swin Transformer
from mmcv.runner import  load_checkpoint

import torch
import torch.onnx

config = './vst/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py'
checkpoint = './vst/checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth'

cfg = Config.fromfile(config)
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))

model.eval() 
model.cuda()
 # Load the checkpoint onto the GPU
checkpoint = load_checkpoint(model, checkpoint, map_location='cuda')

BATCH_SIZE = 1
T = 16

dummy_input=torch.randn(BATCH_SIZE, 3, T, 224, 224)

# export the model to ONNX
torch.onnx.export(model, dummy_input, "siwn_T.onnx", verbose=False)

I got this error:

Traceback (most recent call last):

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\spyder_kernels\py3compat.py", line 356, in compat_exec exec(code, globals, locals)

File "c:\users\msi\untitled2.py", line 30, in torch.onnx.export(model, dummy_input, "siwn_T.onnx", verbose=False)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx__init__.py", line 350, in export return utils.export(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 163, in export _export(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 1074, in _export graph, params_dict, torch_out = _model_to_graph(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 727, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 602, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\onnx\utils.py", line 517, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 1175, in _get_trace_graph outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 127, in forward graph, out = torch._C._create_graph_by_tracing(

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\jit_trace.py", line 118, in wrapper outs.append(self.inner(*trace_inputs))

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1130, in _call_impl return forward_call(*input, **kwargs)

File "C:\Users\MSI\miniconda3\envs\tr_121\lib\site-packages\torch\nn\modules\module.py", line 1118, in _slow_forward result = self.forward(*input, **kwargs)

File "C:\Users\MSI\vst\mmaction\models\recognizers\base.py", line 253, in forward raise ValueError('Label should not be None.')

ValueError: Label should not be None.

the original block of code in vst that cause the error is:

 [ def forward(self, imgs, label=None, return_loss=True, **kwargs):
        """Define the computation performed at every call."""
        if kwargs.get('gradcam', False):
            del kwargs['gradcam']
            return self.forward_gradcam(imgs, **kwargs)
        if return_loss:
            if label is None:
                raise ValueError('Label should not be None.')
            if self.blending is not None:
                imgs, label = self.blending(imgs, label)
            return self.forward_train(imgs, label, **kwargs)
innat-asj commented 7 months ago

ONNX format. https://www.kaggle.com/models/ipythonx/videoswin/frameworks/onnx

Or,

https://github.com/keras-team/keras-cv/pull/2369#issuecomment-2031466737

adeljalalyousif commented 6 months ago

Thanks a lot, but when I try to convert onnx models provided in Kaggle link to TR engine by using : trtexec --onnx=VideoSwinB_K400_IN1K_P244_W877_32x224.onnx --saveEngine=RT_engine_pytorch.trt --explicitBatch But it fails

innat-asj commented 5 months ago

@adeljalalyousif Could you please share more details? A broken Colab file would be good starter to debug. Here is the file that is used to convert the keras 3 model to onnx.