gmberton / CosPlace

Official code for CVPR 2022 paper "Rethinking Visual Geo-localization for Large-Scale Applications"
MIT License
288 stars 59 forks source link

Why is iter used here during training iterations instead of directly using the dataloader? #32

Closed wpumain closed 1 year ago

wpumain commented 1 year ago

https://github.com/gmberton/CosPlace/blob/cea7de243abe1a59c69f1087b1b8001d01262c59/train.py#L114


 dataloader = commons.InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers,
                                            batch_size=args.batch_size, shuffle=True,
                                            pin_memory=(args.device == "cuda"), drop_last=True)

    dataloader_iterator = iter(dataloader)
    model = model.train()

    epoch_losses = np.zeros((0, 1), dtype=np.float32)
    for iteration in tqdm(range(args.iterations_per_epoch), ncols=100):
           images, targets, _ = next(dataloader_iterator) 

Why not do it directly like this?


dataloader = commons.InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers,
                                            batch_size=args.batch_size, shuffle=True,
                                            pin_memory=(args.device == "cuda"), drop_last=True)

       model = model.train()

    epoch_losses = np.zeros((0, 1), dtype=np.float32)
    for images, targets, _ in tqdm(dataloader, ncols=100):
gmberton commented 1 year ago

You can use the dataloader, but then it would continue until it runs out of data (or you would need to use a break after N iterations). Using an iterator is more dataset agnostic, and it is clearer how long the "epoch" is going to last (and using a full epoch over the whole dataset could take a day, so that is not an option). Anyway, doing this

for iteration, (images, targets, _) in tqdm(enumerate(dataloader), ncols=100, total=args.iterations_per_epoch):
    if iteration >= args.iterations_per_epoch:
        break

is equivalent to this

dataloader_iterator = iter(dataloader)
for iteration in tqdm(range(args.iterations_per_epoch), ncols=100):
       images, targets, _ = next(dataloader_iterator) 

I just find the second example cleaner than the first

wpumain commented 1 year ago

https://github.com/gmberton/CosPlace/blob/804fc8c65871cab81f8efbc2c0cff04f1ab78534/train.py#L106

Why not use torch.utils.data.dataloader directly?

gmberton commented 1 year ago

It is also okay to use the standard DataLoader class, but using an InfiniteDataLoader ensures that if there are not enough samples the __iter__ will not raise a StopIteration. This could happen when changing the parameters within the code (e.g. using higher values for parameters --iterations_per_epoch, --batch_size, --L, --N might lead to this issue when using a standard DataLoader)

wpumain commented 1 year ago
dataloader_iterator = iter(dataloader)
for iteration in tqdm(range(args.iterations_per_epoch), ncols=100):
       images, targets, _ = next(dataloader_iterator) 

What you mean is that with this approach, you can train on a portion of the data in each epoch without having to train on all the data, and only need to go through args.iterations_per_epoch iterations?

https://github.com/gmberton/CosPlace/blob/543eaa5995c64d7e2ddbe2d8a1587ec4b2964728/train.py#L106

dataloader = commons.InfiniteDataLoader(groups[current_group_num], num_workers=args.num_workers,
                                            batch_size=args.batch_size, shuffle=True,
                                            pin_memory=(args.device == "cuda"), drop_last=True)

    dataloader_iterator = iter(dataloader)

When I use the small dataset, it has only one group containing 5,965 classes. If args.batch_size is set to 8, there will be 745 BatchSamplers in the dataloader, meaning the dataloader can iterate 745 times to train on all the data. The statement "images, targets, _ = next(dataloader_iterator)" can only iterate 745 times as well. However, in the training loop, args.iterations_per_epoch is set to 10,000, indicating 10,000 iterations are expected, but in reality, only 745 iterations are possible because the statement "images, targets, _ = next(dataloader_iterator)"cannot iterate further after 745 iterations. If this is the case, your intention to train on only a portion of data in each epoch may not be achieved.

wpumain commented 1 year ago
class InfiniteDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dataset_iterator = super().__iter__()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            self.dataset_iterator = super().__iter__()
            batch = next(self.dataset_iterator)
        return batch

using an InfiniteDataLoader ensures that if there are not enough samples the iter will not raise a StopIteration.

When all samples are used up, the __iter__ will start iterating from the first batch again. Is that correct?

Whenargs.iterations_per_epoch=10,000 and the DataLoader has 800 BatchSamplers, after completing the iteration of the 800th batch, the statement "images, targets, _ = next(dataloader_iterator)" in the for loop will start iterating from the first batch again. In this case, a single for training loop will train on the data multiple times, and eventually, one epoch will complete 10,000 iterations of training.

When args.iterations_per_epoch=10,000 and the DataLoader has20,000 BatchSamplers, in this case, a single for training loop will only perform 10,000 iterations of training, training only a portion of the data.

This ensures that regardless of whether the number of BatchSamplers is greater than args.iterations_per_epoch, the for training loop in one epoch will always iterate args.iterations_per_epoch times.

Right?

wpumain commented 1 year ago

Based on this setup, training on8 groups of the processed dataset can lead to CosPlace achieving state-of-the-art performance?

gmberton commented 1 year ago

All your assumptions are correct, and the answer to all your questions is yes.

wpumain commented 1 year ago

Think you for your help