open-mmlab / mmpose

OpenMMLab Pose Estimation Toolbox and Benchmark.
https://mmpose.readthedocs.io/en/latest/
Apache License 2.0
5.59k stars 1.22k forks source link

[Bug] (suggested fix) `mmpose.models.pose_estimators.topdown.TopdownPoseEstimator` is unable to be symbolically traced because of untraceable `add_pred_to_datasample()` and `loss()` #3012

Open elisa-aleman opened 5 months ago

elisa-aleman commented 5 months ago

Prerequisite

Environment

computer not available at the time

Using: torch 2.0.0+cu118 torchvision: 0.15.0+cu118 mmengine: 0.10.3 mmrazor: 1.0.0 mmpose: 1.3.1

Reproduces the problem - code sample

Using mmrazor to quantize this model, I stumbled upon an error when the symbolic_trace for the fx graph was being made.

Applied fixes for torch 2.0.0 incompatibility suggested in mmrazor #632 and a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633

from mmrazor.models.task_modules.tracer.fx.custom_tracer import CustomTracer
from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator
from mmengine.config import Config

cfg = Config.fromfile('/mmpose/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py')

rtmpose = TopdownPoseEstimator(
    backbone=cfg.model.backbone,
    neck=cfg.model.neck,
    head=cfg.model.head,
    train_cfg=cfg.train_cfg,
    data_preprocessor=cfg.model.data_preprocessor,
)

tracer = CustomTracer(
    skipped_methods=[
        'mmpose.models.heads.RTMCCHead.loss',
        'mmpose.models.heads.RTMCCHead.predict',
    ]
)
traced_graph = tracer.trace(rtmpose)

Reproduces the problem - error message

Traceback (most recent call last):
  File "..../site-packages/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 421, in trace
    'output', (self.create_arg(fn(*args)), ), {},

  File "..../site-packages/mmpose/models/pose_estimators/base.py", line 161, in forward
    return self.predict(inputs, data_samples)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 117, in predict
    results = self.add_pred_to_datasample(batch_pred_instances,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 138, in predict
    assert len(batch_pred_instances) == len(batch_data_samples)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "..../site-packages/torch/fx/proxy.py", line 420, in _len_
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported by default. If you want this call to be recorded, please call 'torch.fx.wrap('len') at module scope

Additional information

for loss() I suggest the following patch:

@@ -68,8 +68,8 @@ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
         feats = self.extract_feat(inputs)

-        losses = dict()
-
         if self.with_head:
-            losses.update(
                self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
+            losses = {self.head.loss(feats, data_samples, train_cfg=self.train_cfg)}
+        else:
+            losses = {}

        return losses

for add_pred_to_datasample() I suggest the following:

@@ -138,48 +138,62 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
-        assert len(batch_pred_instances) == len(batch_data_samples)
-        if batch_pred_fields is None:
-            batch_pred_fields = []
         output_keypoint_indices = self.test_cfg.get('output_keypoint_indices',
                                                    None)
-
-        for pred_instances, pred_fields, data_sample in zip_longest(
-                batch_pred_instances, batch_pred_fields, batch_data_samples):
-
-            gt_instances = data_sample.gt_instances
-
-            # convert keypoint coordinates from input space to image space
-            input_center = data_sample.metainfo['input_center']
-            input_scale = data_sample.metainfo['input_scale']
-            input_size = data_sample.metainfo['input_size']
-
-            pred_instances.keypoints[..., :2] = \
-                pred_instances.keypoints[..., :2] / input_size * input_scale \
-                + input_center - 0.5 * input_scale
-            if 'keypoints_visible' not in pred_instances:
-                pred_instances.keypoints_visible = \
-                    pred_instances.keypoaint_scores
-
-            if output_keypoint_indices is not None:
-                # select output keypoints with given indices
-                num_keypoints = pred_instances.keypoints.shape[1]
-                for key, value in pred_instances.all_items():
-                    if key.startswith('keypoint'):
-                        pred_instances.set_field(
-                            value[:, output_keypoint_indices], key)
-
-            # add bbox information into pred_instances
-            pred_instances.bboxes = gt_instances.bboxes
-            pred_instances.bbox_scores = gt_instances.bbox_scores
-
-            data_sample.pred_instances = pred_instances
-
-            if pred_fields is not None:
-                if output_keypoint_indices is not None:
-                    # select output heatmap channels with keypoint indices
-                    # when the number of heatmap channel matches num_keypoints
-                    for key, value in pred_fields.all_items():
-                        if value.shape[0] != num_keypoints:
-                            continue
-                        pred_fields.set_field(value[output_keypoint_indices],
-                                              key)
-                data_sample.pred_fields = pred_fields
+        batch_data_samples = _add_pred_to_datasample(
+            output_keypoint_indices,
+            batch_pred_instances,
+            batch_pred_fields,
+            batch_data_samples
+            )
         return batch_data_samples
+
+
+ @torch.fx.wrap
+ def _add_pred_to_datasample(
+     output_keypoint_indices,
+     batch_pred_instances: InstanceList,
+     batch_pred_fields: Optional[PixelDataList],
+     batch_data_samples: SampleList) -> SampleList:
+     assert len(batch_pred_instances) == len(batch_data_samples)
+     if batch_pred_fields is None:
+         batch_pred_fields = []
+ 
+     for pred_instances, pred_fields, data_sample in zip_longest(
+             batch_pred_instances, batch_pred_fields, batch_data_samples):
+ 
+         gt_instances = data_sample.gt_instances
+ 
+         # convert keypoint coordinates from input space to image space
+         input_center = data_sample.metainfo['input_center']
+         input_scale = data_sample.metainfo['input_scale']
+         input_size = data_sample.metainfo['input_size']
+ 
+         pred_instances.keypoints[..., :2] = \
+             pred_instances.keypoints[..., :2] / input_size * input_scale \
+             + input_center - 0.5 * input_scale
+         if 'keypoints_visible' not in pred_instances:
+             pred_instances.keypoints_visible = \
+                 pred_instances.keypoaint_scores
+ 
+         if output_keypoint_indices is not None:
+             # select output keypoints with given indices
+             num_keypoints = pred_instances.keypoints.shape[1]
+             for key, value in pred_instances.all_items():
+                 if key.startswith('keypoint'):
+                     pred_instances.set_field(
+                         value[:, output_keypoint_indices], key)
+ 
+         # add bbox information into pred_instances
+         pred_instances.bboxes = gt_instances.bboxes
+         pred_instances.bbox_scores = gt_instances.bbox_scores
+ 
+         data_sample.pred_instances = pred_instances
+ 
+         if pred_fields is not None:
+             if output_keypoint_indices is not None:
+                 # select output heatmap channels with keypoint indices
+                 # when the number of heatmap channel matches num_keypoints
+                 for key, value in pred_fields.all_items():
+                     if value.shape[0] != num_keypoints:
+                         continue
+                     pred_fields.set_field(value[output_keypoint_indices],
+                                           key)
+             data_sample.pred_fields = pred_fields
+         return batch_data_samples

This solves the issue with fx tracing, although there's still other issues I have yet to solve.

elisa-aleman commented 5 months ago

Added reproducing code and full fix suggestion