deeperlearner / pytorch-template

A pytorch template files generator.
32 stars 2 forks source link

Questions about iteration-based training #1

Open fabriziojpiva opened 3 years ago

fabriziojpiva commented 3 years ago

Hi, first of all thanks for contributing with such a template, it is very useful and I am trying to use it for a domain adaptation algorithm.

I have a couple of questions regarding iteration-based training, since the template is supposed to work for both epoch-based and iteration-based training. In the following lines:

https://github.com/deeperlearner/pytorch-template/blob/e73871dab70468f15615c95bf3974f7a09cc8240/trainers/trainer.py#L28-L33

you define the logic for both training methods. However, in the _train_epoch function, you do validation only after the epoch is finished, which is fine if you are doing epoch-based training, but if I want to evaluate the model every a certain val_steps, then the function will never perform evaluation because the iterator declared in line 33 is infinite, causing the for-loop from _train_epoch to never stop as long as the condition:

https://github.com/deeperlearner/pytorch-template/blob/e73871dab70468f15615c95bf3974f7a09cc8240/trainers/trainer.py#L110-L111

is met. Another issue is that you are doing the step of the learning rate scheduler after an epoch:

https://github.com/deeperlearner/pytorch-template/blob/e73871dab70468f15615c95bf3974f7a09cc8240/trainers/trainer.py#L123-L124

but for iteration-based training that is supposed to happen every step. Shouldn't the lines 123-124 go inside of the for-loop?

deeperlearner commented 3 years ago

Thank you for pointing out this issue. I didn't notice this since I always use epoch-based training. A pull request is very welcome. Thanks for reviewing the repo.

deeperlearner commented 3 years ago

Sorry, let me correct my comment. The iteration-based training here uses len_epoch variable to define how many iterations in an epoch. It still like epoch-based training. Just change the number of iterations in one epoch. (default is len of dataloader) Thus, when you define len_epoch, it will still do _valid_epoch() after

https://github.com/deeperlearner/pytorch-template/blob/e73871dab70468f15615c95bf3974f7a09cc8240/trainers/trainer.py#L110-L111

If you want to validate in certain val_step. You can set len_epoch equal to val_step in your configs/config.json.

You can move lr_scheduler part into the for-loop.