victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.7k stars 1.08k forks source link

Any support for multiple loss functions? #93

Closed itsnamgyu closed 3 years ago

itsnamgyu commented 3 years ago

I'm using a simple aggregated loss function for now, but it'd be great if it was supported in the config, logging, etc.

Happy to contribute if you have any plans for this.

deeperlearner commented 3 years ago

You can refer to my repo. I let the config read multiple instances of dataset, dataloader, model, loss, optimizer and lr_scheduler. Like the following structure:

    "losses": {
        "loss1": {
            "is_ftn": true,
            "type": "nll_loss"
        },
        "loss2": {
            "balanced": true,
            "type": "BCEWithLogitsLoss"
        },
        "loss3": {
            "type": "MSELoss"
        }
    }
itsnamgyu commented 3 years ago

@deeperlearner Oh sorry, I was referring to using multiple losses at the same time.

deeperlearner commented 3 years ago

in trainers/trainer.py:

def __init__(self, ...):
        ...
        self.loss1 = self.losses['loss1']
        self.loss2 = self.losses['loss2']
        self.loss3 = self.losses['loss3']

You can customize the backward part in _train_epoch. For example,

total_loss = self.loss1(output, target) + self.loss2(output, target) + self.loss3(output, target)
total_loss.backward()
SunQpark commented 3 years ago

Define loss_total which sums up the auxiliary losses.

def first_loss(output, target):
    ...
    return loss

def second_loss(output, target):
    ...
    return loss

def loss_total(output, target):
    loss_1 = first_loss(output, target)
    loss_2 = second_loss(output, target)
    return loss_1 + loss2

and use loss_total as the loss function in config file and trainer.py In this way, you should be able to use all the utils(tensorboard, MetricTracker, etc) as is.

If you want to print or log each of the losses, you can modify trainer.py by either using @deeperlearner's solution, or returning the intermediate values as following.

def loss_total(output, target):
    ...
    log = dict(
        loss_1=loss_1.item(),
        loss_2=loss_2.item(),
    )
    return loss_1 + loss2, log
itsnamgyu commented 3 years ago

Thanks, both look like good solutions. I guess I'll have to modify a lot of code from the template.

@SunQpark are you interested in contributions to the template itself, or do you plan to leave the template as-is and allow people to make their own modifications? Just wondering if you have plans for further development.