TexasInstruments / edgeai-torchvision

This repository has been moved. The new location is in https://github.com/TexasInstruments/edgeai-tensorlab
https://github.com/TexasInstruments/edgeai
Other
70 stars 21 forks source link

Can not load trained parameters when wrapped with xnn.quantize.QuantTrainModule #1

Open BumjunPark opened 3 years ago

BumjunPark commented 3 years ago

I'm trying to use QAT and following the instructions. However, I can't load my trained parameters. I can load it when I don't wrap the model with xnn.quantize.QuantTrainModule, but error with missing keys in state_dict occurs when wrapped.

RuntimeError: Error(s) in loading state_dict for MobileNetV2: Missing key(s) in state_dict: "features.0.0.activation_in.clips_act", "features.0.0.activation_in.num_batches_tracked", .....

It seems like key of dict changes when wrapped. Please help me with this issue.

p.s. I used torchvision.models.mobilenetv2 both for pre-training and loading to follow the guidelines.

BumjunPark commented 3 years ago

import torch from torchvision.models.mobilenetv2 import mobilenet_v2 from torchvision.edgeailite import xnn

model = mobilenet_v2(False, False, num_classes=6) dummy_input = torch.rand((1, 3, 224, 224)) model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input) checkpoint = torch.load('pretrain.pth') model.module.load_state_dict(checkpoint['model'])

This is the commend line I tried. Thank you.

mathmanu commented 3 years ago

Hi,

What you did is correct. But the the num_classes is different in your case - that must be the reason why the weights were not loaded - I guess weights of only one (last) layer would not have loaded.

Try without providing num_classes argument. But default, it uses 1000 internally - that should be fine if your number of classes are less than or equal to 1000.

BumjunPark commented 3 years ago

Hi, thank you for fast response.

I used num_classes=6 when pre-train the model with custom dataset, so I have to use the same option when creating the model. I think there should be another reason as I have no problem when loading the parameters without wrapping.

I'll wait for your response. Thank you.

mathmanu commented 3 years ago

Please use the following: model.module.load_state_dict(checkpoint['model'], strict=False)

strict=False argument tells pytorch to ignore missing keys. clips_act are new fields inserted in the QuantTrainModule() to capture the ranges of activations - that's why they are not found in the checkpoint.

mathmanu commented 3 years ago

Another method:

We have a utility function that takes care of this scenario: https://github.com/TexasInstruments/edgeai-torchvision/blob/master/torchvision/edgeailite/xnn/utils/load_weights.py#L49

You can you use it like this: from torchvision.edgeailite import xnn xnn.utils.load_weights(model, pretrained_path)

BumjunPark commented 3 years ago

@mathmanu Thank you for your fast response.

It seems like working though the training speed is slow.

And I also found that the error doesn't occur while shows the same result if I load the weight first, and wrap it with xnn.

ex) model = mobilenet_v2(False, False, num_classes=6) checkpoint = torch.load('pretrain.pth') model.load_state_dict(checkpoint['model']) dummy_input = torch.rand((1, 3, 224, 224)) model = xnn.quantize.QuantTrainModule(model, dummy_input=dummy_input)

Thanks again.

mathmanu commented 3 years ago

And I also found that the error doesn't occur while shows the same result if I load the weight first, and wrap it with xnn.

That is perfectly fine - but there is a fine detail that you should know.

QuantTrainModule adds some additional parameters - such as the clips_act that we saw earlier. If we want to load a quantized checkpoint (i.e. if you save the module from QuantTrainModule() and then load it back to evaluate accuracy of it) - then we have to load it after wrapping. That's when one of the methods that I suggested become useful. Infact if you use one of the methods that I suggested, then you can load both float checkpoint and quantized checkpoint.

BumjunPark commented 3 years ago

@mathmanu Thank you for your careful suggestion. I'll keep that in mind.

Serissa commented 2 years ago

Have you ever quantified edgeai-yolov5 training? Quantification does not converge .

mathmanu commented 2 years ago

I have forwarded the other issue that you filed (https://github.com/TexasInstruments/edgeai-torchvision/issues/5) to the expert who worked on edgeai-yolov5.