Open BumjunPark opened 3 years ago
import torch from torchvision.models.mobilenetv2 import mobilenet_v2 from torchvision.edgeailite import xnn
model = mobilenet_v2(False, False, num_classes=6) dummy_input = torch.rand((1, 3, 224, 224)) model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input) checkpoint = torch.load('pretrain.pth') model.module.load_state_dict(checkpoint['model'])
This is the commend line I tried. Thank you.
Hi,
What you did is correct. But the the num_classes is different in your case - that must be the reason why the weights were not loaded - I guess weights of only one (last) layer would not have loaded.
Try without providing num_classes argument. But default, it uses 1000 internally - that should be fine if your number of classes are less than or equal to 1000.
Hi, thank you for fast response.
I used num_classes=6 when pre-train the model with custom dataset, so I have to use the same option when creating the model. I think there should be another reason as I have no problem when loading the parameters without wrapping.
I'll wait for your response. Thank you.
Please use the following: model.module.load_state_dict(checkpoint['model'], strict=False)
strict=False argument tells pytorch to ignore missing keys. clips_act are new fields inserted in the QuantTrainModule() to capture the ranges of activations - that's why they are not found in the checkpoint.
We have a utility function that takes care of this scenario: https://github.com/TexasInstruments/edgeai-torchvision/blob/master/torchvision/edgeailite/xnn/utils/load_weights.py#L49
You can you use it like this: from torchvision.edgeailite import xnn xnn.utils.load_weights(model, pretrained_path)
@mathmanu Thank you for your fast response.
It seems like working though the training speed is slow.
And I also found that the error doesn't occur while shows the same result if I load the weight first, and wrap it with xnn.
ex) model = mobilenet_v2(False, False, num_classes=6) checkpoint = torch.load('pretrain.pth') model.load_state_dict(checkpoint['model']) dummy_input = torch.rand((1, 3, 224, 224)) model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)
Thanks again.
And I also found that the error doesn't occur while shows the same result if I load the weight first, and wrap it with xnn.
That is perfectly fine - but there is a fine detail that you should know.
QuantTrainModule adds some additional parameters - such as the clips_act that we saw earlier. If we want to load a quantized checkpoint (i.e. if you save the module from QuantTrainModule() and then load it back to evaluate accuracy of it) - then we have to load it after wrapping. That's when one of the methods that I suggested become useful. Infact if you use one of the methods that I suggested, then you can load both float checkpoint and quantized checkpoint.
@mathmanu Thank you for your careful suggestion. I'll keep that in mind.
Have you ever quantified edgeai-yolov5 training? Quantification does not converge .
I have forwarded the other issue that you filed (https://github.com/TexasInstruments/edgeai-torchvision/issues/5) to the expert who worked on edgeai-yolov5.
I'm trying to use QAT and following the instructions. However, I can't load my trained parameters. I can load it when I don't wrap the model with xnn.quantize.QuantTrainModule, but error with missing keys in state_dict occurs when wrapped.
RuntimeError: Error(s) in loading state_dict for MobileNetV2: Missing key(s) in state_dict: "features.0.0.activation_in.clips_act", "features.0.0.activation_in.num_batches_tracked", .....
It seems like key of dict changes when wrapped. Please help me with this issue.
p.s. I used torchvision.models.mobilenetv2 both for pre-training and loading to follow the guidelines.