Open vlad-filin opened 5 years ago
@vlad-filin Hi, @penincillin and I had discussions in #200 regarding similar questions with yours. Specifically,
Yes. In line26 of body_uv_rcnn_heads.py, 15
indicates the number of semantically meaningful body parts used to sample points for annotators, as mentioned in the CVPR 2018 paper. For a better understanding, I made it as a configurable parameter named BODY_UV_RCNN.NUM_SEMANTIC_PARTS
(like BODY_UV_RCNN.NUM_PATCHES
) in config.py.
I have verified that the maximum annotated points within a person bounding box is "184" instead of 196 over all datasets (train
, valminusminival
, minival
). In my opinion, the number "196" stands for the number of pixels in a feature map output by RoIAlign
, which will be taken as inputs to the "body_uv_rcnn" head network. The evidence to support my point not only exists in that the feature map size is exactly 14 x 14
(which was 7 x 7
previously in Faster-RCNN and Mask-RCNN), but also can be found in the function add_body_uv_rcnn_blobs() where targets blobs with shape (num_fg_rois, 196)
of body UV supervisions for a minibatch are constructed, and pool_points_interp.cu which is the GPU implementation of the operator PoolPointsInterp
that is used to bilinearly interpolate points from estimated heatmaps (Index_UV
, U_estimated
and V_estimated
) in body_uv_rcnn_heads.py before computing the SoftmaxLoss
for patch index classification and the SmoothL1Loss
for UV coordinates regression.
My last comment in #200 may give you a hint on your third question.
So far, I am also not pretty sure why the number of output channels of AnnIndex
must be equal to BODY_UV_RCNN.NUM_PATCHES + 1
. Since its output is used to compute a SpatialSoftmaxLoss
, so I guess the number of output channels can also be set as BODY_UV_RCNN.NUM_SEMANTIC_PARTS + 1
. However, I did find an evidence for such a design. In the inference of DensePose, there is a post-processing step which multiplies (in an element-wise way) the estimated patch index heat maps Index_UV
with binary dense masks computed by AnnIndex
, which requires that the tensor shapes of these two outputs must be equal.
Last but not least, I think the author left behind many experimental codes in this released version, which inevitably makes it difficult for many of us to understand some related and important implementations. Therefore, I took a few days to deeply study the original codes in this repo and some dependent API functions from Detectron and Caffe2. With my deeper understanding, I refined the related codes and fixed some minor bugs regarding this issue and other issues (#191, #194, #200, #202, #203, #206, #211) in my repo. I haven't finished training the baseline model using the refined version yet. So I will give more updates in a new issue or create a PR after my modifications will be fully verified.
@vlad-filin Hi, @penincillin and I had discussions in #200 regarding similar questions with yours. Specifically,
- Yes. In line26 of body_uv_rcnn_heads.py,
15
indicates the number of semantically meaningful body parts used to sample points for annotators, as mentioned in the CVPR 2018 paper. For a better understanding, I made it as a configurable parameter namedBODY_UV_RCNN.NUM_SEMANTIC_PARTS
(likeBODY_UV_RCNN.NUM_PATCHES
) in config.py.- I have verified that the maximum annotated points within a person bounding box is "184" instead of 196 over all datasets (
train
,valminusminival
,minival
). In my opinion, the number "196" stands for the number of pixels in a feature map output byRoIAlign
, which will be taken as inputs to the "body_uv_rcnn" head network. The evidence to support my point not only exists in that the feature map size is exactly14 x 14
(which was7 x 7
previously in Faster-RCNN and Mask-RCNN), but also can be found in the function add_body_uv_rcnn_blobs() where targets blobs with shape(num_fg_rois, 196)
of body UV supervisions for a minibatch are constructed, and pool_points_interp.cu which is the GPU implementation of the operatorPoolPointsInterp
that is used to bilinearly interpolate points from estimated heatmaps (Index_UV
,U_estimated
andV_estimated
) in body_uv_rcnn_heads.py before computing theSoftmaxLoss
for patch index classification and theSmoothL1Loss
for UV coordinates regression.- My last comment in #200 may give you a hint on your third question.
- So far, I am also not pretty sure why the number of output channels of
AnnIndex
must be equal toBODY_UV_RCNN.NUM_PATCHES + 1
. Since its output is used to compute aSpatialSoftmaxLoss
, so I guess the number of output channels can also be set asBODY_UV_RCNN.NUM_SEMANTIC_PARTS + 1
. However, I did find an evidence for such a design. In the inference of DensePose, there is a post-processing step which multiplies (in an element-wise way) the estimated patch index heat mapsIndex_UV
with binary dense masks computed byAnnIndex
, which requires that the tensor shapes of these two outputs must be equal.Last but not least, I think the author left behind many experimental codes in this released version, which inevitably makes it difficult for many of us to understand some related and important implementations. Therefore, I took a few days to deeply study the original codes in this repo and some dependent API functions from Detectron and Caffe2. With my deeper understanding, I refined the related codes and fixed some minor bugs regarding this issue and other issues (#191, #194, #200, #202, #203, #206, #211) in my repo. I haven't finished training the baseline model using the refined version yet. So I will give more updates in a new issue or create a PR after my modifications will be fully verified.
All modifications and refinements can be seen in PR #215. Here's a comparison of results on densepose_coco_minival
dataset for the baseline model (ResNet50_FPN_s1x
) between the "before" and "after" modification. (I have only 2 GPUs, so some results are slightly lower than that reported by the author here):
Before:
INFO json_dataset_evaluator.py: 227: ~~~~ Summary metrics ~~~~
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.533
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.844
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.574
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.259
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.520
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.667
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.207
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.549
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.595
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.336
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.580
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.727
INFO json_dataset_evaluator.py: 194: Wrote json eval results to: tmp/detectron-output/test/dense_coco_2014_minival/generalized_rcnn/detection_results.pkl
INFO task_evaluation.py: 55: Evaluating bounding boxes is done!
INFO task_evaluation.py: 151: Evaluating body uv
INFO json_dataset_evaluator.py: 470: Collecting person results (1/1)
INFO json_dataset_evaluator.py: 484: Writing body uv results pkl to: /home/qinchuan.zqc/densepose/tmp/detectron-output/test/dense_coco_2014_minival/generalized_rcnn/body_uv_dense_coco_2014_minival_results.pkl
Loading and preparing results...
DONE (t=0.05s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *uv*
Loading densereg GT..
/home/qinchuan.zqc/densepose/detectron/datasets/../../DensePoseData/eval_data/
Loaded
DONE (t=178.65s).
Accumulating evaluation results...
('Categories:', [1])
('Final', 1.0, 0.0)
DONE (t=0.13s).
Average Precision (AP) @[ OGPS=0.50:0.95 | area= all | maxDets= 20 ] = 0.466
Average Precision (AP) @[ OGPS=0.50 | area= all | maxDets= 20 ] = 0.840
Average Precision (AP) @[ OGPS=0.55 | area= all | maxDets= 20 ] = 0.801
Average Precision (AP) @[ OGPS=0.60 | area= all | maxDets= 20 ] = 0.751
Average Precision (AP) @[ OGPS=0.65 | area= all | maxDets= 20 ] = 0.685
Average Precision (AP) @[ OGPS=0.70 | area= all | maxDets= 20 ] = 0.596
Average Precision (AP) @[ OGPS=0.75 | area= all | maxDets= 20 ] = 0.468
Average Precision (AP) @[ OGPS=0.80 | area= all | maxDets= 20 ] = 0.321
Average Precision (AP) @[ OGPS=0.85 | area= all | maxDets= 20 ] = 0.159
Average Precision (AP) @[ OGPS=0.90 | area= all | maxDets= 20 ] = 0.038
Average Precision (AP) @[ OGPS=0.95 | area= all | maxDets= 20 ] = 0.002
Average Precision (AP) @[ OGPS=0.50:0.95 | area=medium | maxDets= 20 ] = 0.410
Average Precision (AP) @[ OGPS=0.50:0.95 | area= large | maxDets= 20 ] = 0.486
Average Recall (AR) @[ OGPS=0.50:0.95 | area= all | maxDets= 20 ] = 0.560
Average Recall (AR) @[ OGPS=0.50 | area= all | maxDets= 20 ] = 0.905
Average Recall (AR) @[ OGPS=0.75 | area= all | maxDets= 20 ] = 0.590
Average Recall (AR) @[ OGPS=0.50:0.95 | area=medium | maxDets= 20 ] = 0.440
Average Recall (AR) @[ OGPS=0.50:0.95 | area= large | maxDets= 20 ] = 0.568
INFO task_evaluation.py: 67: Evaluating body uv is done!
INFO task_evaluation.py: 194: copypaste: Dataset: dense_coco_2014_minival
INFO task_evaluation.py: 196: copypaste: Task: box
INFO task_evaluation.py: 199: copypaste: AP,AP50,AP75,APs,APm,APl
INFO task_evaluation.py: 200: copypaste: 0.5332,0.8442,0.5741,0.2587,0.5200,0.6670
INFO task_evaluation.py: 196: copypaste: Task: body_uv
INFO task_evaluation.py: 199: copypaste: AP,AP50,AP75,APm,APl
INFO task_evaluation.py: 200: copypaste: 0.4662,0.8401,0.4679,0.4105,0.4864
After:
INFO json_dataset_evaluator.py: 227: ~~~~ Summary metrics ~~~~
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.538
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.846
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.587
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.265
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.523
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.673
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.208
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.554
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.598
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.340
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.580
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.732
INFO json_dataset_evaluator.py: 194: Wrote json eval results to: tmp/repro_ResNet50_FPN_2GPUs/test/dense_coco_2014_minival/generalized_rcnn/detection_results.pkl
INFO task_evaluation.py: 55: Evaluating bounding boxes is done!
INFO task_evaluation.py: 151: Evaluating body uv
INFO json_dataset_evaluator.py: 467: Collecting person results (1/1)
INFO json_dataset_evaluator.py: 475: Writing body uv results pkl to: /home/qinchuan.zqc/densepose/tmp/repro_ResNet50_FPN_2GPUs/test/dense_coco_2014_minival/generalized_rcnn/body_uv_dense_coco_2014_minival_results.pkl
Loading and preparing results...
DONE (t=0.04s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *uv*
Loading densereg GT from /home/qinchuan.zqc/densepose/detectron/datasets/../../DensePoseData/eval_data/
densereg GT loaded
DONE (t=163.10s).
Accumulating evaluation results...
('Categories ids:', [1])
Final precisions, max: 1.00, min: 0.00
DONE (t=0.14s).
Average Precision (AP) @[ GPS=0.50:0.95 | area= all | maxDets= 20 ] = 0.470
Average Precision (AP) @[ GPS=0.50 | area= all | maxDets= 20 ] = 0.838
Average Precision (AP) @[ GPS=0.55 | area= all | maxDets= 20 ] = 0.809
Average Precision (AP) @[ GPS=0.60 | area= all | maxDets= 20 ] = 0.756
Average Precision (AP) @[ GPS=0.65 | area= all | maxDets= 20 ] = 0.692
Average Precision (AP) @[ GPS=0.70 | area= all | maxDets= 20 ] = 0.596
Average Precision (AP) @[ GPS=0.75 | area= all | maxDets= 20 ] = 0.466
Average Precision (AP) @[ GPS=0.80 | area= all | maxDets= 20 ] = 0.331
Average Precision (AP) @[ GPS=0.85 | area= all | maxDets= 20 ] = 0.159
Average Precision (AP) @[ GPS=0.90 | area= all | maxDets= 20 ] = 0.048
Average Precision (AP) @[ GPS=0.95 | area= all | maxDets= 20 ] = 0.002
Average Precision (AP) @[ GPS=0.50:0.95 | area=medium | maxDets= 20 ] = 0.418
Average Precision (AP) @[ GPS=0.50:0.95 | area= large | maxDets= 20 ] = 0.488
Average Recall (AR) @[ GPS=0.50:0.95 | area= all | maxDets= 20 ] = 0.563
Average Recall (AR) @[ GPS=0.50 | area= all | maxDets= 20 ] = 0.901
Average Recall (AR) @[ GPS=0.75 | area= all | maxDets= 20 ] = 0.585
Average Recall (AR) @[ GPS=0.50:0.95 | area=medium | maxDets= 20 ] = 0.443
Average Recall (AR) @[ GPS=0.50:0.95 | area= large | maxDets= 20 ] = 0.571
INFO task_evaluation.py: 67: Evaluating body uv is done!
INFO task_evaluation.py: 194: copypaste: Dataset: dense_coco_2014_minival
INFO task_evaluation.py: 196: copypaste: Task: box
INFO task_evaluation.py: 199: copypaste: AP,AP50,AP75,APs,APm,APl
INFO task_evaluation.py: 200: copypaste: 0.5377,0.8460,0.5872,0.2646,0.5227,0.6726
INFO task_evaluation.py: 196: copypaste: Task: body_uv
INFO task_evaluation.py: 199: copypaste: AP,AP50,AP75,APm,APl
INFO task_evaluation.py: 200: copypaste: 0.4697,0.8380,0.4661,0.4180,0.4880
Hi @Johnqczhang, I think there is a problem in Annindex_lowers and AnnIndex. the output shape of AnnIndex_lowers is 15. However, if you check the code:blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE).
cfg.BODY_UV_RCNN.NUM_PATCHES+1 is equal to 25. for AnnIndex, its input dim and output dim are the same : 25. Moreover, if you check the pre-trained file. the shape of AnnIndex's weight is (25,25,4,4). Which is not very possible to use AnnIndex_lowers directly as AnnIndex's input because its output dim is 15.
Hi @lizenan, sorry for the late reply. As I discussed with @penincillin in #200 (you can see my last comment), for AnnIndex
, its input dim and output dim must be the same by the current implementation of BilinearInterpolation
which essentially is a ConvTranspose
layer, where you can find that its output dim is determined by the number of input kernels and is not related with the input_blob size (in this case, i.e., AnnIndex_lowres
). So as long as you specified the same number of the input dim and output dim of AnnIndex
like the following code, you can specify any number for the output dim of AnnIndex_lowres
, except that values in extra channels are all 0s.
blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE).
For AnnIndex_lowres
, I still can't find any clear explanation about why its output dim is 15 instead of 25. So I replaced its output dim from 15 to 25 to keep the same number with the input dim of AnnIndex
such that all channels have response now. After I trained ResNet50_FPN_s1x
models using the same hyperparameter setting as the author but different output dim of AnnIndex_lowres
, here are results I got in the test set (densepose_coco_minival2014
):
Model | AP | AP50 | AP75 | APm | APl |
---|---|---|---|---|---|
model in MODEL_ZOO (out_dim=15) | 0.4748 | 0.8368 | 0.4820 | 0.4262 | 0.4948 |
my reproduced model (out_dim=15) | 0.4717 | 0.8377 | 0.4821 | 0.4044 | 0.4928 |
my reproduced model (out_dim=25) | 0.4764 | 0.8422 | 0.4928 | 0.4207 | 0.4950 |
The index 15 is because they use 14 segmented body parts to collect annotations (before any surface correspondence is imposed). This is used as an auxiliary loss during training and supervised using dp_masks (segmentation masks from annotation stage 1). More details here:
https://github.com/facebookresearch/DensePose/blob/master/notebooks/DensePose-COCO-Visualize.ipynb
Hello! Thank you for providing code, it gives a chance to fully understand how model works.
I have several questions about constants mentioned in body_uv_rcnn_heads.py file. They have no description or even name, just a number in code (e.g. line 26 number 15). Questions: 1) Line 26 : " model.ConvTranspose(blob_in, 'AnnIndex_lowres'+pref, dim, 15,...". I have a guess that 15 stands for number of annotations classes (14) + 1 (background). It would be nice to make it a config parameter(like BODY_UV_RCNN.NUM_PATCHES), or at least highlight the meaning of this constant in comments in body_uv_rcnn_heads.py 2) Line 65 " ### Now reshape UV blobs, such that they are 1x1x(196 NumSamples)xNUM_PATCHES" and line 70 " ... , shape=(-1,cfg.BODY_UV_RCNN.NUM_PATCHES+1,196))". In article "Dense Human Pose Estimation In The Wild" it was mentioned that there are <= 14 points per one part of body, and there are 14 semantic parts of body in COCO DensePose Dataset, so i have a guess that it stands for max points all semantic parts, but i am not sure about this. It would be nice to provide this constant(196) a description.
I also have a question about transformation "AnnIndex_lowres" to "AnnIndex". This transfromation is done via bilinear interpolation and semantically shouldn't change the number of tensor's channel( and for transformations "Index_UV_lowres" to "Index_UV", "U_lowres" to "U_estimated", "V_lowres" to "V_estimated" number of channels is immutable). But at the same time:
at line 26: model.ConvTranspose(blob_in, 'AnnIndex_lowres'+pref, dim, 15,cfg.BODY_UV_RCNN.DECONV_KERNEL, pad=int(cfg.BODY_UV_RCNN.DECONV_KERNEL / 2 - 1), stride=2, weight_init=(cfg.BODY_UV_RCNN.CONV_INIT, {'std': 0.001}), bias_init=('ConstantFill', {'value': 0.}))
at line 46: blob_Ann_Index = model.BilinearInterpolation('AnnIndex_lowres'+pref, 'AnnIndex'+pref, cfg.BODY_UV_RCNN.NUM_PATCHES+1 , cfg.BODY_UV_RCNN.NUM_PATCHES+1, cfg.BODY_UV_RCNN.UP_SCALE)
So, I have questions: 3) in docs of detector.BilinearInterpolation ( detector.py lines 330 -334) mentioned that number of input channels is equal to number of output channels, but at the same time input blob "AnnIndex_lowres" has 15 channels, and output blob "AnnIndex" has 25 channels.How is this possible? I am not familiar with caffe2, but BilinearInterpolation in this project is implemented as ConvTranspose layer with fixed weights. 4) Why number of output channels of "AnnIndex" must be equal to cfg.BODY_UV_RCNN.NUM_PATCHES+1 (in COCO DensePose dataset there are 14 semantic classes for masks)?
I also provide part of log in which this change of channels are highlighted. This log was created by running "python2 tools/train_net.py --cfg configs/DensePose_ResNet50_FPN_single_GPU.yaml OUTPUT_DIR /tmp/detectron-output".
INFO net.py: 241: body_conv_fcn8 : (3, 512, 14, 14) => AnnIndex_lowres : (3, 15, 28, 28) ------- (op: ConvTranspose) INFO net.py: 241: body_conv_fcn8 : (3, 512, 14, 14) => Index_UV_lowres : (3, 25, 28, 28) ------- (op: ConvTranspose) INFO net.py: 241: body_conv_fcn8 : (3, 512, 14, 14) => U_lowres : (3, 25, 28, 28) ------- (op: ConvTranspose) INFO net.py: 241: body_conv_fcn8 : (3, 512, 14, 14) => V_lowres : (3, 25, 28, 28) ------- (op: ConvTranspose) INFO net.py: 241: AnnIndex_lowres : (3, 15, 28, 28) => AnnIndex : (3, 25, 56, 56) ------- (op: ConvTranspose) INFO net.py: 241: Index_UV_lowres : (3, 25, 28, 28) => Index_UV : (3, 25, 56, 56) ------- (op: ConvTranspose) INFO net.py: 241: U_lowres : (3, 25, 28, 28) => U_estimated : (3, 25, 56, 56) ------- (op: ConvTranspose) INFO net.py: 241: V_lowres : (3, 25, 28, 28) => V_estimated : (3, 25, 56, 56) ------- (op: ConvTranspose)
Thank you for your time and hope to hear from you soon!