facebookresearch / DomainBed

DomainBed is a suite to test domain generalization algorithms
MIT License
1.42k stars 298 forks source link

underlying_length of train_loader is 1 #1

Closed ChongZhangZC closed 4 years ago

ChongZhangZC commented 4 years ago

I noticed that you use WeightedRandomSampler when training, but it lead to wrong steps_per_epoch and wrong 'epoch' in logs. When length==self.INFINITE, self.underlying_length becomes 1.

class FastDataLoader(object):
    INFINITE = 'infinite'
    EPOCH = 'epoch'
    """DataLoader wrapper with slightly improved speed by not respawning worker
    processes at every epoch."""
    def __init__(self, dataset, weights, batch_size, num_workers,
        length=EPOCH):
        super(FastDataLoader, self).__init__()

        if length == self.EPOCH and weights != None:
            raise Exception("Specifying sampling weights with length=EPOCH is "
                "illegal: every datapoint would eventually get sampled exactly "
                "once.")

        if weights == None:
            weights = torch.ones(len(dataset))

        if length == self.INFINITE:
            batch_sampler = torch.utils.data.BatchSampler(
                torch.utils.data.WeightedRandomSampler(weights,
                    replacement=True,
                    num_samples=batch_size),
                batch_size=batch_size,
                drop_last=True)
            print(length==self.INFINITE)
        else:
            batch_sampler = torch.utils.data.BatchSampler(
                torch.utils.data.SequentialSampler(dataset),
                batch_size=batch_size,
                drop_last=False
            )
        self.length = length
        self.underlying_length = len(batch_sampler)

ENV: Python==3.7 Pytorch==1.3.1 Torchvision==0.4.2

igul222 commented 4 years ago

Thanks for reporting! I've taken a shot at fixing this in #5