Open jumelet opened 3 years ago
Old code for this was as follows, supposedly didn't comply with skorch
skorch
def _set_class_weights(self, labels: Tensor) -> None: classes, class_freqs = torch.unique(labels, return_counts=True) norm = class_freqs.sum().item() class_weight = { classes[i].item(): class_freqs[i].item() / norm for i in range(len(class_freqs)) } self.classifier.class_weight = class_weight
Old code for this was as follows, supposedly didn't comply with
skorch