Closed tacchinotacchi closed 5 years ago
You should use a Callback
to switch the loss function between train and eval, it seems the easiest way to deal with your problem.
Thanks, I solved by using Callback
. Still have to use a wrapper, but I guess if needed I could wrap decodes
and activation
too.
class TargetLoss():
def __init__(self, loss_func, target_index, name=None):
store_attr(self, "loss_func,target_index,name")
def __call__(self, x, *targets, **kwargs):
return self.loss_func(x, targets[self.target_index], **kwargs)
@property
def __name__(self):
return self.name
class SwitchLoss(Callback):
def __init__(self, train_loss, valid_loss):
store_attr(self, "train_loss,valid_loss")
def begin_train(self):
self.learn.loss_func = self.train_loss
def begin_validate(self):
self.learn.loss_func = self.valid_loss
train_loss = TargetLoss(MSELossFlat(), 1)
valid_loss = TargetLoss(CrossEntropyLossFlat(), 0, name="valid_loss")
valid_accuracy = TargetLoss(accuracy, 0, name="valid_accuracy")
opt_func = partial(Adam, wd=0.1, eps=1e-7)
cb_funcs = [partial(RNNTrainer, alpha=2, beta=1), partial(SwitchLoss, train_loss, valid_loss)]
learn = Learner(dbch, model, loss_func=train_loss, opt_func=opt_func, cb_funcs=cb_funcs, metrics=[valid_accuracy])
If I understand correctly, there's no way, in the library, to select different loss functions for different targets.
For my use case, I have two types of labels:
I need to apply MSE loss to the logits during training steps, and cross entropy (+accuracy) to the ground-truth labels as metrics on the validation set.
For now, I solved the problem with this snippet:
But this way I get two columns named
valid_loss
: the first is the one made by the learner, which in my case is applied to the wrong targets, and the second one is the one I specified. It also doesn't implementdecodes
andactivation
.Do you think there should be an in-library mechanism for handling these situations? Or is there already a way and I didn't realize?