chainer / chainer-chemistry

Chainer Chemistry: A Library for Deep Learning in Biology and Chemistry
MIT License
629 stars 130 forks source link

Sample weights #244

Open allchemist opened 6 years ago

allchemist commented 6 years ago

Hello,

is there a designed way to train models with weighted samples? especially to set specific weights to each output in multi-output tasks.

Thanks

corochann commented 6 years ago

I think you can multiply weight to the loss. what is your task, regression or classification? what kind of loss function currently you are using?

allchemist commented 6 years ago

both regression and classification tasks, with use mean_squared_error and sigmoid_cross_entropy loss functions. Do i need to modify my loss functions for this trick? or i can provide a modified metric in "metrics_fun" argument?

allchemist commented 6 years ago

I suppose metrics do not participate in backward pass, so i have to modify loss functions

mottodora commented 6 years ago

Yes. You need to modify functions to to set specific weights to each output in multi-output tasks.

I'm planning to implement task_weight argument in mean_squared_error and mean_absolute_error. Maybe it is useful for you. https://github.com/pfnet-research/chainer-chemistry/blob/master/chainer_chemistry/functions/mean_squared_error.py#L14-L15

allchemist commented 6 years ago

Weighted loss function is an easy part. Seems like sending only two arguments to loss function is hard-coded inside models/prediction/regressor.py and classifier.py, and also in GraphConvPredictor definition (which is not part of chainer_chemistry though).

    if self.weighted:
        loss_args = self.y, t, args[-1]
    else:
        loss_args = self.y, t

    self.loss = self.lossfun(*loss_args)   

This works fine, but i do not fully understand lines 91-109 (regressor.py), so i'm curious if it is right way instead of calling self.loss = self.lossfun(self.y, t)

If we insert weights into dataset like this:

    ds._datasets = ds._datasets[0], ds._datasets[1], ds._datasets[2], sample_weights

then args[-1] above is batch of weights