pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.27k stars 6.96k forks source link

torchvision mvit_v2_s export to onnx failed #7003

Open chenzhengda opened 1 year ago

chenzhengda commented 1 year ago

🐛 Describe the bug

import torchvision
import torch.onnx

input_ = torch.randn(1, 3, 16, 224, 224, requires_grad=False)
model = torchvision.models.get_model('mvit_v2_s', weights="DEFAULT")
torch.onnx.export(model, input_, "alexnet.onnx", output_names=['output'])

Pragrame Terminated Unexpected (hang up in graph = _C._jit_pass_onnx(graph, operator_export_type))

Versions

version torch 1.13.0a0+d0d6b1f

janeyx99 commented 1 year ago

@chenzhengda moved to torchvision as it seems more relevant here

oke-aditya commented 1 year ago

I tried to reproduce this but my 16GB + 8GB Swap machine ran out of RAM. Are you sure your process completed sucessfully? And it didn't crash or get killed?

chenzhengda commented 1 year ago

yes,the program was killd by system. I traced the program and found it crash in jit pass(It seems to be in graph = _C._jit_pass_onnx(graph, operator_export_type) in torch.onnx.export) . I'm not sure if this is a bug of torch.onnx.export.

oke-aditya commented 1 year ago

It might not be a bug. Maybe someone with bigger machine can run this.

chenzhengda commented 1 year ago

It might not be a bug. Maybe someone with bigger machine can run this.

This model is not big, why does it take up so much machine memory?

oke-aditya commented 1 year ago

Some memory issue probably, I will try running ONNX over other models.

Frankly they should fit in < 64GB ram.

@YosuaMichael too tried and he could reproduce memory issue over bigger infra.

YosuaMichael commented 1 year ago

Thanks of the report @chenzhengda , I tried to debug this issue, here are some summary on what I found so far:

  1. I ran the script on AWS cluster and it used up more than 234 GB of memory before the process getting killed automatically
  2. Running the same script with mvit_v1_b works, with peak memory around 2 GB (It also works with another video model swin3d_t)

From these two observation, I think there is a memory issue between torch.onnx.export and mvit_v2_s specifically, I will try to debug more on this.

YosuaMichael commented 1 year ago

I have tried to debug mvit_v2 deeper, and I found that the cause of the memory blowup is the function _add_rel_pos that is used on mvit_v2. Here is a script to show the problem:

import torchvision
import torch.onnx
import torch
from memory_profiler import memory_usage

from torchvision.models.video.mvit import _add_shortcut, _add_rel_pos, _unsqueeze

class PartialMViT(torch.nn.Module):
    def __init__(self, mvit, use_add_rel_pos=True):
        super().__init__()
        self.mvit = mvit
        self.use_add_rel_pos = use_add_rel_pos

    def forward(self, x):
        # Copied from MViT forward
        x = _unsqueeze(x, 5, 2)[0]
        x = self.mvit.conv_proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.mvit.pos_encoding(x)
        thw = (self.mvit.pos_encoding.temporal_size,) + self.mvit.pos_encoding.spatial_size

        # Expand mvit first block (copied from MultiscaleBlock forward)
        block0 = self.mvit.blocks[0]
        x_norm1 = block0.norm1(x.transpose(1, 2)).transpose(1, 2) if block0.needs_transposal else block0.norm1(x)

        # Expand attn in the block (copied from MultiscaleAttention forward)
        attn0 = block0.attn
        B, N, C = x.shape
        q, k, v = attn0.qkv(x).reshape(B, N, 3, attn0.num_heads, attn0.head_dim).transpose(1, 3).unbind(dim=2)
        if attn0.pool_k is not None:
            k, k_thw = attn0.pool_k(k, thw)
        else:
            k_thw = thw
        if attn0.pool_v is not None:
            v = attn0.pool_v(v, thw)[0]
        if attn0.pool_q is not None:
            q, thw = attn0.pool_q(q, thw)

        attn = torch.matmul(attn0.scaler * q, k.transpose(2, 3))
        if attn0.rel_pos_h is not None and attn0.rel_pos_w is not None and attn0.rel_pos_t is not None and self.use_add_rel_pos:
            # This is the part that cause ONNX memory blowup
            attn = _add_rel_pos(
                attn,
                q,
                thw,
                k_thw,
                attn0.rel_pos_h,
                attn0.rel_pos_w,
                attn0.rel_pos_t,
            )

        return attn

mvitv1 = torchvision.models.get_model('mvit_v1_b', weights="DEFAULT").eval()
mvitv2 = torchvision.models.get_model('mvit_v2_s', weights="DEFAULT").eval()

def gen_profiled_func(mvit_model, use_add_rel_pos=False):
    def profiled_func():
        input_ = torch.randn(1, 3, 16, 224, 224, requires_grad=False)
        partial_mvit = PartialMViT(mvit_model, use_add_rel_pos=use_add_rel_pos)
        torch.onnx.export(partial_mvit, input_, "mvit.onnx", output_names=['output'])
        return
    return profiled_func

mem = max(memory_usage(proc=gen_profiled_func(mvitv1, use_add_rel_pos=True)))
print(f"v1 with use_add_rel_pos=True using {mem/1000} GB memory! --> BASELINE")

mem = max(memory_usage(proc=gen_profiled_func(mvitv2, use_add_rel_pos=False)))
print(f"v2 with use_add_rel_pos=False using {mem/1000} GB memory!")

mem = max(memory_usage(proc=gen_profiled_func(mvitv2, use_add_rel_pos=True)))
print(f"v2 with use_add_rel_pos=True using {mem/1000} GB memory!")

In my laptop it will show the following output:

v1 with use_add_rel_pos=True using 1.1580625 GB memory! --> BASELINE
v2 with use_add_rel_pos=False using 1.230203125 GB memory!
v2 with use_add_rel_pos=True using 17.05446875 GB memory!

In this script, I run the mvit partially (start from beginning) and will expand the components (MultiscaleBlock and MultiscaleAttention). And I add a new boolean to control whether or not we need to use _add_rel_pos (Normally mvit_v1 will not use it, and mvit_v2 will always use it).

As we can see from the result, v2 without _add_rel_pos will roughly consume similar memory with v1. However, when we use _add_rel_pos the memory usage going up to 17x in this case. (Note: this is the memory usage of partial MViT, not all). Although I can isolate the problem to _add_rel_pos function, I still dont really know how to fix or why this function cause this problem (my guess is that there might be some incompatibility with torch.onnx).

@datumbox do you have any idea on why _add_rel_pos might have this problem? Also, do you know torch.onnx folks that might be able to help?

datumbox commented 1 year ago

I confirm that the Relative Positional Embeddings are quite slow and memory heavy. This is the reason that the implementation allows you to turn them off via the rel_pos_embed parameter. MViTv2 has a few improvements over v1 that are cheap (such as the shortcuts or the way the projection happens in the attention) and these are often used in production settings. Perhaps @lyttonhao and @haooooooqi can provide more details as the authors of the papers.

Concerning why _add_rel_pos is problematic on ONNX, unfortunately I'm not an ONNX power user to be able to give a good advice. I suspect the many matrix multiplications and manipulations that this method does causes ONNX to run out of memory (unclear why this is not cleaned up). This is probably a question for someone with strong expertise in ONNX. I wonder if @BowenBao or @prabhat00155 have thoughts on this.