libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.79k stars 180 forks source link

Segmentation Fault when using FFCV alongside PyTorch DataLoader (Possible bug) #292

Open numpee opened 1 year ago

numpee commented 1 year ago

I've just encountered a weird issue where the PyTorch DataLoader will segfault in the following conditions:

Basically, the FFCV loader is bound to 2 workers max until it causes a segfault in the PyTorch DataLoader. Any ideas why this might occur?

Please check out the code snippet below to reproduce the error.

import numpy as np
import torch
from ffcv import Loader
from ffcv.fields.basics import IntDecoder
from ffcv.fields.rgb_image import SimpleRGBImageDecoder
from ffcv.loader import OrderOption
from ffcv.transforms import RandomHorizontalFlip, ToTensor, ToDevice, ToTorchImage, NormalizeImage, Squeeze
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor as TT
from tqdm import tqdm

def main():
    pytorch_workers = 8
    ffcv_workers = 3    # Code works when ffcv_workers <= 2
    beton_path = 'cifar10_val.beton'    # Your FFCV .beton file here
    image_pipeline = [SimpleRGBImageDecoder(), RandomHorizontalFlip(), ToTensor(),
                      ToDevice(torch.device(0), non_blocking=True), ToTorchImage(),
                      NormalizeImage(mean=np.array([0.5, 0.5, 0.5]) * 255, std=np.array([1., 1., 1., ]) * 255,
                                     type=np.float32)]
    label_pipeline = [IntDecoder(), ToTensor(), Squeeze(), ToDevice(torch.device(0), non_blocking=True)]
    loader = Loader(beton_path, batch_size=16, num_workers=ffcv_workers, order=OrderOption.SEQUENTIAL, os_cache=True,
                    drop_last=False, pipelines={'image': image_pipeline, 'label': label_pipeline})
    dataset = MNIST(root="./", download=True, transform=TT())
    pytorch_dataloader = DataLoader(dataset, batch_size=16, num_workers=pytorch_workers, pin_memory=True)

    for i, *_ in enumerate(tqdm(pytorch_dataloader)):
        print(i)
        if i == 10:
            break

if __name__ == "__main__":
    main()