[Bug] (suggested fix) `mmpose.models.pose_estimators.topdown.TopdownPoseEstimator` is unable to be symbolically traced because of untraceable `add_pred_to_datasample()` and `loss()` #3012
Using mmrazor to quantize this model, I stumbled upon an error when the symbolic_trace for the fx graph was being made.
Applied fixes for torch 2.0.0 incompatibility suggested in mmrazor #632 and a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633
Traceback (most recent call last):
File "..../site-packages/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 421, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "..../site-packages/mmpose/models/pose_estimators/base.py", line 161, in forward
return self.predict(inputs, data_samples)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 117, in predict
results = self.add_pred_to_datasample(batch_pred_instances,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/mmpose/models/pose_estimators/topdown.py", line 138, in predict
assert len(batch_pred_instances) == len(batch_data_samples)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "..../site-packages/torch/fx/proxy.py", line 420, in _len_
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported by default. If you want this call to be recorded, please call 'torch.fx.wrap('len') at module scope
Prerequisite
Environment
computer not available at the time
Using: torch 2.0.0+cu118 torchvision: 0.15.0+cu118 mmengine: 0.10.3 mmrazor: 1.0.0 mmpose: 1.3.1
Reproduces the problem - code sample
Using mmrazor to quantize this model, I stumbled upon an error when the
symbolic_trace
for the fx graph was being made.Applied fixes for torch 2.0.0 incompatibility suggested in mmrazor #632 and a fix for nn.Parameters inside TopdownPoseEstimator not being traced in mmrazor #633
Reproduces the problem - error message
Additional information
for
loss()
I suggest the following patch:for
add_pred_to_datasample()
I suggest the following:This solves the issue with fx tracing, although there's still other issues I have yet to solve.