Closed shuaizzZ closed 4 years ago
Maybe you can send me the converting codes so I can check it for you.
I made a mistake and I have solved this problem.
It seems that using nn.ModuleList() takes up more GPU memory than explicitly defining each module.
Is it because of this reason that res2net takes up more GPU memory?
Yes, this may be part of reason that res2net uses more memory. Another part is that the concatation in pytorch is not optimized well.
When I tried to apply the idea of res2net to other networks, I found that the time consumption did not increase, but it used significantly more memory.
If this problem can be improved, it is really great.
If you don't need to finetune the batchnorm layer, merge the batchnorm layer into the conv layer will save you lots of memory. Also, use half-precision will cost you only half of the memory and almost have no performance drop. Let's hope that pytorch will optimize the concat op. And I will let you know if I find another solution to this issue.
In my experience, the extra memory used by res2net is affordable. Maybe you can check if you use the in-place version of relu. In-place can also save lots of memory.
I tried the following changes to the cat method, and found that res2net uses almost the same memory as resnet: ==>Code before change out = sp ... out = torch.cat((out, sp), 1) ==>Code after change out = [sp] ... out.append(sp) out = torch.cat(out, 1)
I am verifying that there is no problem, what do you think?
Thanks! I think your version is quite well. I am gonna test it and update the code.
Haha, but I haven't figured out why this works.
I have modified the code. And it's strange that the new version have almost no memory change during inference. Here is my code: https://gist.github.com/gasvn/0c359e2d6a79760196c828f763879273
I am using pytoch1.3. Maybe it's about the pytorch version. Or this modification saves the memory during training?
I am also using pytoch1.3. Sorry, I tried this change on my network first, and it did save memory during training and inference. Then I tested on res2net, but I forgot to switch the network during the test, so the test result is wrong. I just re-tested it. In fact, it really did not save memory, and I also felt very strange.
Thanks anyway. Fell free to let me know if you have other problems.
Assertion failed: *tensor = importer_ctx->network()->addInput( input.name().c_str(), trt_dtype, trt_dims)