drivendataorg / zamba

A Python package for identifying 42 kinds of animals, training custom models, and estimating distance from camera trap videos
MIT License
118 stars 27 forks source link

Expose ability to set learning rate for models #157

Open ejm714 opened 3 years ago

ejm714 commented 3 years ago

Right now, the only options is to use auto_lr_find or not. If that is False, default learning rate is 0.001. We should let users specify this learning rate.

ZambaVideoClassificationLightningModule has an lr param: but this can't be set via configs

ejm714 commented 2 years ago

Turns out that auto_lr_find was broken in PTL, and has been fixed. However, using a more recent version hasn't yet worked

ejm714 commented 2 years ago

Moving context from slack into issue:

there was a bug where auto_lr_find wasn’t resetting to the best learning rate (it just said it was). so it was using a crazy learning rate that caused immediate divergence

couple helpful links:

ejm714 commented 2 years ago

Running with an updated version of PTL (1.7.5) yields the following error

Finding best initial lr:  81%|████████████████████████▎     | 81/100 [08:13<01:55,  6.09s/it]
`` stopped: `max_steps=81` reached.
LR finder stopped early after 81 steps due to diverging loss.
Learning rate set to 0.025118864315095822
Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
2022-09-12 21:01:42.928 | INFO     | zamba.models.model_manager:train_model:303 - Writing out full configuration to /home/ubuntu/zamba/version_1/train_configuration.yaml.
2022-09-12 21:01:42.933 | INFO     | zamba.models.model_manager:train_model:307 - Starting training...
/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/ UserWarning: Checkpoint directory /home/ubuntu/zamba/version_1 exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Adjusting learning rate of group 0 to 2.5119e-02.
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/zamba/bin/zamba", line 33, in <module>
    sys.exit(load_entry_point('zamba', 'console_scripts', 'zamba')())
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/", line 214, in __call__
    return get_command(self)(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/", line 1055, in main
    rv = self.invoke(ctx)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/", line 1657, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/click/", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/typer/", line 500, in wrapper
    return callback(**use_params)  # type: ignore
  File "/home/ubuntu/zamba/zamba/", line 181, in train
  File "/home/ubuntu/zamba/zamba/models/", line 431, in train
  File "/home/ubuntu/zamba/zamba/models/", line 308, in train_model, data_module)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 696, in fit
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 650, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1151, in _run
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/trainer/", line 1597, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/", line 373, in on_fit_start
    return super().on_fit_start(trainer, pl_module)
  File "/home/ubuntu/anaconda3/envs/zamba/lib/python3.8/site-packages/pytorch_lightning/callbacks/", line 105, in on_fit_start
    self._internal_optimizer_metadata[opt_idx], named_parameters
KeyError: 0

which looks like it has to do with

It's possible that the Backbone Finetuning fix to resume training from a checkpoint broke auto_lr_find? It may be taking the auto_lr_find checkpoint incorrectly as the model checkpoint since it calls on_fit_start here

Restoring states from the checkpoint path at /home/ubuntu/zamba/.lr_find_26c5855d-e088-4e32-8a37-75f5b9ff1c9f.ckpt
ejm714 commented 2 years ago

Confirmed that if we override that on_fit_start change, things work as expected.

in zamba.pytorch.finetuning, add:

def on_fit_start(self, trainer, pl_module):
                If LightningModule has no nn.Module `backbone` attribute.
        if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
            # revert the change so this just doesn't call the fit start method of BaseFinetuning
        raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")

However, this is not the best fix as it sacrifices being able to resume training from a model checkpoint and have the optimizers load correctly. Filed