Jannoshh / simple-sam

Sharpness-Aware Minimization for Efficiently Improving Generalization
MIT License
41 stars 9 forks source link

Too many values to unpack when including sample_weights #6

Closed doubleapple123 closed 3 years ago

doubleapple123 commented 3 years ago

If sample_weights are included, I think that data has a len() of 3, so I think it should be x, y, sample_weight = data. My question is how do I properly include the sample_weight in the training step? I tested not using the sample_weights and it gives wrong initial loss. I followed the tensorflow guide and included the sample weight as such loss = self.compiled_loss(y, y_pred, sample_weight=sample_weight, regularization_losses=self.losses) for the 2 calls to self.compiled_loss in the 2 separate with tf.GradientTape() as tape: blocks. Is this the correct implementation for both class_weights and sample_weights?

Jannoshh commented 3 years ago

It seems to be the correct implementation if you also add the sample_weights to self.compiled_metrics.

I just added support for sample weights to the implementation but have not tested it exhaustively yet.

doubleapple123 commented 3 years ago

cool, thanks.