pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.98k stars 6.92k forks source link

Error(s) in loading state_dict in ..., unexpected keys #2562

Open alex96295 opened 4 years ago

alex96295 commented 4 years ago

🐛 Bug

Hello, I am trying to quantize a model. I have done post training static quantization following the tutorial. During the conversion, I:

mymodel = model(cfg)

mymodel = load_state_dict(torch.load('weights.pt'))

torch.save(mymodel_q.state_dict(), 'weights_q.pt')

When I load it I use the same code as before, i.e. I define the model and then load it using load_state_dict

But having the quantized model information about the scale and zero_point it seems that they are missing from the original model definition. It was guessable since nobody has ever changed 'model(cfg)' after quantization. But how to include info about scale and zero_point?

Am I saving the quantized model in the wrong way?

Thank you!

Additional context

fmassa commented 4 years ago

@raghuramank100 can you have a look?

972461099 commented 3 years ago

RuntimeError: Error(s) in loading state_dict for BMNet: Unexpected key(s) in state_dict: "conv1a.0.scale", "conv1a.0.zero_point", "conv1aa.0.scale", "conv1aa.0.zero_point", "conv1b.0.scale", "conv1b.0.zero_point", "conv2a.0.scale", "conv2a.0.zero_point", "conv2aa.0.scale", "conv2aa.0.zero_point", "conv2b.0.scale", "conv2b.0.zero_point", "conv3a.0.scale", "conv3a.0.zero_point", "conv3aa.0.scale", "conv3aa.0.zero_point", "conv3b.0.scale", "conv3b.0.zero_point", "conv4a.0.scale", "conv4a.0.zero_point", "conv4aa.0.scale", "conv4aa.0.zero_point", "conv4b.0.scale", "conv4b

Karthik-S-EC commented 2 years ago

Hi @alex96295 One Alternative Way to load the Quantized Model in PyTorch is

  1. model_fp32 = #Load and store the state dict of the Original Model (Unquantized)
  2. model_quant_state_dict = torch.load(quant_model_path) # Load Quant Model which you have converted
  3. quant_model = torch.quantization.convert(model_fp32) # Convert to Quant
  4. quant_model.load_state_dict(model_quant_state_dict)

Thats it, Now yo can run the Quant Model without any errors

WangFengtu1996 commented 2 weeks ago

I have some same error to load quantizated model in torch

self._create_model()
        if quantized:

            self.model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
            model_fp32_prepared  = torch.ao.quantization.prepare_qat(self.model)
            model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
            #TODO  load dynamic name
            from pprint import pprint
            # pprint(torch.load('/home/wangft/workspace/weigh/train_model/output/run_289/TCN4_50_QAT_fp32_opset_12_2024-08-28-17-49-47_qat.pth', map_location=self.device).keys())
            # exit()
            pprint(model_int8)
            # exit()
            model_int8.load_state_dict(torch.load('/home/wangft/workspace/weigh/train_model/output/run_289/TCN4_50_QAT_fp32_opset_12_2024-08-28-17-49-47_qat.pth', map_location=self.device))
            exit()