TorchEnsemble-Community / Ensemble-Pytorch

A unified ensemble framework for PyTorch to improve the performance and robustness of your deep learning model.
https://ensemble-pytorch.readthedocs.io
BSD 3-Clause "New" or "Revised" License
1.09k stars 95 forks source link

can not fine-tuning after reload the model weights #168

Open wubizhi opened 2 months ago

wubizhi commented 2 months ago
            n_bases=2
            softGBM = SoftGradientBoostingRegressor(
                estimator=MLP,
                n_estimators=n_bases,
                shrinkage_rate=1.00,
                cuda=True
            )

            io.load(softGBM, save_dir='./torch_ensemble_results/softGBM/')  # reload

            criterion = StepwiseMSELoss()
            softGBM.set_criterion(criterion)
            softGBM.set_optimizer('Adam', lr=0.001, weight_decay=5e-4)
            softGBM.set_scheduler("ReduceLROnPlateau")

            # Re-training
            softGBM.fit(train_loader=new_train_loader,
                             log_interval=128, 
                             epochs=20, 
                             test_loader=new_vali_loader,
                             save_model=True, 
                             save_dir='./torch_ensemble_results/softGBM/')

I want to know my code above can work or not? if i have just trained the model in 20 epoches, and reload the model weights for the longer epoches training? if it make sense, why it would report bug like below:

sKAN_softGBM.fit(train_loader=new_train_loader,

File "/home/WuBizhi/anaconda3/envs/torch-ensemble/lib/python3.9/site-packages/torchensemble/soft_gradient_boosting.py", line 514, in fit super().fit( File "/home/WuBizhi/anaconda3/envs/torch-ensemble/lib/python3.9/site-packages/torchensemble/soft_gradient_boosting.py", line 261, in fit loss += criterion(output[idx], rets[idx]) IndexError: list index out of range

xuyxu commented 1 month ago

Hi @wubizhi, what is the value of n_bases in your code snippet and in the model located at ./torch_ensemble_results/softGBM/ ?

wubizhi commented 1 month ago

n_bases = 2

xuyxu commented 1 month ago

Is it the same as the model located at ./torch_ensemble_results/softGBM/ ?

wubizhi commented 1 month ago

Yes, the trained model path, the model name keep the same.

Firstly, i trained the model for 20 epoches, the trained weights were stored in the path: ./torch_ensemble_results/softGBM/

Then, i using the io.load to reload the model weights in the path ./torch_ensemble_results/softGBM/

Thirdly, i just want to runing the same model for another 50 epoches, but, it report the issue as i paste above.

Dose the reload and re-running of torch-ensemble work well for you? if yes, can you give me some demo that i can find out what heppend in my own code? Or have you had some tips or idea for the issue?

Best wishes, thanks very much!

xuyxu commented 1 month ago

Sure, I will try to reproduce your problem first, and then get back to you.