i-machine-think / diagNNose

diagNNose is a Python library that facilitates a broad set of tools for analysing hidden activations of neural models.
https://diagnnose.readthedocs.io
MIT License
81 stars 8 forks source link

Allow class_weights to be set for unbalanced data #67

Open jumelet opened 3 years ago

jumelet commented 3 years ago

Old code for this was as follows, supposedly didn't comply with skorch

def _set_class_weights(self, labels: Tensor) -> None:
        classes, class_freqs = torch.unique(labels, return_counts=True)
        norm = class_freqs.sum().item()
        class_weight = {
            classes[i].item(): class_freqs[i].item() / norm
            for i in range(len(class_freqs))
        }
        self.classifier.class_weight = class_weight