phil-bergmann / tracking_wo_bnw

Implementation of "Tracking without bells and whistles” and the multi-object tracking "Tracktor"
https://arxiv.org/abs/1903.05625
GNU General Public License v3.0
818 stars 196 forks source link

Replace Faster-RCNN pre-trained weights with Mask-RCNN, Cascade-RCNN #100

Closed starkgate closed 4 years ago

starkgate commented 4 years ago

How modular is your network? Would simply swapping the Faster-RCNN weights for another models' work?

timmeinhardt commented 4 years ago

Hello, our code is quite modular. If you use the Mask R-CNN code from the official torchvision package your do not have to change much. You only have to account for the additional mask head.

starkgate commented 4 years ago

Thanks. My problem is that I'm not using weights from the official torchvision package, but from mmdetection. Their Faster-RCNN seems a little different from the torchvision one, which causes issues when loading.

Below is a comparison of the first few layers of Faster-RCNN, on torchvision and mmdetection. What would you advise to make the mmdetection model work? I could simply rename the layers from eg backbone.conv1.weight to backbone.body.conv1.weight and drop the layers that are only in mmdetection. The other possibility would be to replace Torchvision's implementation with mmdetection's, but since Torchvision is so tightly integrated to your project, replacing it altogether seems unfeasible. Maybe I'm missing something.

backbone.conv1.weight    torch.Size([64, 3, 7, 7])
backbone.bn1.weight    torch.Size([64])
backbone.bn1.bias    torch.Size([64])
backbone.bn1.running_mean    torch.Size([64])
backbone.bn1.running_var   torch.Size([64])
backbone.bn1.num_batches_tracked   torch.Size([])
backbone.layer1.0.conv1.weight   torch.Size([64, 64, 1, 1])
backbone.layer1.0.bn1.weight   torch.Size([64])
backbone.layer1.0.bn1.bias   torch.Size([64])
backbone.layer1.0.bn1.running_mean   torch.Size([64])
backbone.layer1.0.bn1.running_var    torch.Size([64])
backbone.layer1.0.bn1.num_batches_tracked
backbone.body.conv1.weight   torch.Size([64, 3, 7, 7])
backbone.body.bn1.weight   torch.Size([64])
backbone.body.bn1.bias   torch.Size([64])
backbone.body.bn1.running_mean   torch.Size([64])
backbone.body.bn1.running_var    torch.Size([64])

backbone.body.layer1.0.conv1.weight    torch.Size([64, 64, 1, 1])
backbone.body.layer1.0.bn1.weight    torch.Size([64])
backbone.body.layer1.0.bn1.bias    torch.Size([64])
backbone.body.layer1.0.bn1.running_mean    torch.Size([64])
backbone.body.layer1.0.bn1.running_var   torch.Size([64])
timmeinhardt commented 4 years ago

If the layers only have different names you can just copy the weights into a dict with the requires names and load that one as a model file. There is a lot of literature only on how to remove/add layers to existing state dictionaries. But if there are more differences between the torchvision and mmdetection models this will not work without some additional training. Ideally you use mask rcnn weights for the torchvision implementation.

starkgate commented 4 years ago

Alright, figured it out. The key is to replace: https://github.com/phil-bergmann/tracking_wo_bnw/blob/7c7c9cd40520186b219dbf3bd27e4967ffc54f33/experiments/scripts/test_tracktor.py#L59

with the code used by mmdetection, to load the model you want: https://github.com/open-mmlab/mmdetection/blob/9ee13ab6c0008e10d031fef8933313a904c2d4b3/tools/train.py#L151

Basically integrate mmdetection into BNW. I've attached my code as a zip, it might be helpful to others. I'll probably publish a fork at some point. Thanks for the help!

starkgate commented 3 years ago

FYI, I published my code here: https://github.com/starkgate/tracking_wo_bnw

In case anyone has the same problem.

Leo63963 commented 3 years ago

Hi starkgate, @starkgate, I really admire your issue and have found your code from the link. Could you please give more information about your code? Such as how to load checkpoint file, config file from the mmdetection? Plus, from your code, in test_tracktor.py, it seems you still using FRCNN_FPN as model, instead of the model loaded from mmdetection, could you please give more information about that? Many thanks.

Leo63963 commented 3 years ago

Hi starkgate, @starkgate I revised your code, and try to run it. In my environment, I first install mmdetection, and then install requirement in your code, but the code cannot run with unknow bugs, and the mmdetection cannot run in that environment as well. Just kindly ask, what version of mmdetection do you use to run the code? Many thanks.

starkgate commented 3 years ago

Oops, looks like I forgot to push some files. Should be fixed now.

I'm using mmcv 1.1.4 and mmdet 2.3.0. The way I have it set up is I have Tracktor installed in a subfolder of mmdetection. So, my folder structure looks like this:

Leo63963 commented 3 years ago

Hi starkgate @starkgate Thank you so much for your reply. I am still trying running your code, and encountered some unkonw bugs, trying to solve it. Could you please show some details of test_tracktor.py ? Many thanks.

starkgate commented 3 years ago

I added instructions in the readme, it should help quite a bit.

Leo63963 commented 3 years ago

Thank you so much !

Leo63963 commented 3 years ago

Hi ! Can I just use the 'dataset' from the original, to test with MOT17 dataset? Thanks.

starkgate commented 3 years ago

Sure. Should work just fine

Leo63963 commented 3 years ago

Hi starkgate! I replace the 'dateset' file with the original 'dateset' file, and now it seems loading data works fine. And another bugs happened: Traceback (most recent call last): File "experiments/scripts/test_tracktor.py", line 106, in tracker.step(frame) File "/home/lee/Detection/mmdetection/tracking_wo_bnw/src/tracktor/tracker.py", line 259, in step blob['img']['filename'] = blob['img_path'][0] TypeError: can't assign a str to a torch.FloatTensor Do you have any idea about this one? Thanks!

starkgate commented 3 years ago

We should probably take this to a new issue in my fork. If you have Pycharm installed, you could add a breakpoint at tracker.py:259 to see what data is in blob['img_path'][0].