Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.17k stars 195 forks source link

Error while loading a model #620

Closed MohamedA95 closed 1 year ago

MohamedA95 commented 1 year ago

Hi everyone, I am getting the following error while loading a model trained with brevitas

Traceback (most recent call last):
  File "tools/test_widerface.py", line 258, in <module>
    main()
  File "tools/test_widerface.py", line 132, in main
    model.load_state_dict(checkpoint['state_dict'])
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1468, in load_state_dict
    load(self)
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1466, in load
    load(child, prefix + name + '.')
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1466, in load
    load(child, prefix + name + '.')
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1466, in load
    load(child, prefix + name + '.')
  [Previous line repeated 5 more times]
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1462, in load
    module._load_from_state_dict(
  File "/home/user/.conda/envs/scrfd3/lib/python3.8/site-packages/brevitas/core/scaling/standalone.py", line 299, in _load_from_state_dict
    missing_keys.remove(prefix + 'buffer')
ValueError: list.remove(x): x not in list

The key that the code is trying to remove from the list is backbone.stem.0.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.buffer I loaded the checkpoint manually and confirmed that It exists in the state dict.

This is how this part of the model is defined:

if quant == 2:
    assert bre_quant_cfg is not None
    for layer_type in bre_quant_cfg.keys():
        for k in bre_quant_cfg[layer_type].keys():
            if isinstance(bre_quant_cfg[layer_type][k],str):
                if bre_quant_cfg[layer_type][k].lower() == 'none':
                    bre_quant_cfg[layer_type][k] = None
                else:
                    bre_quant_cfg[layer_type][k] = getattr(brevitas.quant, bre_quant_cfg[layer_type][k])
    def conv_bn(inp, oup, stride):
        return nn.Sequential(
            qnn.QuantConv2d(inp, oup, 3, stride, 1, bias=False, **bre_quant_cfg['conv']),
            nn.BatchNorm2d(oup),
            qnn.QuantReLU(inplace=True, **bre_quant_cfg['relu'])
        )
    def conv_dw(inp, oup, stride):
        return nn.Sequential(
            qnn.QuantConv2d(inp, inp, 3, stride, 1, groups=inp, bias=False, **bre_quant_cfg['conv']),
            nn.BatchNorm2d(inp),
            qnn.QuantReLU(inplace=True, **bre_quant_cfg['relu']),

            qnn.QuantConv2d(inp, oup, 1, 1, 0, bias=False, **bre_quant_cfg['conv']),
            nn.BatchNorm2d(oup),
            qnn.QuantReLU(inplace=True, **bre_quant_cfg['relu']))
self.stem = nn.Sequential(
                conv_bn(3, stage_planes[0], 2),
                conv_dw(stage_planes[0], stage_planes[1], 1))

Where bre_quant_cfg is a dict defined by this way:

bre_quant_cfg = dict(
    conv=dict(weight_quant='Int8WeightPerTensorFixedPoint', weight_bit_width=8),
    relu=dict(input_quant='None', act_quant='Uint8ActPerTensorFixedPoint'))

I modified _load_from_state_dict in at line:298 to check if the key is in the list before attempting to remove it and this solves the issue, the model loads normally and gives the expected results. Since you deeply understand Brevitas's flow why does the error occur? Why is the buffer supposed to be always missing?

Giuseppe5 commented 1 year ago

Hello,

Thanks for pointing out this issue!

Could you provide a minimum reproducible example for this for me to try? May I also ask what kind of Brevitas/Torch version you are using?

Thanks, Giuseppe

MohamedA95 commented 1 year ago

Hi Giuseppe, Thanks for your response. I am using torch 1.10.0+cu102 and Brevitas 0.9.1. This part of the model is a part of a bigger project that has many dependencies, I will try to create a minimum reproducible example and get back to you. Regards. Mohamed.

Giuseppe5 commented 1 year ago

Hi,

I was wondering if there were any updates on this. I'm happy to look into the issue if I manage to replicate it on my side.

MohamedA95 commented 1 year ago

Hi @Giuseppe5, Thanks for your patience and for following up. It has been a busy week, I am still working on the example. Nevertheless, I started to believe that It's not directly related to brevitas but to the training framework, as the network was trained using mmcv runner. I will confirm this and create an example and get back to you very soon.

MohamedA95 commented 1 year ago

Hi @Giuseppe5, I created a gist with a small script that reproduces the issue. At the end of the file you will find 2 functions train_using_torch which trains the model using pytorch and works correctly, the other one is train_using_mmcv which trains using mmcv runner and produces an error when the model is reloaded. A bit of context: MobileNet here is used as a backbone in the SCRFD face detection flow. The original SCRFD model has 3 functions that I had to create under the mobile net class so the mmcv.runner would accept it as usual. Those are _parse_losses , train_step and val_step. Also mmcv.runner expects the model to return the loss, not the direct output.

Honestly, I would understand if you are not interested in investigating this in the near time as you might have more valuable goals in your list. But Kindly could you let me know if my workaround does not have consequences?

Giuseppe5 commented 1 year ago

Which version of MMCV I am supposed to use?

MohamedA95 commented 1 year ago

Hi, I used 1.2.6 (not the most recent, but is supported by SCRFD) built from source using cd mmcv && MMCV_WITH_OPS=1 pip install -e . -v --user && cd .. below is the full script that I use to build the environment.

export PYTHONNOUSERSITE=True  && export CUDA_HOME=/software/nvidia/10.2.89/
echo "creating conda environment"
conda create -n scrfd python=3.8 -y
conda activate scrfd
echo "installing packages"
python -m pip install torch==1.10.0 torchvision torchaudio torchinfo numpy cython pycuda albumentations more-itertools 'Click!=8.0.0,>=7.0' pathtools 'click>=7.0' 'cloudpickle>=1.5.0' 'jinja2>=2.10.3' 'sortedcontainers>=2.0.5' 'toolz>=0.10.0' onnxruntime-gpu cupy-cuda102 anaconda scipy opencv-python cityscapesscripts imagecorruptions mmlvis 'urllib3<1.27,>=1.21.1' 'charset-normalizer<4,>=2' 'idna<4,>=2.5' --extra-index-url https://download.pytorch.org/whl/cu102 --user
echo "installing mmcv"
cd mmcv && MMCV_WITH_OPS=1 pip install -e . -v --user  && cd ..

Let me know if you face any issues.

Giuseppe5 commented 1 year ago

I couldn't manage to install it because of cuda compatibility reasons but I suspect I found the cause of the error. Investigating inside MMCV 1.2.6, I noticed that they don't call model.state_dict() to retrieve the state dict, but rather they implement a recursive custom function to do that. However this skips Brevitas' overloading of the state_dict method where we remove the buffer key, so that when the state_dict is then reloaded the key is missing.

If you need to use MMCV, I would suggest to add locally the fix you mentioned above, and we'll discuss if/how to tackle this.

One thing to note is that if you load a checkpoint from the state dict, the buffer won't be used anymore to collect stats, and the scale factor will become immediately a learned parameter.

MohamedA95 commented 1 year ago

Hi, Thanks for looking into this, really appreciate it. Regarding the last point, I am not retraining models now, so I should be fine, but I will keep it in mind for future work. Best Regards.