fastai / fastai_dev

fast.ai early development experiments
Apache License 2.0
641 stars 351 forks source link

Is there a way to select different loss functions for different targets? #279

Closed tacchinotacchi closed 5 years ago

tacchinotacchi commented 5 years ago

If I understand correctly, there's no way, in the library, to select different loss functions for different targets.

For my use case, I have two types of labels:

  1. Logits from a teacher model, available only for the training set
  2. Ground-truth labels, available only for the validation set

I need to apply MSE loss to the logits during training steps, and cross entropy (+accuracy) to the ground-truth labels as metrics on the validation set.

For now, I solved the problem with this snippet:

class TargetLoss():
    def __init__(self, loss_func, target_index, name=None):
        store_attr(self, "loss_func,target_index,name")
    def __call__(self, x, *targets, **kwargs):
        return self.loss_func(x, targets[self.target_index], **kwargs)
    @property
    def __name__(self):
        return self.name

train_loss = TargetLoss(MSELossFlat(), 1)
valid_loss = TargetLoss(CrossEntropyLossFlat(), 0, name="valid_loss")
valid_accuracy = TargetLoss(accuracy, 0, name="valid_accuracy")
learn = Learner(dbch, model, loss_func=train_loss, opt_func=opt_func, cb_funcs=cb_funcs, metrics=[valid_loss, valid_accuracy])

But this way I get two columns named valid_loss: the first is the one made by the learner, which in my case is applied to the wrong targets, and the second one is the one I specified. It also doesn't implement decodes and activation.

Do you think there should be an in-library mechanism for handling these situations? Or is there already a way and I didn't realize?

sgugger commented 5 years ago

You should use a Callback to switch the loss function between train and eval, it seems the easiest way to deal with your problem.

tacchinotacchi commented 5 years ago

Thanks, I solved by using Callback. Still have to use a wrapper, but I guess if needed I could wrap decodes and activation too.

class TargetLoss():
    def __init__(self, loss_func, target_index, name=None):
        store_attr(self, "loss_func,target_index,name")
    def __call__(self, x, *targets, **kwargs):
        return self.loss_func(x, targets[self.target_index], **kwargs)
    @property
    def __name__(self):
        return self.name

class SwitchLoss(Callback):
    def __init__(self, train_loss, valid_loss):
        store_attr(self, "train_loss,valid_loss")
    def begin_train(self):
        self.learn.loss_func = self.train_loss
    def begin_validate(self):
        self.learn.loss_func = self.valid_loss    

train_loss = TargetLoss(MSELossFlat(), 1)
valid_loss = TargetLoss(CrossEntropyLossFlat(), 0, name="valid_loss")
valid_accuracy = TargetLoss(accuracy, 0, name="valid_accuracy")

opt_func = partial(Adam, wd=0.1, eps=1e-7)
cb_funcs = [partial(RNNTrainer, alpha=2, beta=1), partial(SwitchLoss, train_loss, valid_loss)]

learn = Learner(dbch, model, loss_func=train_loss, opt_func=opt_func, cb_funcs=cb_funcs, metrics=[valid_accuracy])