sniklaus / pytorch-pwc

a reimplementation of PWC-Net in PyTorch that matches the official Caffe version
GNU General Public License v3.0
608 stars 122 forks source link

Convert pwc_net.caffemodel to network-default.pytorch #34

Closed lhao0301 closed 4 years ago

lhao0301 commented 4 years ago

Thanks for your wonderful code!

I try to convert the official pwc_net.caffemodel to pytorch. And I find a interesing problem, the bias of the FeatureUpsample layers within all the decoder in this repo are not consistent with official model.

I compare my pytorch weights (converted from official caffemodel) with network-default.pytorch in this repo as follows.

data1 = torch.load('network-default.pytorch')  
data2 = torch.load('my-network.pytorch')  
for key in data1.keys():
    print(key, torch.max(data1[key] - data2[key]), torch.min(data1[key] - data2[key])

moduleExtractor.moduleOne.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleOne.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleOne.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleOne.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleOne.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleOne.4.bias tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleTwo.4.bias tensor(0.) tensor(0.)
moduleExtractor.moduleThr.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleThr.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleThr.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleThr.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleThr.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleThr.4.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFou.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFou.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFou.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFou.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFou.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFou.4.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleFiv.4.bias tensor(0.) tensor(0.)
moduleExtractor.moduleSix.0.weight tensor(0.) tensor(0.)
moduleExtractor.moduleSix.0.bias tensor(0.) tensor(0.)
moduleExtractor.moduleSix.2.weight tensor(0.) tensor(0.)
moduleExtractor.moduleSix.2.bias tensor(0.) tensor(0.)
moduleExtractor.moduleSix.4.weight tensor(0.) tensor(0.)
moduleExtractor.moduleSix.4.bias tensor(0.) tensor(0.)
moduleTwo.moduleUpflow.weight tensor(0.) tensor(0.)
moduleTwo.moduleUpflow.bias tensor(0.) tensor(0.)
moduleTwo.moduleUpfeat.weight tensor(0.) tensor(0.)
moduleTwo.moduleUpfeat.bias tensor(0.0873) tensor(0.0774)
moduleTwo.moduleOne.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleOne.0.bias tensor(0.) tensor(0.)
moduleTwo.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleTwo.moduleThr.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleThr.0.bias tensor(0.) tensor(0.)
moduleTwo.moduleFou.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleFou.0.bias tensor(0.) tensor(0.)
moduleTwo.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleTwo.moduleSix.0.weight tensor(0.) tensor(0.)
moduleTwo.moduleSix.0.bias tensor(0.) tensor(0.)
moduleThr.moduleUpflow.weight tensor(0.) tensor(0.)
moduleThr.moduleUpflow.bias tensor(0.) tensor(0.)
moduleThr.moduleUpfeat.weight tensor(0.) tensor(0.)
moduleThr.moduleUpfeat.bias tensor(-0.1421) tensor(-0.1514)
moduleThr.moduleOne.0.weight tensor(0.) tensor(0.)
moduleThr.moduleOne.0.bias tensor(0.) tensor(0.)
moduleThr.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleThr.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleThr.moduleThr.0.weight tensor(0.) tensor(0.)
moduleThr.moduleThr.0.bias tensor(0.) tensor(0.)
moduleThr.moduleFou.0.weight tensor(0.) tensor(0.)
moduleThr.moduleFou.0.bias tensor(0.) tensor(0.)
moduleThr.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleThr.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleThr.moduleSix.0.weight tensor(0.) tensor(0.)
moduleThr.moduleSix.0.bias tensor(0.) tensor(0.)
moduleFou.moduleUpflow.weight tensor(0.) tensor(0.)
moduleFou.moduleUpflow.bias tensor(0.) tensor(0.)
moduleFou.moduleUpfeat.weight tensor(0.) tensor(0.)
moduleFou.moduleUpfeat.bias tensor(0.1554) tensor(-0.1341)
moduleFou.moduleOne.0.weight tensor(0.) tensor(0.)
moduleFou.moduleOne.0.bias tensor(0.) tensor(0.)
moduleFou.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleFou.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleFou.moduleThr.0.weight tensor(0.) tensor(0.)
moduleFou.moduleThr.0.bias tensor(0.) tensor(0.)
moduleFou.moduleFou.0.weight tensor(0.) tensor(0.)
moduleFou.moduleFou.0.bias tensor(0.) tensor(0.)
moduleFou.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleFou.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleFou.moduleSix.0.weight tensor(0.) tensor(0.)
moduleFou.moduleSix.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleUpflow.weight tensor(0.) tensor(0.)
moduleFiv.moduleUpflow.bias tensor(0.) tensor(0.)
moduleFiv.moduleUpfeat.weight tensor(0.) tensor(0.)
moduleFiv.moduleUpfeat.bias tensor(0.0935) tensor(-0.1689)
moduleFiv.moduleOne.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleOne.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleThr.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleThr.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleFou.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleFou.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleFiv.moduleSix.0.weight tensor(0.) tensor(0.)
moduleFiv.moduleSix.0.bias tensor(0.) tensor(0.)
moduleSix.moduleOne.0.weight tensor(0.) tensor(0.)
moduleSix.moduleOne.0.bias tensor(0.) tensor(0.)
moduleSix.moduleTwo.0.weight tensor(0.) tensor(0.)
moduleSix.moduleTwo.0.bias tensor(0.) tensor(0.)
moduleSix.moduleThr.0.weight tensor(0.) tensor(0.)
moduleSix.moduleThr.0.bias tensor(0.) tensor(0.)
moduleSix.moduleFou.0.weight tensor(0.) tensor(0.)
moduleSix.moduleFou.0.bias tensor(0.) tensor(0.)
moduleSix.moduleFiv.0.weight tensor(0.) tensor(0.)
moduleSix.moduleFiv.0.bias tensor(0.) tensor(0.)
moduleSix.moduleSix.0.weight tensor(0.) tensor(0.)
moduleSix.moduleSix.0.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.0.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.0.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.2.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.2.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.4.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.4.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.6.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.6.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.8.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.8.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.10.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.10.bias tensor(0.) tensor(0.)
moduleRefiner.moduleMain.12.weight tensor(0.) tensor(0.)
moduleRefiner.moduleMain.12.bias tensor(0.) tensor(0.)

All the parameters are actually same but the bias in FeatureUpsample as aforementioned. And I find the bias in my converted pytorch model are 0 and in the network-default.pytorch they are not 0. I look into the official trianing prototxt file and find the lr_mult for all deconvolution layers are 0 and so does the initializer(constant means 0 also by default). So the real value of bias should be 0.

I also check the sintel.pytorch in an older version of this repo. I compare sintel.pytorch with network-default.pytorch and find the same phenomenon, all the parameters are actually same but the bias in FeatureUpsample. There may be some randomness during your conversion within the FeatureUpsample layers. I guess the bias are initialized differently everytime.

@sniklaus

sniklaus commented 4 years ago

Thank you for investigating this, I greatly appreciate it! I found a typo on my end and updated the provided models. I also updated the flow visualization but it remains unaffected from this bug. Would you mind downloading the provided models again to see whether they are correct now? Thanks!

lhao0301 commented 4 years ago

Thanks for your reply!

The updated model works well and is identical with the official model. Visualizing the flow, it changes a little only, which can be ignored.

flow_viz_before: viz_flow_raw flow_viz_updated: viz_flow_updated

sniklaus commented 4 years ago

Thank you for confirming, and thanks again for bringing this to my attention!

lhao0301 commented 4 years ago

This problem is also found when I compare your repo and the official repo. Hence, I just propose it here.

After looking into the official pwc-net and flownet2 repo, I find that the mean_subtraction is omitted in the image preprocessing period in this repo. What is interesting is that the predicted optical flow seems not so bad, when mean_subtraction is omitted (Actually, rgb or bgr is also not so important for optical flow's prediction). I'm a litte confused about the mean_subtraction because I didn't find it in the original pytorch implementation as well.

In the pwc_net_test.prototxt, recompute_mean is not zero. And as the code in data_augmentation_layer.cu, both the training and test periods in official caffe version have mean_subtraction. @sniklaus

sniklaus commented 4 years ago

Thank you for your message! I was wondering that too when I converted the Caffe model, but I noticed that the results from this reimplementation are different from the official Caffe model if I add the following lines to the preprocessing.

tensorPreprocessedFirst = tensorPreprocessedFirst - tensorPreprocessedFirst.view(1, 3, -1).mean(2).view(1, 3, 1, 1)
tensorPreprocessedSecond = tensorPreprocessedSecond - tensorPreprocessedSecond.view(1, 3, -1).mean(2).view(1, 3, 1, 1)

Output from the official Caffe implementation: official - caffe

Output from this reimplementation: out2

Output from this reimplementation with the additonal mean subtraction: out

I never bothered to look into this further, Deqing stated that it makes no difference whether or not they use the mean subtraction: https://github.com/NVlabs/PWC-Net/issues/40

I would be happy if you can find the reason behind this, I am afraid that I am a little busy at the moment.