gangweiX / CGI-Stereo

A novel neural network architecture that can concurrently achieve real-time performance, competitive accuracy, and strong generalization ability.
MIT License
313 stars 29 forks source link

Pre-trained model weights don't load correctly #11

Closed ccaven closed 11 months ago

ccaven commented 11 months ago

After downloading sceneflow.ckpt for the CGI_Stereo model into ./pretrained_models/CGI_Stereo folder, I run this code

maxdisp = 192
model = __models__['CGI_Stereo'](maxdisp)
model.cuda()
state_dict = torch.load("./pretrained_models/CGI_Stereo/sceneflow.ckpt")
model.load_state_dict(state_dict['model'])

and I get this error:

RuntimeError: Error(s) in loading state_dict for CGI_Stereo:
    Missing key(s) in state_dict: "feature.conv_stem.weight", "feature.bn1.weight", "feature.bn1.bias", "feature.bn1.running_mean", "feature.bn1.running_var", "feature.block0.0.0.conv_dw.weight", "feature.block0.0.0.bn1.weight", "feature.block0.0.0.bn1.bias", "feature.block0.0.0.bn1.running_mean", "feature.block0.0.0.bn1.running_var", "feature.block0.0.0.conv_pw.weight", "feature.block0.0.0.bn2.weight", "feature.block0.0.0.bn2.bias", "feature.block0.0.0.bn2.running_mean", "feature.block0.0.0.bn2.running_var", "feature.block1.0.0.conv_pw.weight", "feature.block1.0.0.bn1.weight", "feature.block1.0.0.bn1.bias", "feature.block1.0.0.bn1.running_mean", "feature.block1.0.0.bn1.running_var", "feature.block1.0.0.conv_dw.weight", "feature.block1.0.0.bn2.weight", "feature.block1.0.0.bn2.bias", "feature.block1.0.0.bn2.running_mean", 
...

Do I need to set a different max disparity?

ccaven commented 11 months ago

Nevermind - the problem was that all of the weights in the checkpoint are prefixed with module. I got around this by running this instead

state_dict = torch.load("./pretrained_models/CGI_Stereo/sceneflow.ckpt")
model.load_state_dict({ k[7:]: v for k, v in state_dict['model'].items() })