Open chenzhengda opened 1 year ago
@chenzhengda moved to torchvision as it seems more relevant here
I tried to reproduce this but my 16GB + 8GB Swap machine ran out of RAM. Are you sure your process completed sucessfully? And it didn't crash or get killed?
yes,the program was killd by system. I traced the program and found it crash in jit pass(It seems to be in
graph = _C._jit_pass_onnx(graph, operator_export_type)
in torch.onnx.export) . I'm not sure if this is a bug of torch.onnx.export.
It might not be a bug. Maybe someone with bigger machine can run this.
It might not be a bug. Maybe someone with bigger machine can run this.
This model is not big, why does it take up so much machine memory?
Some memory issue probably, I will try running ONNX over other models.
Frankly they should fit in < 64GB ram.
@YosuaMichael too tried and he could reproduce memory issue over bigger infra.
Thanks of the report @chenzhengda , I tried to debug this issue, here are some summary on what I found so far:
234 GB
of memory before the process getting killed automaticallymvit_v1_b
works, with peak memory around 2 GB
(It also works with another video model swin3d_t
)From these two observation, I think there is a memory issue between torch.onnx.export
and mvit_v2_s
specifically, I will try to debug more on this.
I have tried to debug mvit_v2
deeper, and I found that the cause of the memory blowup is the function _add_rel_pos
that is used on mvit_v2
. Here is a script to show the problem:
import torchvision
import torch.onnx
import torch
from memory_profiler import memory_usage
from torchvision.models.video.mvit import _add_shortcut, _add_rel_pos, _unsqueeze
class PartialMViT(torch.nn.Module):
def __init__(self, mvit, use_add_rel_pos=True):
super().__init__()
self.mvit = mvit
self.use_add_rel_pos = use_add_rel_pos
def forward(self, x):
# Copied from MViT forward
x = _unsqueeze(x, 5, 2)[0]
x = self.mvit.conv_proj(x)
x = x.flatten(2).transpose(1, 2)
x = self.mvit.pos_encoding(x)
thw = (self.mvit.pos_encoding.temporal_size,) + self.mvit.pos_encoding.spatial_size
# Expand mvit first block (copied from MultiscaleBlock forward)
block0 = self.mvit.blocks[0]
x_norm1 = block0.norm1(x.transpose(1, 2)).transpose(1, 2) if block0.needs_transposal else block0.norm1(x)
# Expand attn in the block (copied from MultiscaleAttention forward)
attn0 = block0.attn
B, N, C = x.shape
q, k, v = attn0.qkv(x).reshape(B, N, 3, attn0.num_heads, attn0.head_dim).transpose(1, 3).unbind(dim=2)
if attn0.pool_k is not None:
k, k_thw = attn0.pool_k(k, thw)
else:
k_thw = thw
if attn0.pool_v is not None:
v = attn0.pool_v(v, thw)[0]
if attn0.pool_q is not None:
q, thw = attn0.pool_q(q, thw)
attn = torch.matmul(attn0.scaler * q, k.transpose(2, 3))
if attn0.rel_pos_h is not None and attn0.rel_pos_w is not None and attn0.rel_pos_t is not None and self.use_add_rel_pos:
# This is the part that cause ONNX memory blowup
attn = _add_rel_pos(
attn,
q,
thw,
k_thw,
attn0.rel_pos_h,
attn0.rel_pos_w,
attn0.rel_pos_t,
)
return attn
mvitv1 = torchvision.models.get_model('mvit_v1_b', weights="DEFAULT").eval()
mvitv2 = torchvision.models.get_model('mvit_v2_s', weights="DEFAULT").eval()
def gen_profiled_func(mvit_model, use_add_rel_pos=False):
def profiled_func():
input_ = torch.randn(1, 3, 16, 224, 224, requires_grad=False)
partial_mvit = PartialMViT(mvit_model, use_add_rel_pos=use_add_rel_pos)
torch.onnx.export(partial_mvit, input_, "mvit.onnx", output_names=['output'])
return
return profiled_func
mem = max(memory_usage(proc=gen_profiled_func(mvitv1, use_add_rel_pos=True)))
print(f"v1 with use_add_rel_pos=True using {mem/1000} GB memory! --> BASELINE")
mem = max(memory_usage(proc=gen_profiled_func(mvitv2, use_add_rel_pos=False)))
print(f"v2 with use_add_rel_pos=False using {mem/1000} GB memory!")
mem = max(memory_usage(proc=gen_profiled_func(mvitv2, use_add_rel_pos=True)))
print(f"v2 with use_add_rel_pos=True using {mem/1000} GB memory!")
In my laptop it will show the following output:
v1 with use_add_rel_pos=True using 1.1580625 GB memory! --> BASELINE
v2 with use_add_rel_pos=False using 1.230203125 GB memory!
v2 with use_add_rel_pos=True using 17.05446875 GB memory!
In this script, I run the mvit partially (start from beginning) and will expand the components (MultiscaleBlock and MultiscaleAttention). And I add a new boolean to control whether or not we need to use _add_rel_pos
(Normally mvit_v1 will not use it, and mvit_v2 will always use it).
As we can see from the result, v2 without _add_rel_pos
will roughly consume similar memory with v1. However, when we use _add_rel_pos
the memory usage going up to 17x in this case. (Note: this is the memory usage of partial MViT, not all). Although I can isolate the problem to _add_rel_pos
function, I still dont really know how to fix or why this function cause this problem (my guess is that there might be some incompatibility with torch.onnx
).
@datumbox do you have any idea on why _add_rel_pos
might have this problem? Also, do you know torch.onnx
folks that might be able to help?
I confirm that the Relative Positional Embeddings are quite slow and memory heavy. This is the reason that the implementation allows you to turn them off via the rel_pos_embed
parameter. MViTv2 has a few improvements over v1 that are cheap (such as the shortcuts or the way the projection happens in the attention) and these are often used in production settings. Perhaps @lyttonhao and @haooooooqi can provide more details as the authors of the papers.
Concerning why _add_rel_pos
is problematic on ONNX, unfortunately I'm not an ONNX power user to be able to give a good advice. I suspect the many matrix multiplications and manipulations that this method does causes ONNX to run out of memory (unclear why this is not cleaned up). This is probably a question for someone with strong expertise in ONNX. I wonder if @BowenBao or @prabhat00155 have thoughts on this.
🐛 Describe the bug
Pragrame Terminated Unexpected (hang up in
graph = _C._jit_pass_onnx(graph, operator_export_type)
)Versions
version torch 1.13.0a0+d0d6b1f