Thanks so much for sharing this implementation! :)
Ran into a snag while loading the pre-trained model. Seems like the pretrained models were saved using DataParallel so naively loading the model throws error. Just calling model = torch.nn.DataParallel(model) and then calling load_part_of_model fixes the issue.
It's a minor thing but just in case others run into the same issue.
Hi Lucas!
Thanks so much for sharing this implementation! :)
Ran into a snag while loading the pre-trained model. Seems like the pretrained models were saved using
DataParallel
so naively loading the model throws error. Just callingmodel = torch.nn.DataParallel(model)
and then callingload_part_of_model
fixes the issue.It's a minor thing but just in case others run into the same issue.
Best,