Closed MohamedA95 closed 1 year ago
Hello,
Thanks for pointing out this issue!
Could you provide a minimum reproducible example for this for me to try? May I also ask what kind of Brevitas/Torch version you are using?
Thanks, Giuseppe
Hi Giuseppe, Thanks for your response. I am using torch 1.10.0+cu102 and Brevitas 0.9.1. This part of the model is a part of a bigger project that has many dependencies, I will try to create a minimum reproducible example and get back to you. Regards. Mohamed.
Hi,
I was wondering if there were any updates on this. I'm happy to look into the issue if I manage to replicate it on my side.
Hi @Giuseppe5, Thanks for your patience and for following up. It has been a busy week, I am still working on the example. Nevertheless, I started to believe that It's not directly related to brevitas but to the training framework, as the network was trained using mmcv runner. I will confirm this and create an example and get back to you very soon.
Hi @Giuseppe5,
I created a gist with a small script that reproduces the issue. At the end of the file you will find 2 functions train_using_torch
which trains the model using pytorch and works correctly, the other one is train_using_mmcv
which trains using mmcv runner and produces an error when the model is reloaded.
A bit of context:
MobileNet here is used as a backbone in the SCRFD face detection flow. The original SCRFD model has 3 functions that I had to create under the mobile net class so the mmcv.runner
would accept it as usual. Those are _parse_losses
, train_step
and val_step
. Also mmcv.runner
expects the model to return the loss, not the direct output.
Honestly, I would understand if you are not interested in investigating this in the near time as you might have more valuable goals in your list. But Kindly could you let me know if my workaround does not have consequences?
Which version of MMCV I am supposed to use?
Hi,
I used 1.2.6 (not the most recent, but is supported by SCRFD) built from source using cd mmcv && MMCV_WITH_OPS=1 pip install -e . -v --user && cd ..
below is the full script that I use to build the environment.
export PYTHONNOUSERSITE=True && export CUDA_HOME=/software/nvidia/10.2.89/
echo "creating conda environment"
conda create -n scrfd python=3.8 -y
conda activate scrfd
echo "installing packages"
python -m pip install torch==1.10.0 torchvision torchaudio torchinfo numpy cython pycuda albumentations more-itertools 'Click!=8.0.0,>=7.0' pathtools 'click>=7.0' 'cloudpickle>=1.5.0' 'jinja2>=2.10.3' 'sortedcontainers>=2.0.5' 'toolz>=0.10.0' onnxruntime-gpu cupy-cuda102 anaconda scipy opencv-python cityscapesscripts imagecorruptions mmlvis 'urllib3<1.27,>=1.21.1' 'charset-normalizer<4,>=2' 'idna<4,>=2.5' --extra-index-url https://download.pytorch.org/whl/cu102 --user
echo "installing mmcv"
cd mmcv && MMCV_WITH_OPS=1 pip install -e . -v --user && cd ..
Let me know if you face any issues.
I couldn't manage to install it because of cuda compatibility reasons but I suspect I found the cause of the error.
Investigating inside MMCV 1.2.6, I noticed that they don't call model.state_dict()
to retrieve the state dict, but rather they implement a recursive custom function to do that. However this skips Brevitas' overloading of the state_dict
method where we remove the buffer
key, so that when the state_dict is then reloaded the key is missing.
If you need to use MMCV, I would suggest to add locally the fix you mentioned above, and we'll discuss if/how to tackle this.
One thing to note is that if you load a checkpoint from the state dict, the buffer won't be used anymore to collect stats, and the scale factor will become immediately a learned parameter.
Hi, Thanks for looking into this, really appreciate it. Regarding the last point, I am not retraining models now, so I should be fine, but I will keep it in mind for future work. Best Regards.
Hi everyone, I am getting the following error while loading a model trained with brevitas
The key that the code is trying to remove from the list is
backbone.stem.0.2.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.buffer
I loaded the checkpoint manually and confirmed that It exists in the state dict.This is how this part of the model is defined:
Where
bre_quant_cfg
is adict
defined by this way:I modified
_load_from_state_dict
in at line:298 to check if the key is in the list before attempting to remove it and this solves the issue, the model loads normally and gives the expected results. Since you deeply understand Brevitas's flow why does the error occur? Why is the buffer supposed to be always missing?