sacmehta / ESPNet

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation
https://sacmehta.github.io/ESPNet/
MIT License
541 stars 112 forks source link

Train on a different dataset #14

Closed DenisN03 closed 6 years ago

DenisN03 commented 6 years ago

Hello, @sacmehta! I'm trying to train a neural network on my own database which consists of 5 classes.

To train the encoder, I use the command: CUDA_VISIBLE_DEVICES=1 python3 main.py --data_dir=./DataBase --inWidth=480 --inHeight=360 --classes=5 --cached_data_file=data.p --batch_size=10

To train the decoder, I use the command: CUDA_VISIBLE_DEVICES=1 python3 main.py --data_dir=./DataBase --inWidth=480 --inHeight=360 --classes=5 --cached_data_file=data.p --batch_size=5 --decoder=True --pretrained=./results_enc__enc_2_8_long/model_161.pth --scaleIn=1 --savedir=./results_dec_

After completing the training, I start testing the neural network: CUDA_VISIBLE_DEVICES=1 python3 VisualizeResults.py --modelType=1 --inWidth=480 --inHeight=360 --scaleIn=1 --weightsDir=../pretrained/decoder/ --classes=5 --cityFormat=False

In the decoder folder are the weights (new) of the trained neural network (espnet_p_2_q_8.pth). As a result, I get the following error: RuntimeError: Error(s) in loading state_dict for ESPNet: While copying the parameter named "conv.conv.weight", whose dimensions in the model are torch.Size([5, 21, 3, 3]) and whose dimensions in the checkpoint are torch.Size([5, 24, 3, 3]).

How can I fix this error?

sacmehta commented 6 years ago

Hi,

The model files in test folder is slightly different from the one in train folder. That is why you are getting the error. Just replace the Model.py in the test folder with the one you used for training, and it should work.

A side note:

Since you are using 5 classes, ESP block might not be very effective in the decoder because it projects the feature maps into low-dimensional space which is equal to number of classes. To learn more better representations, you can follow this issue where we provided a work around to work on dataset with few classes.

https://github.com/sacmehta/ESPNet/issues/13