AliaksandrSiarohin / monkey-net

Animating Arbitrary Objects via Deep Motion Transfer
467 stars 81 forks source link

Questions about kp2gaussian and gaussian2kp #20

Open zhuhaozh opened 4 years ago

zhuhaozh commented 4 years ago

Hi, Thanks for your interesting work! I have some questions about the kp2gaussian and gaussian2kp in keypoint_detector.py

  1. Are these functions invertible? I use a pretrained hourglass network to extract heatmap, by feed it to gaussion2kp and kp2gaussion, the output become weird.

image

  1. To calculate the mean of heatmap, why apply "sum" function on it? mean = (heatmap * grid).sum(dim=(3, 4))

  2. If I have use pretrained landmark detector, for example, facial landmark detector, how should I modify the code?

AliaksandrSiarohin commented 4 years ago

Hi, 1) kp2gaussian(gaussian2kp(x)) != x, is this is what are you asking? This is the case because, gaussian2kp produce unormalized heatmaps. E.g. The heatmaps that does not sum to one, while kp2gaussian requires maps like this. Why do you need to invert it? 2) This operation is called soft-argmax, and it computes the weighted mean of the coordinate grid. More formally heatmap defines a probability distribution over the image coordinates, and we compute the mean coordinate given this probability distribution. 3) Do you have a pytorch landmark detector? In that case follow these steps:

  1. Replace architecture of the keypoint_detector with your architecture. Modify the forward method so that it return a dict with a single value ['mean']. It should have the following shape [bs, num_kp, 2].
  2. Modify the config.yaml, replace num_kp with number of keypoiths in your case and modify the kp_var from 'matrix' to 0.01.
  3. Modify train.py, load your keypoint detector weights and commend lines wich do optimizer_kp_detector.step().

If you don't have acces to the architecrure of detector or it is complicated, you can do the following:

  1. Compute the keypoints for all the videos and all the frames offline and save it to some file.
  2. In frames dataset load this file and for each frame load the appropriate keypoints. Alternatively you can just run you keypoint detector on your frames at this point. Without offline computation and file saving.

Put the keypoints in the out dict under the key 'kp'.

  1. Replace keypoint detector with some dummy class that will just forward you the keypoints. In other words just do out['mean'] = x['kp'].
  2. Modify the config and train as in previous method.