victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.7k stars 1.08k forks source link

Adding Configurable Arguments to Custom Loss and Metric(s) Functions #49

Closed minghao4 closed 5 years ago

minghao4 commented 5 years ago

Hey guys,

Thank you for the great template. One of things I noticed when implementing custom loss and metric functions is that if I ever wanted to pass additional arguments (that are not the output and target tensors) to the function, there are a lot of locations within the code that would need to be modified.

This PR allows users to set additional arguments for the loss and metric functions directly in the config JSON file.

This is my first time submitting a PR, hopefully it's in the correct format :). Let me know what you think.

Best, Michael

ag14774 commented 5 years ago

The problem is that you need to copy the loss args to self.device which is defined in the Trainer. So it's not clear how to use functools.partial here

SunQpark commented 5 years ago

I added two commits handling this issue with a new init_fnc method in config parser. Can you check that changes?

ag14774 commented 5 years ago

This does not address the problem of having to move the arguments to GPU though

ag14774 commented 5 years ago

My opinion is that there should be a BaseLoss. It should define a to(device) method and a __call__ method. So loss functions should override this. Then you also don't need the extra function in config because now it's a class and arguments are passed through __init__

ag14774 commented 5 years ago
class NLLLoss(BaseLoss):
    def __init__(self):
        pass

    def to(self, device):
        pass

    def __call__(self, output, target):
        return F.nll_loss(output, target)

An example of how I would implement loss functions. This way you don't need anything extra in parse_config.py and you don't need partial. Also if the loss take parameters, __init__ can store them. In train(epoch) in BaseTrainer we can also call self.loss.to(self.device) to transfer the arguments to the correct device if the loss takes any extra arguments in __init__ (e.g weights for weighted loss)

SunQpark commented 5 years ago

If you want your loss to be a class, use loss modules defined in torch.nn like nn.NLLLoss instead of functions defined in torch.nn.functional. If you want more customization, you can make a custom loss function inheriting nn.Module.

SunQpark commented 5 years ago

Using base classes should be decided more carefully, since most of time they restricts user's choice of coding style. Closing this PR.

ag14774 commented 5 years ago

If you want your loss to be a class, use loss modules defined in torch.nn like nn.NLLLoss instead of functions defined in torch.nn.functional. If you want more customization, you can make a custom loss function inheriting nn.Module.

You are right. nn.Module seems the way to go here. Although it seems to make sense to make this the default choice here too. e.g in train.py we could have something like:

    import model.loss as module_loss
    import torch.nn.modules.loss as torch_loss
    try:
        loss = config.initialize('loss', torch_loss)
    except AttributeError:
        loss = config.initialize('loss', module_loss)

i.e First look in the built-in losses. If not found then look in loss.py. And we can make the custom loss functions inherit nn.Module as you said