gulvarol / bsldict

Watch, read and lookup: learning to spot signs from multiple supervisors, ACCV 2020 (Best Application Paper)
http://www.robots.ox.ac.uk/~vgg/research/bsldict/
28 stars 4 forks source link

Loading state dict for I3d #5

Closed Umdog786 closed 1 year ago

Umdog786 commented 1 year ago

I am attempting to retrain the i3d model to output 20 classes. I have defined the model with: model = InceptionI3d(NUM_CLASSES, in_channels=3, num_in_frames=16)

I am attempting to load the provided state dict with: model.load_state_dict(torch.load("...bsldicti3d.pth.tar"))

I received an error relating to missing/unexpected keys in the state dict: RuntimeError: Error(s) in loading state_dict for InceptionI3d: Missing key(s) in state_dict: "logits.conv3d.weight", "logits.conv3d.bias", "Conv3d_1a_7x7.conv3d.weight", "Conv3d_1a_7x7.bn.weight"... Unexpected key(s) in state_dict: "module.logits.conv3d.weight", "module.logits.conv3d.bias", "module.Conv3d_1a_7x7.conv3d.weight", "module.Conv3d_1a_7x7.bn.weight"...

Any ideas on how to fix this? They seem to mainly be the correct state keys with "module" appended to the start. There are some additional ones though

gulvarol commented 1 year ago

This may be due to nn.DataParallel appending `module' to all layer names.

Some quick fixes may be like this: https://github.com/gulvarol/bsldict/blob/8629693ae210062a7cdb41373123ba2ca4e5d856/demo/utils.py#L314 or https://github.com/gulvarol/bsl1k/blob/bfeaa3a8463152a79bcfd0fafd358c4d203b4c16/utils/misc.py#L71

Umdog786 commented 1 year ago

Hiya there, I solved this by removing module from each layer in the state dict:

def remove_module_from_keys(state_dict): new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace("module.", "") new_state_dict[new_key] = value return new_state_dict

Thanks for the help :)