cherubicXN / hawp

Holistically-Attracted Wireframe Parsing [TPAMI'23] & [CVPR' 20]
MIT License
291 stars 51 forks source link

Can't reproduce the sAP10 result #3

Closed alwc closed 4 years ago

alwc commented 4 years ago

Hi @cherubicXN ,

After training the model with your given code and data, I tested with

$ CUDA_VISIBLE_DEVICES=0, python scripts/test.py --config-file config-files/hawp.yaml                                                                            
2020-05-26 11:02:32,633 hawp INFO: Namespace(config_file='config-files/hawp.yaml', display=False, opts=[])
2020-05-26 11:02:32,633 hawp INFO: Loaded configuration file config-files/hawp.yaml
2020-05-26 11:02:34,451 hawp.testing INFO: Loading checkpoint from outputs/hawp/model_00030.pth
2020-05-26 11:02:34,630 hawp.testing INFO: Testing on wireframe_test dataset
100%|█████████| 462/462 [00:14<00:00, 32.98it/s]
2020-05-26 11:02:48,641 hawp.testing INFO: Writing the results of the wireframe_test dataset into outputs/hawp/wireframe_test.json

and evaluated with

$ python scripts/eval_sap.py --path outputs/hawp/wireframe_test.json --threshold 10
sAP10.0 = 58.4

Note that I'm getting sAP10.0 = 58.4, which is much lower than the stated result in the paper (i.e. 66.5). If I ran the code above using your provided pre-trained model. I could get sAP10.0 = 66.5.

FYI, I'm using PyTorch 1.4.0 with Python 3.6, trained on one 2080Ti GPU.

Here are the settings I used:

2020-05-25 14:36:35,208 hawp INFO: Namespace(clean=False, config_file='config-files/hawp.yaml', opts=[], seed=2)
2020-05-25 14:36:35,208 hawp INFO: Loaded configuration file config-files/hawp.yaml
2020-05-25 14:36:35,208 hawp INFO:
SOLVER:
  IMS_PER_BATCH: 6
  BASE_LR: 0.0004
  MAX_EPOCH: 30
  STEPS: (25,)
  WEIGHT_DECAY: 0.0001
  CHECKPOINT_PERIOD: 1
  OPTIMIZER: "ADAM"
  AMSGRAD: True

DATALOADER:
  NUM_WORKERS: 8
DATASETS:
  IMAGE:
    PIXEL_MEAN: [109.730, 103.832, 98.681]
    PIXEL_STD: [22.275, 22.124, 23.229]
    TO_255: True

  TEST: ("wireframe_test","york_test")

MODEL:
  NAME: "Hourglass"
  HEAD_SIZE: [[3], [1], [1], [2], [2]] #Order: ang, dis, dis_residual, jloc, joff
  OUT_FEATURE_CHANNELS: 256
  HGNETS:
    DEPTH: 4
    NUM_STACKS: 2
    NUM_BLOCKS: 1
    INPLANES: 64
    NUM_FEATS: 128

  PARSING_HEAD:
    USE_RESIDUAL: True
    MAX_DISTANCE: 5.0
    N_DYN_JUNC:   300
    N_DYN_POSL:   300
    N_DYN_NEGL:   0
    N_DYN_OTHR:   0
    N_DYN_OTHR2:  300
    N_PTS0: 32
    N_PTS1: 8
    DIM_LOI: 128
    DIM_FC: 1024
    N_OUT_JUNC: 250
    N_OUT_LINE: 2500

  LOSS_WEIGHTS:
    loss_md: 1.0 # angle regression
    loss_dis: 1.0 # dis   regression
    loss_res: 1.0      # residual regression
    loss_joff: 0.25    # joff  regression
    loss_jloc: 8.0     # jloc  classification
    loss_pos: 1.0      # pos   classification
    loss_neg: 1.0      # neg   classification

OUTPUT_DIR: "outputs/hawp"

2020-05-25 14:36:35,208 hawp INFO: Running with config:
DATALOADER:
  NUM_WORKERS: 8
DATASETS:
  DISTANCE_TH: 0.02
  IMAGE:
    HEIGHT: 512
    PIXEL_MEAN: [109.73, 103.832, 98.681]
    PIXEL_STD: [22.275, 22.124, 23.229]
    TO_255: True
    WIDTH: 512
  NUM_STATIC_NEGATIVE_LINES: 40
  NUM_STATIC_POSITIVE_LINES: 300
  TARGET:
    HEIGHT: 128
    WIDTH: 128
  TEST: ('wireframe_test', 'york_test')
  TRAIN: ('wireframe_train',)
  VAL: ('wireframe_test',)
ENCODER:
  ANG_TH: 0.1
  DIS_TH: 5
  NUM_STATIC_NEG_LINES: 40
  NUM_STATIC_POS_LINES: 300
MODEL:
  DEVICE: cuda
  HEAD_SIZE: [[3], [1], [1], [2], [2]]
  HGNETS:
    DEPTH: 4
    INPLANES: 64
    NUM_BLOCKS: 1
    NUM_FEATS: 128
    NUM_STACKS: 2
  LOSS_WEIGHTS:
    loss_dis: 1.0
    loss_jloc: 8.0
    loss_joff: 0.25
    loss_md: 1.0
    loss_neg: 1.0
    loss_pos: 1.0
    loss_res: 1.0
  NAME: Hourglass
  OUT_FEATURE_CHANNELS: 256
  PARSING_HEAD:
    DIM_FC: 1024
    DIM_LOI: 128
    MATCHING_STRATEGY: junction
    MAX_DISTANCE: 5.0
    N_DYN_JUNC: 300
    N_DYN_NEGL: 0
    N_DYN_OTHR: 0
    N_DYN_OTHR2: 300
    N_DYN_POSL: 300
    N_OUT_JUNC: 250
    N_OUT_LINE: 2500
    N_PTS0: 32
    N_PTS1: 8
    N_STC_NEGL: 40
    N_STC_POSL: 300
    USE_RESIDUAL: True
  SCALE: 1.0
  WEIGHTS:
OUTPUT_DIR: outputs/hawp
SOLVER:
  AMSGRAD: True
  BACKBONE_LR_FACTOR: 1.0
  BASE_LR: 0.0004
  BIAS_LR_FACTOR: 1
  CHECKPOINT_PERIOD: 1
  GAMMA: 0.1
  IMS_PER_BATCH: 6
  MAX_EPOCH: 30
  MOMENTUM: 0.9
  OPTIMIZER: ADAM
  STEPS: (25,)
  WEIGHT_DECAY: 0.0001
  WEIGHT_DECAY_BIAS: 0

Log from the last epoch:

2020-05-26 02:57:25,525 hawp.trainer INFO: eta: 0:00:00 epoch: 30 iter: 3333 data: 0.0080 (0.0091) loss: 0.8587 (0.8574) loss_dis: 0.1496 (0.1461) loss_jloc: 0.0318 (0.0327) loss_joff: 0.3663 (0.3664) loss_md: 0.1448 (0.1417) loss_neg: 0.0771 (0.0809) loss_pos: 0.0451 (0.0456) loss_res: 0.0915 (0.0904) time: 0.4403 (0.4446) lr: 0.000040 max mem: 6232
cherubicXN commented 4 years ago

It is weird. Let me retrain the network.

alwc commented 4 years ago

Thanks @cherubicXN !

cherubicXN commented 4 years ago

Thanks @cherubicXN !

I remembered that the sAP will rapidly increase to 60+ in the very early stage. Did you check the sAP10 at 10, 15, 20, 25 epochs?

alwc commented 4 years ago

Here are the sAP10 evaluations from my trained models:

model_00001.pth: 49.0 model_00005.pth: 57.0 model_00010.pth: 58.5 model_00015.pth: 58.0 model_00020.pth: 58.6 model_00025.pth: 58.0

Also note that .png files I'm using were created from https://github.com/zhou13/lcnn/blob/master/dataset/wireframe.py since your provided images are .jpg files.

cherubicXN commented 4 years ago

Here are the sAP10 evaluations from my trained models:

model_00001.pth: 49.0 model_00005.pth: 57.0 model_00010.pth: 58.5 model_00015.pth: 58.0 model_00020.pth: 58.6 model_00025.pth: 58.0

Also note that .png files I'm using were created from https://github.com/zhou13/lcnn/blob/master/dataset/wireframe.py since your provided images are .jpg files.

Ok, I am training the network again. It may take 12 hours. I am watching the training logs now.

cherubicXN commented 4 years ago

Here are the sAP10 evaluations from my trained models:

model_00001.pth: 49.0 model_00005.pth: 57.0 model_00010.pth: 58.5 model_00015.pth: 58.0 model_00020.pth: 58.6 model_00025.pth: 58.0

Also note that .png files I'm using were created from https://github.com/zhou13/lcnn/blob/master/dataset/wireframe.py since your provided images are .jpg files.

The sAP log indicates that the network achieves the better sAP at epoch 10 and the remaining training epochs seem to be useless. That's really weird.

cherubicXN commented 4 years ago

After checking the previous model weights, I guess there may be some bugs in the training code. The sAP should be greater than 60 after 10 epochs of training and approaching 63.0 after 25 epochs. After the learning rate decayed, the sAP10 should be dramatically increased to 66+.

alwc commented 4 years ago

@cherubicXN Interesting, thanks for your insights and I hope you could figure out the bugs!

If you need more GPU machines to do more experiments, please let me know. I can train the model for you to see if I could reproduce the results.

cherubicXN commented 4 years ago

reproduce

Thanks very much :). I have enough GPU machines and the training time is not too long. Maybe I made some mistakes when I was refactoring the code.

cherubicXN commented 4 years ago

@alwc, I think I have found the bug. It is caused by the incorrect use of the lr_scheduler in train.py. At line 116 of the train.py, I made a mistake to call the learning rate scheduler by

scheduler.step(epoch)

which will make the learning rate decayed after the 1st epoch of training.

The correct implementation should be

scheduler.step()

You can check the training log in your machine to see if the learning rate is decayed to 4e-5 after the 1st epoch of training.

alwc commented 4 years ago

You can check the training log in your machine to see if the learning rate is decayed to 4e-5 after the 1st epoch of training.

I think you are right. Looking at my old training log, at the beginning of epoch: 2, the lr is 0.000040 and the new training log with the scheduler bug fix is 0.0004. Thanks @cherubicXN

Here are the sAP10 evaluations for the first 5 epochs from the new model:

Right now the results are roughly the same as the previous model for the first 5 epochs. I'll update you with the final result once the model is done training.

On a side note, I surprised to see model_00001.pth has a different result (old 49.0 vs new 48.4). It seems the seed doesn't work properly.

cherubicXN commented 4 years ago

You can check the training log in your machine to see if the learning rate is decayed to 4e-5 after the 1st epoch of training.

I think you are right. Looking at my old training log, at the beginning of epoch: 2, the lr is 0.000040 and the new training log with the scheduler bug fix is 0.0004. Thanks @cherubicXN

Here are the sAP10 evaluations for the first 5 epochs from the new model:

  • model_00001.pth: 48.4
  • model_00002.pth: 51.9
  • model_00003.pth: 54.2
  • model_00004.pth: 56.9
  • model_00005.pth: 56.9

Right now the results are roughly the same as the previous model for the first 5 epochs. I'll update you with the final result once the model is done training.

On a side note, I surprised to see model_00001.pth has a different result (old 49.0 vs new 48.4). It seems the seed doesn't work properly.

It is normal. The model weights in the first epoch are not stable. Today, I obtained 49.7 for model_00001.pth. I am also training the network, hope it works well.

alwc commented 4 years ago

Hi @cherubicXN ,

I just completed the training with the bug fixed and here are the sAP10 results for the last 5 epochs:

The results are much better than the results before the bug fix, but it is still a little bit off (66.0 vs 66.5). Not sure the discrepancy is due to randomness or other minor issues.

cherubicXN commented 4 years ago

Hi @cherubicXN ,

I just completed the training with the bug fixed and here are the sAP10 results for the last 5 epochs:

  • model_00025.pth: 63.3
  • model_00026.pth: 65.8
  • model_00027.pth: 66.0

    • sAP5: 62.2
    • sAP15: 67.7
  • model_00028.pth: 65.8
  • model_00029.pth: 65.4
  • model_00030.pth: 65.5

The results are much better than the results before the bug fix, but it is still a little bit off (66.0 vs 66.5). Not sure the discrepancy is due to randomness or other minor issues.

I think it is due to randomness. I also completed the training and obtained 66.4 for sAP10 at 27 epochs. Let me train it again.