Open songt96 opened 5 years ago
I am also curious about whether it can be extended to multi-label classification.
Thank you for sharing! This is really good work!
Here's a balanced sampler for multilabel datasets https://github.com/issamemari/pytorch-multilabel-balanced-sampler
@issamemari link is broken
@issamemari link is broken
Should be fixed now
But doesn't it work out of the box for multi-label classification? One of the examples in this repository is the MNIST dataset that is a multi-label classification problem itself.
The only thing to make it work is to define a suitable callback_get_label
function:
import torch
from torch.utils.data import DataLoader, Dataset
from torchsampler import ImbalancedDatasetSampler
class DataPack(Dataset):
"""Class to generate a suitable structure for Dataloader."""
def __init__(self, data, target):
"""Init, data are the features and target is the ground truth."""
self.data = torch.FloatTensor(data)
# may need to change targets to a LongTensor for one-hot vectors
self.targets = torch.FloatTensor(target)
def __len__(self):
"""Get length."""
return len(self.data)
def __getitem__(self, index):
"""Access and instance."""
data_val = self.data[index]
target = self.targets[index]
return data_val, target
# X_train is the feature matrix (in some matrix form; e.g., numpy)
# y_train are the labels/classes in some list form
train_dataset = DataPack(X_train, y_train)
batch_size=200
# the labels are numbers
trainloader = DataLoader(
train_dataset,
sampler=ImbalancedDatasetSampler(
train_dataset, callback_get_label=lambda x, i: x[i][1].item()
),
batch_size=batch_size,
)
A hashable type must be retrieved from the callback_get_label
as a label (for instance, a tuple). In case you have one-hot encoded classes:
trainloader = DataLoader(
train_dataset,
sampler=ImbalancedDatasetSampler(
train_dataset, callback_get_label=lambda x, i: tuple(x[i][1].tolist())
),
batch_size=batch_size,
)
Would that be enough?
A note here: from what I understand the Sampler will be considering each combination of labels as a kind of meta-label.
This could be a highly combinatorial setting, with a risk that each combination of labels might be rare, when actually only a fraction of the labels are rare individually.
I wonder if a solution that would e.g. sum the inverse-frequency of each label individually would work better.
This is a Goood work! But, I want to find out: Dose this work for mulit label classification? Such as: BCELoss in pytorch. THANKS.