Closed minghao4 closed 5 years ago
The problem is that you need to copy the loss args to self.device which is defined in the Trainer. So it's not clear how to use functools.partial here
I added two commits handling this issue with a new init_fnc
method in config parser. Can you check that changes?
This does not address the problem of having to move the arguments to GPU though
My opinion is that there should be a BaseLoss
. It should define a to(device) method and a __call__
method. So loss functions should override this. Then you also don't need the extra function in config because now it's a class and arguments are passed through __init__
class NLLLoss(BaseLoss):
def __init__(self):
pass
def to(self, device):
pass
def __call__(self, output, target):
return F.nll_loss(output, target)
An example of how I would implement loss functions. This way you don't need anything extra in parse_config.py
and you don't need partial
. Also if the loss take parameters, __init__
can store them. In train(epoch)
in BaseTrainer
we can also call self.loss.to(self.device)
to transfer the arguments to the correct device if the loss takes any extra arguments in __init__
(e.g weights for weighted loss)
If you want your loss to be a class, use loss modules defined in torch.nn
like nn.NLLLoss instead of functions defined in torch.nn.functional
. If you want more customization, you can make a custom loss function inheriting nn.Module
.
Using base classes should be decided more carefully, since most of time they restricts user's choice of coding style. Closing this PR.
If you want your loss to be a class, use loss modules defined in
torch.nn
like nn.NLLLoss instead of functions defined intorch.nn.functional
. If you want more customization, you can make a custom loss function inheritingnn.Module
.
You are right. nn.Module
seems the way to go here. Although it seems to make sense to make this the default choice here too. e.g in train.py
we could have something like:
import model.loss as module_loss
import torch.nn.modules.loss as torch_loss
try:
loss = config.initialize('loss', torch_loss)
except AttributeError:
loss = config.initialize('loss', module_loss)
i.e First look in the built-in losses. If not found then look in loss.py
. And we can make the custom loss functions inherit nn.Module
as you said
Hey guys,
Thank you for the great template. One of things I noticed when implementing custom loss and metric functions is that if I ever wanted to pass additional arguments (that are not the output and target tensors) to the function, there are a lot of locations within the code that would need to be modified.
This PR allows users to set additional arguments for the loss and metric functions directly in the config JSON file.
loss
is identical to theoptimizer
andlr_scheduler
sections.metrics
section has been changed.This is my first time submitting a PR, hopefully it's in the correct format :). Let me know what you think.
Best, Michael