Closed vineetk1 closed 1 year ago
Hi! thanks for your contribution!, great first issue!
I ran into this too, and refactored my code to init two separate trainers and models and tune()
d the batch_size
and learning_rate
separately. Then, out of an abundance of caution, I re-init the trainer and model again before calling fit()
.
In case it helps as a clue, when I was debugging this, it seemed that the auto_scale_batch_size
functionality does not properly replace the original model weights after tuning. I believe the LR finder does this, and therefore might have example code.
Working code that tune()
s separately:
dm = BoringDataModule()
model = BoringModel(...)
trainer = pl.Trainer(..., auto_scale_batch_size=True)
trainer.tune(model, datamodule=dm)
print('Suggested batch size:', dm.batch_size)
model = BoringModel(...)
trainer = pl.Trainer(..., auto_lr_find=True)
trainer.tune(model, datamodule=dm)
print('Suggested learning rate:', model.hparams.learning_rate)
model = BoringModel(...)
trainer = pl.Trainer(...)
trainer.fit(model, datamodule=dm)
@SkafteNicki
🐛 Bug
trainer.tune()
works just fine when eitherTrainer.__init__(auto_lr_find=False, auto_scale_batch_size=True)
orTrainer.__init__(auto_lr_find=True, auto_scale_batch_size=False)
However, trainer.tune() fails whenTrainer.__init__(auto_lr_find=True, auto_scale_batch_size=True)
LR finder stopped early due to diverging loss.
Please reproduce using the BoringModel
To Reproduce
Use following BoringModel and post here
In your own environment trainer.tune() should fail when Trainer.init(auto_lr_find=True, auto_scale_batch_size=True). However, if you want to reproduce the bug from my code then go to Github, and fork from https://github.com/vineetk1/conversational-transaction-bot Then run the following on commandline: python3 ctbMain.py input_param_files/distilgpt2_params
Expected behavior
trainer.tune() should find the Batch-Size and the initial Learning-Rate
Environment
Note:
Bugs with code
are solved faster !Colab Notebook
should be madepublic
!IDE
: Please, use our python bug_report_model.py template.Colab Notebook
: Please copy and paste the output from our environment collection script (or fill out the checklist below manually).You can get the script and run it with:
conda
,pip
, source): pipAdditional context