TUI-NICR / ESANet

ESANet: Efficient RGB-D Semantic Segmentation for Indoor Scene Analysis
Other
239 stars 52 forks source link

load_weight problem #48

Closed zhouqunbing closed 2 years ago

zhouqunbing commented 2 years ago

Thank you for your great job! i am a greenhand ,i have meet some problems when i analyse your code. 1:when i am prepared to train your code,i download the weight from the link,which is named train_r34_NBt1D.pth.Then i open the file: encoder.conv1.weight <class 'str'> encoder.bn1.weight <class 'str'> encoder.bn1.bias <class 'str'> encoder.bn1.running_mean <class 'str'> encoder.bn1.running_var <class 'str'> encoder.bn1.num_batches_tracked <class 'str'> encoder.layer1.0.conv3x1_1.weight <class 'str'>

meanwhile i print the state_dict of model,it is like this: encoder_rgb.conv1.weight encoder_rgb.bn1.weight encoder_rgb.bn1.bias encoder_rgb.bn1.running_mean encoder_rgb.bn1.running_var encoder_rgb.bn1.num_batches_tracked encoder_rgb.layer1.0.conv1.weight encoder_rgb.layer1.0.bn1.weight encoder_rgb.layer1.0.bn1.bias

i find the key in the weight and the key in the model.state_dict is not same thus i think the keys are not matching,how to use the model.load_state_dict in the build_model.py?

2: there are two branch in the strcture,so there are two Resnet34 branch,but i just input a weight,how does the two branch use a weight?

mona0809 commented 2 years ago

Did you download the weights pretrained on ImageNet or the weights for the full model trained on NYUv2?

The ImageNet weights are automatically loaded by each branch in this function: https://github.com/TUI-NICR/ESANet/blob/a271ac55ca9d0801fd32c557789f2b908b8d9651/src/models/resnet.py#L469-L509

This function also renames the keys accordingly. Just save the weights in ./trained_models/imagenet and provide the argument --pretrained_dir ./trained_models/imagenet when training.

zhouqunbing commented 2 years ago

Did you download the weights pretrained on ImageNet or the weights for the full model trained on NYUv2?

The ImageNet weights are automatically loaded by each branch in this function:

https://github.com/TUI-NICR/ESANet/blob/a271ac55ca9d0801fd32c557789f2b908b8d9651/src/models/resnet.py#L469-L509

This function also renames the keys accordingly. Just save the weights in ./trained_models/imagenet and provide the argument --pretrained_dir ./trained_models/imagenet when training.

thank you very much!