libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.84k stars 178 forks source link

Inconsistent results when using FFCV dataloaders and PyTorch dataloaders #196

Closed mohamad-amin closed 2 years ago

mohamad-amin commented 2 years ago

Hello,

I've just started using FFCV recently and tried to change my model training pipeline to use FFCV dataloaders instead of PyTorch dataloaders, but I'm getting inconsistent results between these two.

Using a 2 block layer-ed WideResNet (I can share the detailed model code if needed), I'm trying to train the model on a small proportion (1K datapoints) of the SVHN dataset. Previously, using the PyTorch dataloader, I got ~88% test accuracy using a fixed seed for the 1K subset. I tried to change my pipeline to use FFCV as suggested in the document through the following steps:

  1. First, I wrote the dataset in the appropriate format:

    writer = DatasetWriter(write_path, {
    # Tune options to optimize dataset size, throughput at train-time
    'image': RGBImageField(
        write_mode='raw'
    ),
    'label': IntField()
    })
    writer.from_indexed_dataset(dataset)
  2. Then, I changed my dataloader from:

    dataloader = DataLoader(dataset, batch_size=batch_size,
                        sampler=SubsetRandomSampler(labeled_set),
                        num_workers=num_workers)

    to:

    dataloader = Loader(dataset_path, batch_size=batch_size, drop_last=False,
                    order=OrderOption.RANDOM, num_workers=num_workers,
                    pipelines=pipeline, os_cache=True, distributed=False, indices=labeled_set)

where my pipeline argument is defined as:

mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
mean = [item * 255 for item in mean]
std = [item * 255 for item in std]

label_pipeline = [IntDecoder(), ToTensor(), ToDevice(torch.device('cuda')), Squeeze()]

image_pipeline = [SimpleRGBImageDecoder()]
image_pipeline.extend([
    ToTensor(),
    ToDevice(torch.device('cuda')),
    ToTorchImage(),
    Convert(torch.float32),
    transforms.Normalize(mean, std)
])

Previously, when using the PyTorch dataloader, I passed these transformations as transform argument of my torchvision.datasets.SVHN dataset:

mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])

I'm using my ffcv.Loader the same way as I used my torch.DataLoader in the training process, except that obviously I don't move the batches to GPU in the training loop. However, I'm getting 25% test accuracy using the FFCV data loader, which is a bit weird to me. Of course I don't expect the exact same test accuracy as there could be some variations in how PyTorch and FFCV apply transformations, but I would expect some similar training outcomes at least.

Is my expectation valid? If so, is there anything that I'm doing wrong in my migration from PyTorch to FFCV? Like, am I using the corresponding transformers in FFCV correctly and in order? I would appreciate if you could help me figure out why there is such discrepancy between these two dataloaders.

Thanks!

mohamad-amin commented 2 years ago

I noticed something that looks unexpected to me. I ran the following code to investigate the test set labels using the two different loaders.

First FFCV:

labels = []
for data in dataloader:
    labels.append(data[1])
torch.cat(labels).float().mean()
# outputs tensor(3.7744, device='cuda:0')
Counter(torch.cat(labels).tolist())
# outputs:
# Counter({2: 4192,
#          1: 5042,
#          0: 1695,
#          7: 2038,
#          5: 2500,
#          4: 2531,
#          8: 1699,
#          6: 1896,
#          3: 2830,
#          9: 1609})

Then, PyTorch:

labels = []
for data in dataloader:
    labels.append(data[1])
torch.cat(labels).float().mean()
# outputs tensor(3.7525)
Counter(torch.cat(labels).tolist())
# outputs:
# Counter({5: 2384,
#          2: 4149,
#          1: 5099,
#          0: 1744,
#          6: 1977,
#          9: 1595,
#          8: 1660,
#          3: 2882,
#          4: 2523,
#          7: 2019})

Seems like the labels are not exactly the same. This is a bit worrisome, but it's probably that I'm doing something wrong, not that the decoders are messing up (at least I hope). But I'm not doing any complex transformations for the labels, thus I'm not too sure why this is happening.

(P.S: I checked the original torch data and the computed statistics on the dataset are exactly the same as the computed ones on the data retrieved from the PyTorch dataloader)

UPDATE: The FFCV loader results in different label counts and means (the way I computed the mean) when using different OrderOptions.

mohamad-amin commented 2 years ago

Apparently, this issue also appears in the documents' example: https://docs.ffcv.io/ffcv_examples/cifar10.html

labels = []
for data in loaders['test']:
    labels.append(data[1])
Counter(ch.cat(labels).tolist())

Outputs:

Counter({5: 939,
         9: 1066,
         1: 952,
         2: 910,
         0: 934,
         3: 1062,
         8: 1060,
         4: 1045,
         6: 1037,
         7: 995})

while Counter(datasets['test'].targets) outputs:

Counter({3: 1000,
         8: 1000,
         0: 1000,
         6: 1000,
         1: 1000,
         9: 1000,
         5: 1000,
         7: 1000,
         4: 1000,
         2: 1000})
GuillaumeLeclerc commented 2 years ago

Hello,

You are doing something wrong:

labels = []
for data in dataloader:
    labels.append(data[1])

Here you do not "own" data. It's FFCV's. You are allowed to use it but only during the current iteration. When you are done with your iteration, FFCV can do whatever it wants with it you are supposed to be done reading data. By adding it to labels you are keeping a reference on something you do not own. Internally, to save memory and avoid regular garbage collection, (unlike Pytorch) FFCV will reuse arrays that are not needed anymore. data can and will therefore be overwritten once your iteration is finished. If you want to keep something for later you have to place it in an array you own yourself (you can clone it for example).

In your example your array is basically the last 5 batches repeated over and over (you can check for reference equality on the elements of your array and you will realize they point to the same memory region). As a result you don't get the proper proportion of labels and if you were to capture the images too in a similar way you would also see the same images over and over)

andrewilyas commented 2 years ago

Closing this!