pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.52k stars 617 forks source link

FastaiLRFinder usage #3152

Closed leej3 closed 11 months ago

leej3 commented 11 months ago

📚 FastaiLRFinder Documentation

I was trying out the FastaiLRFinder and I got a confusing error message. The error message or documentation could probably be improved a little for this. I'm happy to contribute a fix but I'm not sure whether the docs or the error report would be a better choice.

The error can be reproduced by modifying the output of the ignite code generator for an object classification task. I made some changes to use the finder to set the learning rate.

See diff
``` diff --git a/main.py b/main.py index d177d6b..244f869 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from typing import Any import ignite.distributed as idist from data import setup_data from ignite.engine import Events -from ignite.handlers import PiecewiseLinear +from ignite.handlers import PiecewiseLinear, FastaiLRFinder from ignite.metrics import Accuracy, Loss from ignite.utils import manual_seed from models import setup_model @@ -13,7 +13,7 @@ from trainers import setup_evaluator, setup_trainer from utils import * -def run(local_rank: int, config: Any): +def run(local_rank: int, config: Any,use_finder=True): # make a certain seed rank = idist.get_rank() manual_seed(config.seed + rank) @@ -33,19 +33,41 @@ def run(local_rank: int, config: Any): model = idist.auto_model(setup_model(config.model)) optimizer = idist.auto_optim(optim.Adam(model.parameters(), lr=config.lr)) loss_fn = nn.CrossEntropyLoss().to(device=device) + # trainer and evaluator + trainer = setup_trainer(config, model, optimizer, loss_fn, device, dataloader_train.sampler) + evaluator = setup_evaluator(config, model, device) + + if use_finder: + lr_finder = FastaiLRFinder() + to_save = {"model": model, "optimizer": optimizer} + + with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder: + trainer_with_lr_finder.run(dataloader_train) + + # Get lr_finder results + lr_finder.get_results() + + # Plot lr_finder results (requires matplotlib) + lr_finder.plot() + + # get lr_finder suggestion for lr + finder_rate = lr_finder.lr_suggestion() + print(f"{finder_rate}!!!!!!!!!!") + optimizer = idist.auto_optim(optim.Adam(model.parameters()),lr=finder_rate) + + learning_rate = finder_rate if use_finder else config.lr + milestones_values = [ (0, 0.0), ( len(dataloader_train), - config.lr, + learning_rate, ), (config.max_epochs * len(dataloader_train), 0.0), ] lr_scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values) - # trainer and evaluator - trainer = setup_trainer(config, model, optimizer, loss_fn, device, dataloader_train.sampler) - evaluator = setup_evaluator(config, model, device) + # attach metrics to evaluator accuracy = Accuracy(device=device) ```

This fails because the training loss is returned as a dict. The error reported is:


Current run is terminating due to exception: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict
Engine run is terminating due to exception: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict
Traceback (most recent call last):
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 159, in <module>
    main()
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 155, in main
    p.run(run, config=config)
  File "/opt/conda/lib/python3.10/site-packages/ignite/distributed/launcher.py", line 316, in run
    func(local_rank, *args, **kwargs)
  File "/workspace/ignite/ignite-template-vision-classification/main.py", line 45, in run
    trainer_with_lr_finder.run(dataloader_train)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 898, in run
    return self._internal_run()
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 941, in _internal_run
    return next(self._internal_run_generator)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 999, in _internal_run_as_gen
    self._handle_exception(e)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
    raise e
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 965, in _internal_run_as_gen
    epoch_time_taken += yield from self._run_once_on_dataset_as_gen()
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 1093, in _run_once_on_dataset_as_gen
    self._handle_exception(e)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 644, in _handle_exception
    raise e
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 1075, in _run_once_on_dataset_as_gen
    self._fire_event(Events.ITERATION_COMPLETED)
  File "/opt/conda/lib/python3.10/site-packages/ignite/engine/engine.py", line 431, in _fire_event
    func(*first, *(event_args + others), **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/ignite/handlers/lr_finder.py", line 166, in _log_lr_and_loss
    raise TypeError(
TypeError: output of the engine should be of type float or 0d torch.Tensor or 1d torch.Tensor with 1 element, but got output of type dict

The finder does not expect the format that is returned ({"train_loss": xx.xxx}). This can be fixed by using an output transform for the finder:

with lr_finder.attach(trainer, to_save=to_save,output_transform=lambda x:x["train_loss"]) as trainer_with_lr_finder:
    trainer_with_lr_finder.run(dataloader_train)
vfdev-5 commented 11 months ago

Thanks for reporting, John! I would say it would be helpful to improve error message such that users seeing it know what to do directly. Changing API ref docs can be optional, I would say, but will be happy to approve if you change both. So, if you are eager to contribute this improvement, you are more than welcome ;)