Open ejm714 opened 3 years ago
Turns out that auto_lr_find
was broken in PTL, and has been fixed. However, using a more recent version hasn't yet worked
Moving context from slack into issue:
there was a bug where auto_lr_find
wasn’t resetting to the best learning rate (it just said it was). so it was using a crazy learning rate that caused immediate divergence
couple helpful links:
Running with an updated version of PTL (1.7.5) yields the following error
Finding best initial lr: 81%|████████████████████████▎ | 81/100 [08:13<01:55, 6.09s/it]
`Trainer.fit` stopped: `max_steps=81` reached.
LR finder stopped early after 81 steps due to diverging loss.
Learning rate set to 0.025118864315095822
Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
2022-09-12 21:01:42.928 | INFO | zamba.models.model_manager:train_model:303 - Writing out full configuration to /home/ubuntu/zamba/version_1/train_configuration.yaml.
2022-09-12 21:01:42.933 | INFO | zamba.models.model_manager:train_model:307 - Starting training...
/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /home/ubuntu/zamba/version_1 exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Adjusting learning rate of group 0 to 2.5119e-02.
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/zamba/bin/zamba", line 33, in <module>
sys.exit(load_entry_point('zamba', 'console_scripts', 'zamba')())
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/main.py", line 214, in __call__
return get_command(self)(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1130, in __call__
return self.main(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1055, in main
rv = self.invoke(ctx)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1657, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/main.py", line 500, in wrapper
return callback(**use_params) # type: ignore
File "/home/ubuntu/zamba/zamba/cli.py", line 181, in train
manager.train()
File "/home/ubuntu/zamba/zamba/models/model_manager.py", line 431, in train
train_model(
File "/home/ubuntu/zamba/zamba/models/model_manager.py", line 308, in train_model
trainer.fit(model, data_module)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1151, in _run
self._call_callback_hooks("on_fit_start")
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
fn(self, self.lightning_module, *args, **kwargs)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py", line 373, in on_fit_start
return super().on_fit_start(trainer, pl_module)
File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py", line 105, in on_fit_start
self._internal_optimizer_metadata[opt_idx], named_parameters
KeyError: 0
which looks like it has to do with https://github.com/Lightning-AI/lightning/pull/8501/files
It's possible that the Backbone Finetuning fix to resume training from a checkpoint broke auto_lr_find? It may be taking the auto_lr_find checkpoint incorrectly as the model checkpoint since it calls on_fit_start
here
Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
Confirmed that if we override that on_fit_start
change, things work as expected.
in zamba.pytorch.finetuning
, add:
def on_fit_start(self, trainer, pl_module):
"""
Raises:
MisconfigurationException:
If LightningModule has no nn.Module `backbone` attribute.
"""
if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
# revert the change so this just doesn't call the fit start method of BaseFinetuning
return
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
However, this is not the best fix as it sacrifices being able to resume training from a model checkpoint and have the optimizers load correctly. Filed https://github.com/Lightning-AI/lightning/issues/14674
Right now, the only options is to use
auto_lr_find
or not. If that is False, default learning rate is 0.001. We should let users specify this learning rate.ZambaVideoClassificationLightningModule
has an lr param: https://github.com/drivendataorg/zamba/blob/master/zamba/pytorch_lightning/utils.py#L143 but this can't be set via configs