drivendataorg / zamba

A Python package for identifying 42 kinds of animals, training custom models, and estimating distance from camera trap videos
https://zamba.drivendata.org/docs/stable/
MIT License
118 stars 27 forks source link

Expose ability to set learning rate for models #157

Open ejm714 opened 3 years ago

ejm714 commented 3 years ago

Right now, the only options is to use auto_lr_find or not. If that is False, default learning rate is 0.001. We should let users specify this learning rate.

ZambaVideoClassificationLightningModule has an lr param: https://github.com/drivendataorg/zamba/blob/master/zamba/pytorch_lightning/utils.py#L143 but this can't be set via configs

ejm714 commented 2 years ago

Turns out that auto_lr_find was broken in PTL, and has been fixed. However, using a more recent version hasn't yet worked

ejm714 commented 2 years ago

Moving context from slack into issue:

there was a bug where auto_lr_find wasn’t resetting to the best learning rate (it just said it was). so it was using a crazy learning rate that caused immediate divergence

couple helpful links:

ejm714 commented 2 years ago

Running with an updated version of PTL (1.7.5) yields the following error

Finding best initial lr:  81%|████████████████████████▎     | 81/100 [08:13<01:55,  6.09s/it]
`Trainer.fit` stopped: `max_steps=81` reached.
LR finder stopped early after 81 steps due to diverging loss.
Learning rate set to 0.025118864315095822
Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
2022-09-12 21:01:42.928 | INFO     | zamba.models.model_manager:train_model:303 - Writing out full configuration to /home/ubuntu/zamba/version_1/train_configuration.yaml.
2022-09-12 21:01:42.933 | INFO     | zamba.models.model_manager:train_model:307 - Starting training...
/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /home/ubuntu/zamba/version_1 exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Adjusting learning rate of group 0 to 2.5119e-02.
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/zamba/bin/zamba", line 33, in <module>
    sys.exit(load_entry_point('zamba', 'console_scripts', 'zamba')())
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/main.py", line 214, in __call__
    return get_command(self)(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/main.py", line 500, in wrapper
    return callback(**use_params)  # type: ignore
  File "/home/ubuntu/zamba/zamba/cli.py", line 181, in train
    manager.train()
  File "/home/ubuntu/zamba/zamba/models/model_manager.py", line 431, in train
    train_model(
  File "/home/ubuntu/zamba/zamba/models/model_manager.py", line 308, in train_model
    trainer.fit(model, data_module)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1151, in _run
    self._call_callback_hooks("on_fit_start")
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py", line 373, in on_fit_start
    return super().on_fit_start(trainer, pl_module)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/finetuning.py", line 105, in on_fit_start
    self._internal_optimizer_metadata[opt_idx], named_parameters
KeyError: 0

which looks like it has to do with https://github.com/Lightning-AI/lightning/pull/8501/files

It's possible that the Backbone Finetuning fix to resume training from a checkpoint broke auto_lr_find? It may be taking the auto_lr_find checkpoint incorrectly as the model checkpoint since it calls on_fit_start here

Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
ejm714 commented 2 years ago

Confirmed that if we override that on_fit_start change, things work as expected.

in zamba.pytorch.finetuning, add:

def on_fit_start(self, trainer, pl_module):
        """
        Raises:
            MisconfigurationException:
                If LightningModule has no nn.Module `backbone` attribute.
        """
        if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
            # revert the change so this just doesn't call the fit start method of BaseFinetuning
            return
        raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")

However, this is not the best fix as it sacrifices being able to resume training from a model checkpoint and have the optimizers load correctly. Filed https://github.com/Lightning-AI/lightning/issues/14674