fire717 / movenet.pytorch

A Pytorch implementation of MoveNet from Google. Include training code and pre-trained model.
MIT License
374 stars 87 forks source link

关于reg操作在训练时体现 #14

Closed cassie101 closed 2 years ago

cassie101 commented 2 years ago

https://github.com/fire717/movenet.pytorch/blob/95ec8535245228aa4335243e68722810e50bcaf8/lib/task/task_tools.py#L124-L144

您好,在prediction 阶段可以到 reg 进行了这样的一个处理, 也是TF的处理方式 48x48 根据x,y轴 以0到47的生成的坐标 x' = (rangeweight - reg)^2 tmp_reg = (x'+y')^0.5 + 1.8 keypoint_heatmap/tmp_reg -> 在取 maxpoint 的 reg_x, reg_y

这样的一个操作请问在训练哪部分体现呢? 另外,这样得到的reg_x, reg_y 和 reg heatmap 直接得到的坐标有什么区别呢? 多谢

edit:请问在微信交流一下其他细节方便吗?

fire717 commented 2 years ago

这相当于输出后处理了,在训练的时候没有,因为训练的时候只是针对四个head训练。只要四个head训练准了,后处理输出的最终坐标也就准了。

有问题就在这里问吧,也方便后来者查阅。

cassie101 commented 2 years ago

您好,请教一下在训练centre heatmap的时候加了这样的权重,相当于官方提的 "weighted by the inverse-distance from the frame center" https://github.com/fire717/movenet.pytorch/blob/95ec8535245228aa4335243e68722810e50bcaf8/lib/loss/movenet_loss.py#L303-L305 但对于官方提到 “Each pixel in the keypoint heatmap is multiplied by a weight which is inversely proportional to the distance from the corresponding regressed keypoint” 也就是prediction 中的

48x48 根据x,y轴 以0到47的生成的坐标
x' = (rangeweight - reg)^2
tmp_reg = (x'+y')^0.5 + 1.8
keypoint_heatmap/tmp_reg -> 在取 maxpoint 的 reg_x, reg_y

却没在keypoint heatmap 的训练做相应的处理,请问这其中有什么考量吗?

ref: https://blog.tensorflow.org/2021/05/next-generation-pose-detection-with-movenet-and-tensorflowjs.html

fire717 commented 2 years ago

这都属于后处理操作,没必要加入到训练中。 加入到训练写起来操作很复杂(比如不同特征图的不同index操作),而且不一定可导(比如argmax操作之类的)需要转换等价实现比较麻烦,而且就算实现了也只是相当于一个相关性很强的multi-task learning,训出来可能区别不大。

还是上面那句话,只要保证四个head的输出是准确的,那么最终结果肯定也是准确的。

因此我不是很明白为什么你会纠结这个问题,很多任务中都是训练head然后后处理来输出最终目标(比如目标检测输出相对于anchor的偏移值而不是直接输出检测框的绝对坐标),当然如果你感兴趣也可以尝试下加入到训练中看是否有提升,到时候可以反馈~

cassie101 commented 2 years ago

好的多谢答复