Open lishiyu005 opened 9 months ago
Currently, RTMO only supports single-class pose estimation. For multi-class detection task, have you tried to increase the num_classes
in the config? BTW, categroy_id
should start from 1 instead of 0.
Hello! Are there are any models in MMpose that support multi-class pose estimation? From my understanding, it is possible to use a detector to categorize each class and then apply a separate top-down model like RTMPose for each class. However, I prefer not to train and load multiple pose models, especially since my classes are quite similar to each other, albeit with differences in the number and type of keypoints.
HI! I experienced the same problem. the error seems to be in the SimOTAAssigner class in mmpose\models\task_modules\assigners\sim_ota_assigner.py
Line 158: F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1]).float().unsqueeze(0).repeat( num_valid, 1, 1))
The error is thrown here, this is because pred_scores.shape[-1]
is always 1 but the gt_labels
contain labels for your number of classes. When I changed pred_scores.shape[-1]
to the actual number of classes, in my case 8, it passes.
However the next error arises in
Line 166:F.binary_cross_entropy( valid_pred_scores.to(dtype=torch.float32), gt_onehot_label, reduction='none', ).sum(-1).to(dtype=valid_pred_scores.dtype))
I believe this is now due to not valid_pred_scores
not matching the dimension of gt_onehot_label
.
Can anybody explain what should be implemented to make RTMO work for multiple classes?
I manage to solve this problem to get rtmo work for multiple classes.
Here is my solution: mmpose/models/heads/hybrid_heads/yoloxpose_head.py change the order of arguments of YOLOXposehead module: line 22: class YOLOXPoseHeadModule(BaseModule): """YOLOXPose head module for one-stage human pose estimation. def init( self, num_keypoints: int, num_classes: int, in_channels: Union[int, Sequence],
line 224: class YOLOXPoseHead(BaseModule): def init( self, num_keypoints: int, num_classes: int, head_module_cfg: Optional[ConfigType] = None,
mmpose/models/heads/hybrid_heads/rtmo_head.py line 25: class RTMOHeadModule(BaseModule): def init( self, num_keypoints: int, num_classes: int, in_channels: int,
line 662: class RTMOHead(YOLOXPoseHead): def init( self, num_keypoints: int, num_classes: int, head_module_cfg: ConfigType, line 732
head_module_cfg['featmap_strides'] = featmap_strides
head_module_cfg['num_keypoints'] = num_keypoints
head_module_cfg['num_classes'] = num_classes
line 811
extra_info = dict(num_samples=num_total_samples)
losses = dict()
cls_preds_all = flatten_cls_scores.view(-1, self.num_classes)
cls_target_all = torch.zeros_like(cls_preds_all).to(obj_targets)
line 854: with torch.no_grad(): diff_cc = torch.norm(kpt_cc_preds - kpt_targets, dim=-1) diff_reg = torch.norm(kpt_reg_preds - kpt_targets, dim=-1) mask = (diff_reg > diff_cc).float() kpt_weights_reg = vis_targets * mask
#cls_targets = oks.unsqueeze(1)
losses['loss_oks'] = self.loss_oks(kpt_reg_preds,
kpt_cc_preds.detach(),
kpt_weights_reg, pos_areas)
line 878: extra_info['overlaps'] = cls_targets cls_targets = cls_targets.pow(self.overlaps_power).detach() cls_target_all[pos_masks] = cls_targets
losses['loss_cls'] = self.loss_cls(cls_preds_all, cls_target_all,
obj_weights) / num_total_samples
Is it possible to make it like kapao in the sense that there is 1 person class and other classes for each keypoint box ? because i want box for the keypoints as well.
@lishiyu005 May I ask how the image data of multiple categories of yours is constructed? Is it the target of multiple categories on a graph to be trained? Mine is a graph as a category, and I encounter grad_norm: nan during the training process, is it because there is a problem with the construction of my training data
What is the feature?
Current RTMO is designed for one class human keypoint detection. I try to train it for multi-class keypoint detection, but get a error ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [0,0,0] Assertion
main()
File "D:\code\mmpose\tools\train_rtm6d.py", line 159, in main
runner.train()
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\runner\runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 96, in run
self.run_epoch()
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 112, in run_epoch
self.run_iter(idx, data_batch)
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 128, in run_iter
outputs = self.runner.model.train_step(
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 114, in train_step
losses = self._run_forward(data, mode='loss') # type: ignore
File "C:\Users.conda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 361, in _run_forward
results = self(data, mode=mode)
File "C:\Users.conda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "C:\Users.conda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(args, kwargs)
File "D:\code\mmpose\mmpose\models\pose_estimators\base.py", line 155, in forward
return self.loss(inputs, data_samples)
File "D:\code\mmpose\mmpose\models\pose_estimators\bottomup.py", line 70, in loss
self.head.loss(feats, data_samples, train_cfg=self.train_cfg))
File "D:\code\mmpose\mmpose\models\heads\hybrid_heads\rtmo_head.py", line 793, in loss
targets = self._get_targets(flatten_priors,
File "C:\Users.conda\envs\openmmlab\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, *kwargs)
File "D:\code\mmpose\mmpose\models\heads\hybrid_heads\yoloxpose_head.py", line 410, in _get_targets
target = self._get_targets_single(priors, batch_cls_scores[i],
File "C:\Users.conda\envs\openmmlab\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(args, **kwargs)
File "D:\code\mmpose\mmpose\models\heads\hybrid_heads\yoloxpose_head.py", line 532, in _get_targets_single
assign_result = self.assigner.assign(
File "D:\code\mmpose\mmpose\models\task_modules\assigners\sim_ota_assigner.py", line 158, in assign
F.one_hot(gt_labels.to(torch.int64),
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:365: operator(): block: [0,0,0], thread: [1,0,0] Assertionidx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed. Traceback (most recent call last): File "D:\code\mmpose\tools\train_rtm6d.py", line 163, inTORCH_USE_CUDA_DSA
to enable device-side assertions.Is there any way to adapt RTMO for multiple-class keypoints detection?
Another bug of RTMO is the category id can only be 1. If I define a categroy id with 0, then I get the training labels with -1. Finally I received a out of bound error. In my case, I defined the category id from 0 to 9, I received this error. Is there anyway to fix this?
Thanks in advance.
Best regards
Any other context?
No response