TUI-NICR / ESANet

ESANet: Efficient RGB-D Semantic Segmentation for Indoor Scene Analysis
Other
231 stars 49 forks source link

load_weight problem #48

Closed zhouqunbing closed 1 year ago

zhouqunbing commented 1 year ago

Thank you for your great job! i am a greenhand ,i have meet some problems when i analyse your code. 1:when i am prepared to train your code,i download the weight from the link,which is named train_r34_NBt1D.pth.Then i open the file: encoder.conv1.weight <class 'str'> encoder.bn1.weight <class 'str'> encoder.bn1.bias <class 'str'> encoder.bn1.running_mean <class 'str'> encoder.bn1.running_var <class 'str'> encoder.bn1.num_batches_tracked <class 'str'> encoder.layer1.0.conv3x1_1.weight <class 'str'>

meanwhile i print the state_dict of model,it is like this: encoder_rgb.conv1.weight encoder_rgb.bn1.weight encoder_rgb.bn1.bias encoder_rgb.bn1.running_mean encoder_rgb.bn1.running_var encoder_rgb.bn1.num_batches_tracked encoder_rgb.layer1.0.conv1.weight encoder_rgb.layer1.0.bn1.weight encoder_rgb.layer1.0.bn1.bias

i find the key in the weight and the key in the model.state_dict is not same thus i think the keys are not matching,how to use the model.load_state_dict in the build_model.py?

2: there are two branch in the strcture,so there are two Resnet34 branch,but i just input a weight,how does the two branch use a weight?

mona0809 commented 1 year ago

Did you download the weights pretrained on ImageNet or the weights for the full model trained on NYUv2?

The ImageNet weights are automatically loaded by each branch in this function: https://github.com/TUI-NICR/ESANet/blob/a271ac55ca9d0801fd32c557789f2b908b8d9651/src/models/resnet.py#L469-L509

This function also renames the keys accordingly. Just save the weights in ./trained_models/imagenet and provide the argument --pretrained_dir ./trained_models/imagenet when training.

zhouqunbing commented 1 year ago

Did you download the weights pretrained on ImageNet or the weights for the full model trained on NYUv2?

The ImageNet weights are automatically loaded by each branch in this function:

https://github.com/TUI-NICR/ESANet/blob/a271ac55ca9d0801fd32c557789f2b908b8d9651/src/models/resnet.py#L469-L509

This function also renames the keys accordingly. Just save the weights in ./trained_models/imagenet and provide the argument --pretrained_dir ./trained_models/imagenet when training.

thank you very much!