rdevon / cortex

A machine learning library for PyTorch
BSD 3-Clause "New" or "Revised" License
92 stars 25 forks source link

LossHandler does not handle two different sets on the same network key #198

Open tsirif opened 5 years ago

tsirif commented 5 years ago

I have come to realize by the implementation in built_ins/gan.py that two separate calls are done in order to update the discriminator. The first optimizer update utilizes the gan loss and the second the gradient penalty.

This is because LossHandle will overwrite (s1) any value for a specific network key, which is an inconvenient behaviour. I can see in s2 and s3 that there was an intention to implement a convenient behaviour, but it seems that it has not been done.

I propose the following, tell me what you think:

rdevon commented 5 years ago

So this was supposed to be default behavior at some point, but it was removed. That said, the way losses and results are handled in the backend are quite messy, and probably need to be refactored. What about for adding losses, there is an "add_loss" method instead? Then all losses are available via a dictionary.

tsirif commented 5 years ago

I have refactored lots of stuff to comply with all things intented (incl. isolating a routine's contribution to losses and results and reporting them prefixed to all_epoch_losses - which I think is the original intention), without breaking the API for ModelPlugin developers. I am going to make a separate PR for this as well after ICML.

tsirif commented 5 years ago

Btw as of now I am familiar with the whole codebase

rdevon commented 5 years ago

OK! Looking forward to the PR!