facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.34k stars 7.46k forks source link

hooks-callbacks #867

Closed JavierClearImageAI closed 4 years ago

JavierClearImageAI commented 4 years ago

How to implement some callbacks?

I need to implement lrt_finder, early_stopping, and model_checkpoint ("best" accuracy if possible. Otherwise, minimum loss).

  1. Is it possible to get these callbacks?:

NOTE:

  1. I have checked fastai, but since all the model is built in torch, I would prefer to use pure detectron or just modify the minimum possible. I guess we could get trainer.model, trainer.loss, etc and introduce this data on the lrt_finder or other callbacks?

  2. Also, any hint on how to build these callbacks using hooks?. I just have a couple of days to implement all three to detectron training, so I would like to check out the most efficient-fast way.

This seems like a nice starting point: https://detectron2.readthedocs.io/tutorials/extend.html

I hope I made myself clear. Thanks a lot

JavierClearImageAI commented 4 years ago

just adding:

cfg.SOLVER.CHECKPOINT_PERIOD = 20

to the default training, we get all the solvers every 20 iterations. Then we can just select the best one based on the "metrics.json" which is in the outputs directory.

I just found out that we could create our own default trainer based on the detectron2.engine.defaults.py, similarly like how they do in detectron2/tools/train_net.py .

So if we want to keep only the best model we could probably do it somewhere around the build_hooks() method.

ppwwyyxx commented 4 years ago

It's always easier to use your own loop to implement custom training logic, one example is https://github.com/facebookresearch/detectron2/blob/master/tools/plain_train_net.py.

You can abstract these logic into hooks but it's often more work. For example, for early stopping and saving the best model, you can obtain the metrics from self.trainer.storage (no need to parse the json file). I don't know how lrt_finder works but I suppose you'll need to run training a few times - so it's not about hooks.