samleoqh / MSCG-Net

Multi-view Self-Constructing Graph Convolutional Networks with Adaptive Class Weighting Loss for Semantic Segmentation
MIT License
68 stars 28 forks source link

Saving trained model #21

Closed czarmanu closed 2 years ago

czarmanu commented 2 years ago

Where is the model (or the major checkpoints) trained using train_R50.py stored?

czarmanu commented 2 years ago

I save the trained model after some epochs using the following command: torch.save(net.state_dict(), TRAINED_MODEL_SAVE_PATH)

And updated the ckpt.py with the following: ckpt4 = { 'net': 'MSCG-ETHZ', 'data': 'Agriculture', 'bands': ['NIR','RGB'], 'nodes': (32,32), 'snapshot': 'ckpt/new_model.pth' }

However, when i try to load the same, i get the following error: not found the net Traceback (most recent call last): File "./tools/test_submission.py", line 244, in main() File "./tools/test_submission.py", line 28, in main net4 = get_net(ckpt4) # MSCG-NET-ETHZ File "/scratch/manu/MSCG-Net-master_selftrained/tools/ckpt.py", line 53, in get_net net.load_state_dict(torch.load(ckpt['snapshot'])) AttributeError: 'int' object has no attribute 'load_state_dict'

czarmanu commented 2 years ago

The model is saved in train_R50.py as follows:

if updated or (train_args.best_record['val_loss'] > avg_loss): torch.save(net.state_dict(), os.path.join(train_args.save_path, snapshot_name + '.pth'))

I follow the same syntax to save the model by setting "TRAINED_MODEL_SAVE_PATH= ckpt/new_model.pth". Even after saving the model in the same way, i do not understand why it throws the error while loading this newly trained and saved model.

I am not doing it wrong according to this forum post: https://discuss.pytorch.org/t/torch-has-not-attribute-load-state-dict/21781/26

Any help is highly appreciated. Thanks!

samleoqh commented 2 years ago

The snapshot name is dynamicly created during the training process, the name pattern would like: epoch_8_loss_0.99527_acc_0.82278_acc-cls_0.60967_mean-iu_0.48098_fwavacc_0.70248_f1_0.62839_lr_0.0000829109.pth all trained model snapshots would have been saved at ./ckpt folder. Don't change the name pattern only if you know how to modify the file /config/configs_kf.py, because once you want to resume your training from specific snapshots, the function resume_train() will decode the name pattern as pre-defined such a way.

czarmanu commented 2 years ago

I tried running for 100+ epochs, however, no newly trained model was stored in the ckpt folder. That's why I decided to stop the training after a couple of epochs and then save the then model. Do you know why the model was not saved automatically? It didn't even throw an error

samleoqh commented 2 years ago

It should be save at ./ckptfolder if you didn't change default settings in the file./config/configs_kf.py ckpt_path = '../ckpt'

czarmanu commented 2 years ago

The configs_kf.py was unchanged.

is the presence of the pre-trained model in the ckpt folder the issue?

samleoqh commented 2 years ago

The trained model pth files are normally saved in the subfoders created under ./ckpt folder, were there any subfolders?

czarmanu commented 2 years ago

No subfolders were created too. Could that be due to folder write permission issues? But, in that case, error messages should pop pu, right?

samleoqh commented 2 years ago

could u print out the save_path by adding one line in train_R50.py beforemain() print(train_args.save_path)

czarmanu commented 2 years ago

../ckpt/MSCG-Rx50/Agriculture_NIR-RGB_kf-0-0-reproduce_ACW_loss2_adax

However, no such subfolder is created during training

samleoqh commented 2 years ago

Probably you can try to modifiytrain_R50.py to set : train_args.save_pred = True and comment out or delete this line train_args.ckpt_path=os.path.abspath(os.curdir)

czarmanu commented 2 years ago

I found the directory where the models are saved. However, i still get the model load error:

"not found the net ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth Traceback (most recent call last): File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 239, in main() File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 28, in main net4 = get_net(ckpt4) File "/scratch/manu/MSCG-Net-master_selftrained/tools/ckpt.py", line 54, in get_net net.load_state_dict(torch.load(ckpt['snapshot'])) AttributeError: 'int' object has no attribute 'load_state_dict'"

samleoqh commented 2 years ago

Copy epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth to the folder ./ckpt, and then And updated the ckpt.py with the following:

ckpt4 = {
'net': 'MSCG-ETHZ',
'data': 'Agriculture',
'bands': ['NIR','RGB'],
'nodes': (32,32),
'snapshot': '../ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth'
}

or just update with abs path to your model saved like: 'snapshot': 'path/to/your/model_saved/ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth'

samleoqh commented 2 years ago

It seems that your net was not built correctly. You might need use MSCG-Rx50, instead of 'MSCG-ETHZ', if you want use 'MSCG-ETHZ', you must need to modify the file tools/model.py accordingly, change MSCG-RX50 to MSCG-ETHZ

czarmanu commented 2 years ago

Updating to 'MSCG-Rx50 worked'. Thanks a lot :)