fastai / fastai

The fastai deep learning library
http://docs.fast.ai
Apache License 2.0
26.22k stars 7.56k forks source link

Change default behavior GANLearner switcher #3390

Closed StFroese closed 3 years ago

StFroese commented 3 years ago

Describe the bug By default the GANLearner switcher is the FixedGANSwitcher with n_crit=5 & n_gen=1 but the original GAN paper (https://arxiv.org/pdf/1406.2661.pdf) recommends an alternation for the switcher with k=n_crit=n_gen=1

To Reproduce Default FixedGANSwitcher is: https://github.com/fastai/fastai/blob/301016c5d3de2bdb5269121bd0716538d85f7409/fastai/vision/gan.py#L305

in

https://github.com/fastai/fastai/blob/301016c5d3de2bdb5269121bd0716538d85f7409/fastai/vision/gan.py#L299-L309

Expected behavior I suggest that the default behavior should be k=n_crit=n_gen=1 to be consistend with the original work of Ian Goodfellow et al. if switcher is None: switcher = FixedGANSwitcher()

tmabraham commented 3 years ago

The GAN module is not very well-documented but my understanding is that it is implementing a WGAN which is an improved version of a regular GAN. For WGANs, the critic is trained for 5 steps and then the generator is trained for 1 step.

StFroese commented 3 years ago

Yeah ok. This seems to be the case here but in the tutorial/notebook it states to use GANLearner.wgan(...) for WGAN. Wouldn't it be a better solution to have the class function wgan() of the GANLearner to handle the number of iterations in this special case?

StFroese commented 3 years ago

So maybe a nice solution would be to have init with a

FixedGANSwitcher(n_crit=1, n_gen=1) 

and have the case with n_crit=5 in

def wgan():

Change: https://github.com/fastai/fastai/blob/301016c5d3de2bdb5269121bd0716538d85f7409/fastai/vision/gan.py#L317-L320

To:

def wgan(cls, dls, generator, critic, switcher=None, clip=0.01, switch_eval=False, **kwargs):
    "Create a WGAN from `data`, `generator` and `critic`."
    if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
    return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)