open-mmlab / mmrazor

OpenMMLab Model Compression Toolbox and Benchmark.
https://mmrazor.readthedocs.io/en/latest/
Apache License 2.0
1.45k stars 227 forks source link

[Bug] (suggested fix) `mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_qparams()` fails if there are modules present in other modes but not in forward `mode='tensor'` #634

Open elisa-aleman opened 5 months ago

elisa-aleman commented 5 months ago

Describe the bug

In models where theres modules that exist only in mode 'predict' or in 'loss' but not in 'tensor', the following code fails with a KeyError looking through the state dict of the tensor mode model. For example, if one model has duplicates but the other doesn't.

mmrazor.models.algorithms.quantization.mm_architecture.MMArchitectureQuant.sync_params()#L124--L148

        def traverse(module, prefix):
            for name, child in module._modules.items():
                if module is None:
                    continue
                child_name = f'{prefix}{name}'
                if isinstance(child, FakeQuantizeBase):
                    for name, param in child.named_parameters():
                        param_name = f'{child_name}.{name}'
                        src_param = src_state_dict[param_name]  ## Here
                        if src_param.shape == param.shape:
                            param.data.copy_(src_param)
                        else:
                            requirs_grad = param.requires_grad
                            param.requires_grad = False
                            param.resize_(src_param.shape)
                            param.requires_grad = requirs_grad
                            param.data.copy_(src_param)
                    for name, buffer in child.named_buffers():
                        buffer_name = f'{child_name}.{name}'
                        src_buffer = src_state_dict[buffer_name] # here
                        if src_buffer.shape == buffer.shape:
                            buffer.data.copy_(src_buffer)
                        else:
                            buffer.resize_(src_buffer.shape)
                            buffer.data.copy_(src_buffer)

Additional Context

I have been trying to quantize the mmpose.TopdownPoseEstimator, applying fixes for torch 2.0.0 incompatibility suggested in mmrazor #632, a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633, and a fix on mmpose.TopdownPoseEstimator untraceable methods in mmpose #3012.

Because of a flip input inversion test being added to the predict forward graph, not only are there duplicate modules but also duplicate loose (leaf) activation_post_process_xyz numbered modules that make the syncing fail.

Reproduces the error - code sample

I cannot currently provide the configuration, but the executing code is this:

from mmrazor.models.algorithms.quantization.mm_architechture import MMArchitectureQuant
from mmengine import Config

cfg = Config.fromfile('qat_rtmpose-t_8xb256-420e_coco-256x192.py')

qtopdown = MMArchitectureQuant(
    data_preprocessor=cfg.data_preprocessor,
    architecture=cfg.architecture,
    quantizer=cfg.model.quantizer,
    input_shapes=cfg.model.input_shapes
)

Reproduces the problem - error message

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
    self.sync_qparams('tensor')
   File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
    .....redacted
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 143, in traverse
    src_buffer = src_state_dict[buffer_name]
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
KeyError: 'backbone.stem.0.conv_dup1.weight_fake_quant.fake_quant_enabled'

And while patching that:

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 91, in __init__
    self.sync_qparams('tensor')
   File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 156, in sync_qparams
    .....redacted
  File "..../site-packages/mmrazor/models/algorithms/quantization/mm_architecture.py", line 157, in traverse
    raise KeyError(f"{buffer_name} in mode '{mode}' but not found in source mode '{tensor}', sync_qparams() failed.")
KeyError: "activation_post_process_123.fake_quant_enabled in mode 'predict' but not found in source mode 'tensor', sync_qparams() failed."

Post related information - suggested fix


*EDIT: while this fix allows for syncing of nodes that aren't in other modes, it causes failure in model deployment later down the line

For duplicate modules i figure one can copy the state_dict element with a non-suffixed name, but I don't have a suggestion for non existent modules yet.

For activation post processing leaf nodes, I can ignore most of the copying since a lot of it is reset in MMArchitectureQuant.__init__().

mmrazor/models/algorithms/quantization/mm_architecture.py

@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
                 in some subtle ways, so we need to sync them here.
         """

-        def traverse(module, prefix):
+        def traverse(module, prefix, mode, src_mode):
             for name, child in module._modules.items():
                 if module is None:
                     continue
@@ -129,7 +129,14 @@ class MMArchitectureQuant(BaseAlgorithm):
                 if isinstance(child, FakeQuantizeBase):
                     for name, param in child.named_parameters():
                         param_name = f'{child_name}.{name}'
-                        src_param = src_state_dict[param_name]
+                        src_param = src_state_dict.get(param_name)
+                        if '_dup' in param_name and src_param is None:
+                            param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+                            src_param = src_state_dict.get(param_name)
+                        if src_param is None:
+                            print(src_state_dict)
+                            print(child)
+                            raise KeyError(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
                         if src_param.shape == param.shape:
                             param.data.copy_(src_param)
                         else:
@@ -138,22 +145,42 @@ class MMArchitectureQuant(BaseAlgorithm):
                             param.resize_(src_param.shape)
                             param.requires_grad = requirs_grad
                             param.data.copy_(src_param)
+                    # These are either reset after sync_qparams() is called, or are left as default (eps)
+                    # so there's no need to sync them if there's not a match
+                    skip_buffer_sync = [
+                        "fake_quant_enabled",
+                        "observer_enabled",
+                        "scale",
+                        "zero_point",
+                        "min_val",
+                        "max_val",
+                        "eps",
+                    ]
                     for name, buffer in child.named_buffers():
                         buffer_name = f'{child_name}.{name}'
-                        src_buffer = src_state_dict[buffer_name]
+                        src_buffer = src_state_dict.get(buffer_name)
+                        if '_dup' in buffer_name and src_buffer is None:
+                            buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+                            src_buffer = src_state_dict.get(buffer_name)
+                        if any([s in buffer_name for s in skip_buffer_sync]) and src_buffer is None:
+                            continue
+                            src_buffer = torch.tensor([1], dtype=torch.uint8)
+                        if src_buffer is None:
+                            print(src_state_dict)
+                            print(child)
+                            raise KeyError(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', sync_qparams() failed.")
                         if src_buffer.shape == buffer.shape:
                             buffer.data.copy_(src_buffer)
                         else:
                             buffer.resize_(src_buffer.shape)
                             buffer.data.copy_(src_buffer)
                 else:
-                    traverse(child, f'{child_name}.')
+                    traverse(child, f'{child_name}.', mode, src_mode)
         src_state_dict = self.qmodels[src_mode].state_dict()
         for mode in self.forward_modes:
             if mode == src_mode:
                 continue
-            traverse(self.qmodels[mode], '')
+            traverse(self.qmodels[mode], '', mode, src_mode)

     def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
         """Get rewriter context in mmdeploy according to the deploy related
elisa-aleman commented 4 months ago

Added more context and suggested a fix

elisa-aleman commented 4 months ago

After trying to deploy the quantized model, I realized the suggested fix might be unnecessary and cause further issues since the mmdeploy/tools/deploy.py will force model.architecture.test_cfg.flip_test=False for pose estimators, which means that there would be extra weights in the quantized state_dict and cause the model deploy to fail.

I then tried:

python /tools/train.py \
    ${qat_topdown_cgf} \
    --cgf-options \
         model.architecture.test_cfg.flip_test=False \
    --work-dir /path/here/

But the model still fails to sync without my patch.

elisa-aleman commented 4 months ago

I realized that the sync_qparams() is also called from the loss mode as a source mode during the training loop, so my previous fix actually removes any progress during training. I suggest this new fix that doesn't reset fake weight values if not found, although I've yet to finish deploying this model and so it's subject to changes.

@@ -121,7 +121,7 @@ class MMArchitectureQuant(BaseAlgorithm):
                 in some subtle ways, so we need to sync them here.
         """

-        def traverse(module, prefix):
+        def traverse(module, prefix, mode, src_mode):
             for name, child in module._modules.items():
                 if module is None:
                     continue
@@ -129,7 +129,13 @@ class MMArchitectureQuant(BaseAlgorithm):
                 if isinstance(child, FakeQuantizeBase):
                     for name, param in child.named_parameters():
                         param_name = f'{child_name}.{name}'
-                        src_param = src_state_dict[param_name]
+                        src_param = src_state_dict.get(param_name)
+                        if '_dup' in param_name and src_param is None:
+                            param_name = '.'.join([section.split('_dup')[0] for section in param_name.split('.')])
+                            src_param = src_state_dict.get(param_name)
+                        if src_param is None:
+                            print(f"{param_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_param.shape == param.shape:
                             param.data.copy_(src_param)
                         else:
@@ -140,20 +146,26 @@ class MMArchitectureQuant(BaseAlgorithm):
                             param.data.copy_(src_param)
                     for name, buffer in child.named_buffers():
                         buffer_name = f'{child_name}.{name}'
-                        src_buffer = src_state_dict[buffer_name]
+                        src_buffer = src_state_dict.get(buffer_name)
+                        if '_dup' in buffer_name and src_buffer is None:
+                            buffer_name = '.'.join([section.split('_dup')[0] for section in buffer_name.split('.')])
+                            src_buffer = src_state_dict.get(buffer_name)
+                        if src_buffer is None:
+                            print(f"{buffer_name} in mode: '{mode}' but not found in source mode: '{src_mode}', skipping sync.")
+                            continue
                         if src_buffer.shape == buffer.shape:
                             buffer.data.copy_(src_buffer)
                         else:
                             buffer.resize_(src_buffer.shape)
                             buffer.data.copy_(src_buffer)
                 else:
-                    traverse(child, f'{child_name}.')
+                    traverse(child, f'{child_name}.', mode, src_mode)
         src_state_dict = self.qmodels[src_mode].state_dict()
         for mode in self.forward_modes:
             if mode == src_mode:
                 continue
-            traverse(self.qmodels[mode], '')
+            traverse(self.qmodels[mode], '', mode, src_mode)

     def _get_rewriter_context_in_mmdeploy(self, deploy_cfg):
         """Get rewriter context in mmdeploy according to the deploy related
elisa-aleman commented 3 weeks ago

After some fixing, the solution to this issue is to refactor the model so that all FX tracing is possible on all modes up until wrapped methods that differ in each mode. as long as the only difference in tracing is after the .forward() method, the syncing won't fail.