Closed Umdog786 closed 1 year ago
This may be due to nn.DataParallel appending `module' to all layer names.
Some quick fixes may be like this: https://github.com/gulvarol/bsldict/blob/8629693ae210062a7cdb41373123ba2ca4e5d856/demo/utils.py#L314 or https://github.com/gulvarol/bsl1k/blob/bfeaa3a8463152a79bcfd0fafd358c4d203b4c16/utils/misc.py#L71
Hiya there, I solved this by removing module from each layer in the state dict:
def remove_module_from_keys(state_dict): new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace("module.", "") new_state_dict[new_key] = value return new_state_dict
Thanks for the help :)
I am attempting to retrain the i3d model to output 20 classes. I have defined the model with:
model = InceptionI3d(NUM_CLASSES, in_channels=3, num_in_frames=16)
I am attempting to load the provided state dict with:
model.load_state_dict(torch.load("...bsldicti3d.pth.tar"))
I received an error relating to missing/unexpected keys in the state dict:
RuntimeError: Error(s) in loading state_dict for InceptionI3d: Missing key(s) in state_dict: "logits.conv3d.weight", "logits.conv3d.bias", "Conv3d_1a_7x7.conv3d.weight", "Conv3d_1a_7x7.bn.weight"... Unexpected key(s) in state_dict: "module.logits.conv3d.weight", "module.logits.conv3d.bias", "module.Conv3d_1a_7x7.conv3d.weight", "module.Conv3d_1a_7x7.bn.weight"...
Any ideas on how to fix this? They seem to mainly be the correct state keys with "module" appended to the start. There are some additional ones though