Closed czarmanu closed 2 years ago
I save the trained model after some epochs using the following command: torch.save(net.state_dict(), TRAINED_MODEL_SAVE_PATH)
And updated the ckpt.py with the following: ckpt4 = { 'net': 'MSCG-ETHZ', 'data': 'Agriculture', 'bands': ['NIR','RGB'], 'nodes': (32,32), 'snapshot': 'ckpt/new_model.pth' }
However, when i try to load the same, i get the following error:
not found the net
Traceback (most recent call last):
File "./tools/test_submission.py", line 244, in
The model is saved in train_R50.py as follows:
if updated or (train_args.best_record['val_loss'] > avg_loss): torch.save(net.state_dict(), os.path.join(train_args.save_path, snapshot_name + '.pth'))
I follow the same syntax to save the model by setting "TRAINED_MODEL_SAVE_PATH= ckpt/new_model.pth". Even after saving the model in the same way, i do not understand why it throws the error while loading this newly trained and saved model.
I am not doing it wrong according to this forum post: https://discuss.pytorch.org/t/torch-has-not-attribute-load-state-dict/21781/26
Any help is highly appreciated. Thanks!
The snapshot name is dynamicly created during the training process, the name pattern would like: epoch_8_loss_0.99527_acc_0.82278_acc-cls_0.60967_mean-iu_0.48098_fwavacc_0.70248_f1_0.62839_lr_0.0000829109.pth
all trained model snapshots would have been saved at ./ckpt
folder. Don't change the name pattern only if you know how to modify the file /config/configs_kf.py
, because once you want to resume your training from specific snapshots, the function resume_train()
will decode the name pattern as pre-defined such a way.
I tried running for 100+ epochs, however, no newly trained model was stored in the ckpt folder. That's why I decided to stop the training after a couple of epochs and then save the then model. Do you know why the model was not saved automatically? It didn't even throw an error
It should be save at ./ckpt
folder if you didn't change default settings in the file./config/configs_kf.py
ckpt_path = '../ckpt'
The configs_kf.py was unchanged.
is the presence of the pre-trained model in the ckpt folder the issue?
The trained model pth files are normally saved in the subfoders created under ./ckpt
folder, were there any subfolders?
No subfolders were created too. Could that be due to folder write permission issues? But, in that case, error messages should pop pu, right?
could u print out the save_path by adding one line in train_R50.py
beforemain()
print(train_args.save_path)
../ckpt/MSCG-Rx50/Agriculture_NIR-RGB_kf-0-0-reproduce_ACW_loss2_adax
However, no such subfolder is created during training
Probably you can try to modifiytrain_R50.py
to set : train_args.save_pred = True
and comment out or delete this line train_args.ckpt_path=os.path.abspath(os.curdir)
I found the directory where the models are saved. However, i still get the model load error:
"not found the net
ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth
Traceback (most recent call last):
File "/scratch/manu/MSCG-Net-master_selftrained/./tools/test_submission.py", line 239, in
Copy epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth to the folder ./ckpt
, and then
And updated the ckpt.py with the following:
ckpt4 = {
'net': 'MSCG-ETHZ',
'data': 'Agriculture',
'bands': ['NIR','RGB'],
'nodes': (32,32),
'snapshot': '../ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth'
}
or just update with abs path to your model saved like:
'snapshot': 'path/to/your/model_saved/ckpt/epoch_2_loss_0.68043_acc_0.81885_acc-cls_0.83440_mean-iu_0.67651_fwavacc_0.70391_f1_0.80475_lr_0.0000863686.pth'
It seems that your net was not built correctly. You might need use MSCG-Rx50
, instead of 'MSCG-ETHZ'
, if you want use 'MSCG-ETHZ', you must need to modify the file tools/model.py
accordingly, change MSCG-RX50 to MSCG-ETHZ
Updating to 'MSCG-Rx50 worked'. Thanks a lot :)
Where is the model (or the major checkpoints) trained using train_R50.py stored?