Open Levi-zhan opened 10 months ago
This might be tangentially related to what I encountered in the mmpose TopdownEstimator in issue #3012
You might need to refactor the model so that there is no self-referencing methods within it, and instead point to wrapped outer methods.
I haven't checked if thats the case for mmseg but it might point you in the right direction.
Hi, I have the same problem with the class EncoderDecoder from the segmentors of MMSegmentation (line 208). Did you manage to refactor your model and how?
Hi, I have the same problem with the class EncoderDecoder from the segmentors of MMSegmentation (line 208). Did you manage to refactor your model and how?
Yes, I haven't posted an issue yet, but you should mimic the structure in mmpretrain.models.heads.cls_head.ClsHead where there is an additional _get_loss
and _get_predict
that handle all the untraceable methods, and only trace the code where forward is being called on the input.
Thank you. I have changed the following argument of the MMRazor CustomTracer to fit with the EncoderDecoder class:
skipped_methods=[
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.predict_by_feat',
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.loss_by_feat']
Both auxiliary head (FCNHead) and decode head (PSPHead) use the the same predict and loss functions.
Moreover, I have take the whole code of the EncoderDecoder predict method out of the class (except from the self.inference() call), by creating functions with a @torch.fx.wrap decorator.
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (List[:obj:`SegDataSample`], optional): The seg data
samples. It usually includes information such as `metainfo`
and `gt_sem_seg`.
Returns:
list[:obj:`SegDataSample`]: Segmentation results of the
input images. Each SegDataSample usually contain:
- ``pred_sem_seg``(PixelData): Prediction of semantic segmentation.
- ``seg_logits``(PixelData): Predicted logits of semantic
segmentation before normalization.
"""
batch_img_metas = _prepare_batch(inputs, data_samples)
seg_logits = self.inference(inputs, batch_img_metas)
return postprocess_result(self.decode_head, seg_logits, data_samples)
The problem now is when calling the EncoderDecoder loss function, it calls the EncoderDecoder _decode_head_forward_train and _auxiliary_head_forward_train functions which try to update a dictionnary of losses. I can't make the same changes you have made in mmpose TopdownEstimator for the loss function, as the latter two functions update the dictionnary.
Do I have to pass the EncoderDecoder loss function entirely to skipped_methods
, or is this a bigger issue?
Here is the full log of the issue:
/opt/conda/lib/python3.10/site-packages/mmseg/models/backbones/resnet.py:431: UserWarning: DeprecationWarning: pretrained is a deprecated, please use "init_cfg" instead
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
/opt/conda/lib/python3.10/site-packages/mmseg/models/builder.py:36: UserWarning: ``build_loss`` would be deprecated soon, please use ``mmseg.registry.MODELS.build()``
warnings.warn('``build_loss`` would be deprecated soon, please use '
/opt/conda/lib/python3.10/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.
warnings.warn(
Loads checkpoint by local backend from path: /workspace/mmlab/MMR/qat/seg/pspnet_r18-d8_512x1024_80k_cityscapes_20201225_021458-09ffa746.pth
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/workspace/mmlab/mmrazor/tools/train.py", line 121, in <module>
main()
File "/workspace/mmlab/mmrazor/tools/train.py", line 114, in main
runner = Runner.from_cfg(cfg)
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
runner = cls(
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
self.model = self.build_model(model)
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
model = MODELS.build(model)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
return self.build_func(cfg, *args, **kwargs, registry=self)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 232, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 121, in build_from_cfg
obj = obj_cls(**args) # type: ignore
File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 90, in __init__
self.qmodels = self._build_qmodels(self.architecture)
File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 300, in _build_qmodels
observed_module = self.quantizer.prepare(model, concrete_args)
File "/workspace/mmlab/mmrazor/mmrazor/models/quantizers/native_quantizer.py", line 231, in prepare
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 422, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/base.py", line 94, in forward
return self.loss(inputs, data_samples)
File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 179, in loss
loss_decode = self._decode_head_forward_train(x, data_samples)
File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 143, in _decode_head_forward_train
losses.update(add_prefix(loss_decode, 'decode'))
File "/opt/conda/lib/python3.10/site-packages/mmseg/utils/misc.py", line 24, in add_prefix
for name, value in inputs.items():
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 274, in __iter__
return self.tracer.iter(self)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 183, in iter
raise TraceError('Proxy object cannot be iterated. This can be '
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors
@Veccoy
Passing the entire loss function to skipped_methods
will disallow the fake quantize observers to be calibrated, but anything that is inside the loss function that is not calling the head forward call can be refactored in another method which you can then skip. Basically you want the tracer to trace all nodes that are common between forward, predict, and loss, but not anything else necessarily.
In this case something like this should work:
def _get_loss(self, x: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
x (Tensor): forward call result.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
losses = dict()
loss_decode = self._decode_head_forward_train(x, data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
losses.update(loss_aux)
return losses
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
inputs (Tensor): Input images.
data_samples (list[:obj:`SegDataSample`]): The seg data samples.
It usually includes information such as `metainfo` and
`gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(inputs)
losses = self._get_loss(x, data_samples)
return losses
with a config that skips _get_loss
Thank you for your answer. Unfortunately, this doesn't work (see traceback below). It seems to be a malfunction in the trace function when dealing with the 'loss' mode.
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/XXX/.vscode-server/extensions/ms-python.debugpy-2024.0.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/workspace/mmlab/mmrazor/tools/train.py", line 121, in <module>
main()
File "/workspace/mmlab/mmrazor/tools/train.py", line 114, in main
runner = Runner.from_cfg(cfg)
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
runner = cls(
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
self.model = self.build_model(model)
File "/opt/conda/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
model = MODELS.build(model)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
return self.build_func(cfg, *args, **kwargs, registry=self)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 232, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/opt/conda/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 121, in build_from_cfg
obj = obj_cls(**args) # type: ignore
File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 90, in __init__
self.qmodels = self._build_qmodels(self.architecture)
File "/workspace/mmlab/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 300, in _build_qmodels
observed_module = self.quantizer.prepare(model, concrete_args)
File "/workspace/mmlab/mmrazor/mmrazor/models/quantizers/native_quantizer.py", line 231, in prepare
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 422, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/base.py", line 94, in forward
return self.loss(inputs, data_samples)
File "/opt/conda/lib/python3.10/site-packages/mmseg/models/segmentors/encoder_decoder.py", line 205, in loss
losses = self._get_loss(x, data_samples)
File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 72, in wrapped_method
return self.tracer.call_method(mod, self.name, method, args,
File "/workspace/mmlab/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 317, in call_method
return self.create_proxy('call_method', name, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 66, in create_proxy
args_ = self.create_arg(args)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 344, in create_arg
return super().create_arg(a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 140, in create_arg
return type(a)(self.create_arg(elem) for elem in a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 140, in <genexpr>
return type(a)(self.create_arg(elem) for elem in a)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 298, in create_arg
return self.create_node("get_attr", n_, (), {})
File "/opt/conda/lib/python3.10/site-packages/torch/ao/quantization/fx/tracer.py", line 114, in create_node
node = super().create_node(kind, target, args, kwargs, name, type_expr)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/proxy.py", line 46, in create_node
return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/graph.py", line 777, in create_node
name = self._graph_namespace.create_name(candidate, None)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/graph.py", line 137, in create_name
if candidate[0].isdigit():
IndexError: string index out of range
When the trace function of CustomTracer is called, it calls the create_arg method of torch fx for the forward method of EncoderDecoder and several of its modules. However, one of these modules is the EncoderDecoder itself (not submodules), which should not. It enters in create_arg and crashes in this condition because the EncoderDecoder module has no name n_
(empty string).
I think the problem comes from the fact that the _get_loss function is still in the EncoderDecoder class: this makes the EncoderDecoder model appear in the arguments of the create_arg method. I had the same issue and traceback with the tracing of the 'predict' mode and I made some changes (see in this comment). I take the _prepare_batch
and postprocess_result
functions out of the class and put the @torch.fx.wrap decorator on top, which enables the tracing for the 'predict' mode.
@Veccoy
The above Traceback makes me think that you didn't add EncoderDecoder._get_loss
to skipped_methods
. Can you tell me if that is the case?
EDIT: I see, so EncoderDecider is not a submodule, sorry, if so, you'll need to refactor the loss function into not using .update
for dicts, since that is what makes it untraceable
EDIT 2: Or, alternatively, factor the dict handling out of the class and decorate it with @torch.fx.wrap
EDIT 3: You might also need to refactor and skip the refactored code from the decoder head and auxiliary head losses when they also handle dictionaries.
Thank you! Indeed, it works by refactoring the dict handling the batch preparation in respectively the loss
and predict
method of the EncoderDecoder
class and the postprocess_result
method of the class BaseSegmentor
and decorating it with @torch.fx.wrap
. I also put the BaseDecodeHead.predict_by_feat
, the PSPHead.loss_by_feat
and the FCNHead.loss
methods in the skipped_method
argument.
What is the difference between the use of the @torch.fx.wrap
decorator and the skipped_method
argument if both try to handle untraceable code? When using one instead of the other?
Thank you! Indeed, it works by refactoring the dict handling the batch preparation in respectively the
loss
andpredict
method of theEncoderDecoder
class and thepostprocess_result
method of the classBaseSegmentor
and decorating it with@torch.fx.wrap
. I also put theBaseDecodeHead.predict_by_feat
, thePSPHead.loss_by_feat
and theFCNHead.loss
methods in theskipped_method
argument.
Do make sure that the FCNHead.loss
doesn't have any nodes in common i
with EncoderDecoder.forward
, or the fake quants won't calibrate correctly.
What is the difference between the use of the
@torch.fx.wrap
decorator and theskipped_method
argument if both try to handle untraceable code? When using one instead of the other?
@torch.fx.wrap
is mainly for functions, and I use it for things that either repeat across classes or that are on the root class I'm trying to trace. In contrast, skipped_methods
works only on submodule methods, but theoretically if you can skip it without refactoring it is more convenient.
How do you check if methods have nodes in common? FCNHead
and PSPHead
both inherit from the same loss method in the BaseDecodeHead
class, that only do the forward of the head and the computation of the loss. But these heads are submodules inside the EncoderDecoder
model.
Anything that has a forward calculation would need to not be skipped. one way to check is adding a printout of the JIT graph within mmrazor's CustomTracer
Hi, I encountered the same error and modified the predict and loss functions as outlined in this comment and this comment. I also added BaseDecodeHead.predict_by_feat and BaseDecodeHead.loss_by_feat to the skipped functions.
Could you provide more context on FX tracing? I want to ensure I'm not missing any critical steps from the solution you mentioned above. Also, I assume the _prepare_batch() function includes the if-else block from the original script—please confirm if this is correct.
Thank you in advance.
Hi, tracing using Torch FX needs you to make the tracer skip every untraceable code parts. Untraceable parts are all for
, while
and if
structures. Hence, you can make the tracer skip:
skipped_methods
argument provided by MMRazor CustomTracer
;torch.fx.wrap
decorator.The last point is useful for example if you want to skip only the postprocessing part after the forward of your detection head, although they are originally in the same class method. So, the untraceable parts are still used but not traced thanks to the decorator.
Based on your previous response and this comment, I have refactored the dict handling for both loss
and predict
functions outside the EncoderDecoder
class. The skipped_methods
argument looks like this:
skipped_methods=[
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.predict_by_feat',
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.loss_by_feat',
'mmseg.models.segmentors.encoder_decoder.EncoderDecoder._get_predictions',
'mmseg.models.segmentors.encoder_decoder.EncoderDecoder._get_loss',
'mmseg.models.segmentors.encoder_decoder.PSPHead.loss_by_feat',
'mmseg.models.segmentors.encoder_decoder.FCNHead.loss'
]
)
Can you please confirm if this is correct ?
Also, regarding the postprocess_result method of the class BaseSegmentor, did you just move the entire function out of the class and wrap it with the torch.fx.wrap
decorator?
I guess you can delete some methods from your skipped_method
argument:
skipped_methods=[
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.predict_by_feat',
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.loss_by_feat',
# 'mmseg.models.segmentors.encoder_decoder.EncoderDecoder._get_predictions', --> no such method in native EncoderDecoder
# 'mmseg.models.segmentors.encoder_decoder.EncoderDecoder._get_loss', --> no such method in native EncoderDecoder
# 'mmseg.models.segmentors.encoder_decoder.PSPHead.loss_by_feat', --> already skipped with second element
'mmseg.models.segmentors.encoder_decoder.FCNHead.loss'
]
)
Then, yes I moved the entire postprocess_result
method out of the BaseSegmentor
class by calling a function with the decorator and the code of the original method.
@torch.fx.wrap
def postprocess_result(seg_logits: Tensor,
decode_head_threshold: float,
align_corners: bool,
data_samples: OptSampleList = None) -> SampleList:
# code of the base postprocess_result method
Also, the handling of the loss dictionnaries will be problematic, so you will have to apply the same trick.
You can see this repository of MMDetection for MMRazor for more examples: https://github.com/HIT-cwh/mmdetection/tree/for_mmrazor
As we discussed and following the mmdet logic for mmrazor, I have refactored the dict handling (_get_loss
and _get_predictions
) for both loss
and predict
functions methods outside the EncoderDecoder
class and the postprocess_result
method of the BaseSegmentor
as follows:
@torch.fx.wrap
def _get_loss(self, x: Tensor, data_samples: SampleList) -> dict:
losses = dict()
loss_decode = self._decode_head_forward_train(x, data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
losses.update(loss_aux)
return losses
@torch.fx.wrap
def _get_predictions(self, data_samples, inputs):
if data_samples is not None:
batch_img_metas = [
data_sample.metainfo for data_sample in data_samples
]
else:
batch_img_metas = [
dict(
ori_shape=inputs.shape[2:],
img_shape=inputs.shape[2:],
pad_shape=inputs.shape[2:],
padding_size=[0, 0, 0, 0])
] * inputs.shape[0]
return batch_img_metas
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
x = self.extract_feat(inputs)
losses = _get_loss(x, data_samples)
return losses
def predict(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> SampleList:
batch_img_metas = _get_predictions(data_samples,inputs)
seg_logits = self.inference(inputs, batch_img_metas)
return self.postprocess_result(self.decode_head, seg_logits, data_samples)
@torch.fx.wrap
def _postprocess_result(seg_logits: Tensor,
decode_head_threshold: float,
align_corners: bool,
data_samples: OptSampleList = None) -> SampleList:
# code of the base postprocess_result method
def postprocess_result(self,
seg_logits: Tensor,
data_samples: OptSampleList = None) -> SampleList:
return _postprocess_result(seg_logits, data_samples)
However, I'm still encountering the torch.fx.proxy.TraceError when predict is called:
Traceback (most recent call last):
File "C:\Users\user\Documents\IPC\mmrazor\tools\ptq.py", line 73, in <module>
main()
File "C:\Users\user\Documents\IPC\mmrazor\tools\ptq.py", line 66, in main
runner = Runner.from_cfg(cfg)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 462, in from_cfg
runner = cls(
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 429, in __init__
self.model = self.build_model(model)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 836, in build_model
model = MODELS.build(model)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\registry.py", line 570, in build
return self.build_func(cfg, *args, **kwargs, registry=self)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\build_functions.py", line 232, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\build_functions.py", line 121, in build_from_cfg
obj = obj_cls(**args) # type: ignore
File "c:\users\user\documents\ipc\mmrazor\mmrazor\models\algorithms\quantization\mm_architecture.py", line 90, in __init__
self.qmodels = self._build_qmodels(self.architecture)
File "c:\users\user\documents\ipc\mmrazor\mmrazor\models\algorithms\quantization\mm_architecture.py", line 300, in _build_qmodels
observed_module = self.quantizer.prepare(model, concrete_args)
File "c:\users\user\documents\ipc\mmrazor\mmrazor\models\quantizers\native_quantizer.py", line 231, in prepare
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
File "c:\users\user\documents\ipc\mmrazor\mmrazor\models\task_modules\tracer\fx\custom_tracer.py", line 424, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "c:\users\user\documents\ipc\mmsegmentation\mmseg\models\segmentors\base.py", line 153, in forward
return self.predict(inputs, data_samples)
File "c:\users\user\documents\ipc\mmsegmentation\mmseg\models\segmentors\encoder_decoder.py", line 259, in predict
seg_logits = self.inference(inputs, batch_img_metas)
File "c:\users\user\documents\ipc\mmsegmentation\mmseg\models\segmentors\encoder_decoder.py", line 375, in inference
if not all(_['ori_shape'] == ori_shape for _ in batch_img_metas):
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\torch\fx\proxy.py", line 274, in __iter__
return self.tracer.iter(self)
File "C:\Users\user\anaconda3\envs\mmlab\lib\site-packages\torch\fx\proxy.py", line 183, in iter
raise TraceError('Proxy object cannot be iterated. This can be '
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors
Do I miss something in batch preparation (_get_predictions
) ?
After following the suggestions from the previous responses, I added skipped_methods, which resolved the issue in the backbone. However, I feel that this current error cannot be solved.
This is my config
_base_ = [
'mmdet::rsprompter/samseg-maskrcnn-nwpu.py',
'../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501
]
_base_.val_dataloader.batch_size = 4
test_cfg = dict(
type='mmrazor.PTQLoop',
calibrate_dataloader=_base_.val_dataloader,
calibrate_steps=32,
)
float_checkpoint = '/home/user/RSPrompter_train/mmrazor/pth/seg_mask_base.pth' # noqa: E501
global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1),
)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmdet.DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32),
architecture=_base_.model,
deploy_cfg=_base_.deploy_cfg,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.TensorRTQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501
# 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat',
#test
# 'mmdet.models.dense_heads.rpn_head.RPNHead.loss_by_feat',
# 'mmdet.models.dense_heads.rpn_head.RPNHead._predict_by_feat_single',
'transformers.models.sam.modeling_sam.SamVisionAttention.get_rel_pos',
'mmdet.rsprompter.models.RSFeatureAggregator.change',
])))
model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)
And this is Traceback
Traceback (most recent call last):
File "/home/user/RSPrompter_train/mmrazor/tools/ptq.py", line 73, in <module>
main()
File "/home/user/RSPrompter_train/mmrazor/tools/ptq.py", line 66, in main
runner = Runner.from_cfg(cfg)
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/runner/runner.py", line 462, in from_cfg
runner = cls(
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/runner/runner.py", line 429, in __init__
self.model = self.build_model(model)
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/runner/runner.py", line 836, in build_model
model = MODELS.build(model)
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
return self.build_func(cfg, *args, **kwargs, registry=self)
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 232, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 121, in build_from_cfg
obj = obj_cls(**args) # type: ignore
File "/home/user/RSPrompter_train/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 90, in __init__
self.qmodels = self._build_qmodels(self.architecture)
File "/home/user/RSPrompter_train/mmrazor/mmrazor/models/algorithms/quantization/mm_architecture.py", line 297, in _build_qmodels
observed_module = self.quantizer.prepare(
File "/home/user/RSPrompter_train/mmrazor/mmrazor/models/quantizers/native_quantizer.py", line 231, in prepare
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
File "/home/user/RSPrompter_train/mmrazor/mmrazor/models/task_modules/tracer/fx/custom_tracer.py", line 421, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "/home/user/RSPrompter_train/mmdetection-for_razor/mmdet/models/detectors/base.py", line 103, in forward
return self._forward(inputs, data_samples)
File "/home/user/RSPrompter_train/mmdetection-for_razor/mmdet/models/detectors/two_stage.py", line 134, in _forward
rpn_results_list = self.rpn_head.predict(
File "/home/user/RSPrompter_train/mmdetection-for_razor/mmdet/models/dense_heads/base_dense_head.py", line 208, in predict
predictions = self.predict_by_feat(
File "/home/user/.conda/envs/rsptest/lib/python3.10/site-packages/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py", line 72, in rpn_head__predict_by_feat
mlvl_anchors = self.anchor_generator.grid_anchors(
File "/home/user/RSPrompter_train/mmdetection-for_razor/mmdet/models/task_modules/prior_generators/anchor_generator.py", line 362, in grid_anchors
anchors = self.single_level_grid_anchors(
File "/home/user/RSPrompter_train/mmdetection-for_razor/mmdet/models/task_modules/prior_generators/anchor_generator.py", line 399, in single_level_grid_anchors
shift_x = torch.arange(0, feat_w, device=device) * stride[0]
TypeError: arange() received an invalid combination of arguments - got (int, Proxy, device=Attribute), but expected one of:
* (Number end, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (Number start, Number end, *, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (Number start, Number end, Number step, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
My understanding is that when I add 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat' to skipped_methods, this issue should be resolved, but in fact, it hasn't. So, I would like to ask those who have solved this problem for some advice.
@torch.fx.wrap def _get_loss(self, x: Tensor, data_samples: SampleList) -> dict: losses = dict() loss_decode = self._decode_head_forward_train(x, data_samples) losses.update(loss_decode) if self.with_auxiliary_head: loss_aux = self._auxiliary_head_forward_train(x, data_samples) losses.update(loss_aux) return losses
You can't wrap the entire loss part as it is not going to trace the forward of the head (self.decode_head_forward_train
). You need to skip tracing the parts with for
, if
and while
structures while keeping the foward methods in the tracing.
The predict and postprocessing parts seem ok, so the only thing you have to change in your code is this loss part.
loss
method of the auxiliary head is already handled by the skipped_method
argument of your configuration file, so you don't have to worry about it.loss
method of the main head should be traced, so you have to handle it the same way you did with the inference
method in the predict method you have shown here.update
method of the dictionnary is untraceable, you have to create a specific function with the torch.fx.wrap
decorator that will update the dictionnary out of the class.With this done, it should be working.
Hi @psychedelicosisyphus, I have never faced this issue before. But as I said above, you need to skip tracing the parts with for
, if
and while
structures while keeping the foward methods in the tracing.
So check carrefully what you are skipping by passing methods to the skipped_method
argument. There may be some lines that are parts of the forward of the model and you don't want to skip those. When some methods have mixed traceable and untraceable parts, you have to use the torch.fx.wrap
decorator.
You can't wrap the entire loss part as it is not going to trace the forward of the head (
self.decode_head_forward_train
). You need to skip tracing the parts withfor
,if
andwhile
structures while keeping the foward methods in the tracing.The predict and postprocessing parts seem ok, so the only thing you have to change in your code is this loss part.
- The
loss
method of the auxiliary head is already handled by theskipped_method
argument of your configuration file, so you don't have to worry about it.- The
loss
method of the main head should be traced, so you have to handle it the same way you did with theinference
method in the predict method you have shown here.- Finally as the
update
method of the dictionnary is untraceable, you have to create a specific function with thetorch.fx.wrap
decorator that will update the dictionnary out of the class.With this done, it should be working.
So, based on your last response, I replaced the update()
with the update_losses()
and modified the loss()
in (encoder_decoder.py) as following :
@torch.fx.wrap
def update_losses(losses_dict, new_losses):
for key, value in new_losses.items():
if key in losses_dict:
# If the key already exists, sum the losses (or apply any logic you want)
losses_dict[key] += value
else:
# If the key does not exist, add the new key-value pair
losses_dict[key] = value
return losses_dict
@torch.fx.wrap
def _get_loss(self, x: Tensor, data_samples: SampleList) -> dict:
loss_aux = {}
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, data_samples)
return loss_aux
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
x = self.extract_feat(inputs)
losses = dict()
loss_decode = self._decode_head_forward_train(x, data_samples)
losses = update_losses(losses, loss_decode)
loss_aux = _get_loss(x, data_samples)
losses = update_losses(losses, loss_aux)
return losses
skipped_methods=[
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.predict_by_feat',
'mmseg.models.decode_heads.decode_head.BaseDecodeHead.loss_by_feat',
'mmseg.models.decode_heads.FCNHead.loss'
]
However, I'm still encountering a TraceError in the add_prefix method, which I also refactored using @torch.fx.wrap
Traceback (most recent call last):
File "C:\Users\kosta\Documents\IPC\Quantization\mmrazor\tools\ptq.py", line 73, in <module>
main()
File "C:\Users\kosta\Documents\IPC\Quantization\mmrazor\tools\ptq.py", line 66, in main
runner = Runner.from_cfg(cfg)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 462, in from_cfg
runner = cls(
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 429, in __init__
self.model = self.build_model(model)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\runner\runner.py", line 836, in build_model
model = MODELS.build(model)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\registry.py", line 570, in build
return self.build_func(cfg, *args, **kwargs, registry=self)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\build_functions.py", line 232, in build_model_from_cfg
return build_from_cfg(cfg, registry, default_args)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\mmengine\registry\build_functions.py", line 121, in build_from_cfg
obj = obj_cls(**args) # type: ignore
File "c:\users\kosta\documents\ipc\quantization\mmrazor\mmrazor\models\algorithms\quantization\mm_architecture.py", line 90, in __init__
self.qmodels = self._build_qmodels(self.architecture)
File "c:\users\kosta\documents\ipc\quantization\mmrazor\mmrazor\models\algorithms\quantization\mm_architecture.py", line 300, in _build_qmodels
observed_module = self.quantizer.prepare(model, concrete_args)
File "c:\users\kosta\documents\ipc\quantization\mmrazor\mmrazor\models\quantizers\native_quantizer.py", line 231, in prepare
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
File "c:\users\kosta\documents\ipc\quantization\mmrazor\mmrazor\models\task_modules\tracer\fx\custom_tracer.py", line 421, in trace
'output', (self.create_arg(fn(*args)), ), {},
File "c:\users\kosta\documents\ipc\quantization\mmsegmentation_for_mmrazor\mmseg\models\segmentors\base.py", line 151, in forward
return self.loss(inputs, data_samples)
File "c:\users\kosta\documents\ipc\quantization\mmsegmentation_for_mmrazor\mmseg\models\segmentors\encoder_decoder.py", line 241, in loss
loss_decode = self._decode_head_forward_train(x, data_samples)
File "c:\users\kosta\documents\ipc\quantization\mmsegmentation_for_mmrazor\mmseg\models\segmentors\encoder_decoder.py", line 201, in _decode_head_forward_train
losses = update_losses(losses, add_prefix(loss_decode, 'decode'))
File "c:\users\kosta\documents\ipc\quantization\mmsegmentation_for_mmrazor\mmseg\utils\misc.py", line 24, in add_prefix
for name, value in inputs.items():
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\torch\fx\proxy.py", line 274, in __iter__
return self.tracer.iter(self)
File "C:\Users\kosta\anaconda3\envs\mmlab\lib\site-packages\torch\fx\proxy.py", line 183, in iter
raise TraceError('Proxy object cannot be iterated. This can be '
torch.fx.proxy.TraceError: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errors
Describe the bug
torch.fx.proxy.TraceError: class
MMArchitectureQuant
in mmrazor/models/algorithms/quantization/mm_architecture.py: Proxy object cannot be iterated. This can be attempted when the Proxy is used in a loop or as a *args or **kwargs function argument. See the torch.fx docs on pytorch.org for a more detailed explanation of what types of control flow can be traced, and check out the Proxy docstring for help troubleshooting Proxy iteration errorsI am currently trying to quantify the segmentation model, and the configuration file is as follows Then I reported the bug above Can you help me check how to solve it? Thank you.
The base configuration file is a segmentation model I modified based on DDRNet, with only 3 categories, and all other configurations are consistent
base = [ 'mmseg::ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k-1024x1024_label3.py', '../../deploy_cfgs/mmseg/set_tensorrt-int8-explicit-1024x1024_label3.py' ]
base.val_dataloader.batch_size = 32
test_cfg = dict( type='mmrazor.PTQLoop', calibrate_dataloader=base.val_dataloader, calibrate_steps=32, )
float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501
global_qconfig = dict( w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), w_fake_quant=dict(type='mmrazor.FakeQuantize'), a_fake_quant=dict(type='mmrazor.FakeQuantize'), w_qscheme=dict( qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), ) crop_size = (1024, 1024) model = dict( delete=True, type='mmrazor.MMArchitectureQuant', data_preprocessor = dict( type='mmseg.SegDataPreProcessor', size=crop_size, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], bgr_to_rgb=True, pad_val=0, seg_pad_val=255), architecture=base.model, deploy_cfg=base.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.TensorRTQuantizer', global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', skipped_methods=[ 'mmseg.models.decode_heads.ddr_head.DDRHead.loss_by_feat', ])))
model_wrapper_cfg = dict( type='mmrazor.MMArchitectureQuantDDP', broadcast_buffers=False, find_unused_parameters=True)
custom_hooks = []
May I ask where my configuration file is written incorrectly? thanke you!