jlevy44 / PathFlowAI

A High-Throughput Workflow for Preprocessing, Deep Learning Analytics and Interpretation in Digital Pathology
https://jlevy44.github.io/PathFlowAI/
MIT License
40 stars 8 forks source link

Dask DataLoader Speed (2.0 feature) #18

Open jlevy44 opened 4 years ago

jlevy44 commented 4 years ago

Background, dataloader slows down over time, especially when using a large number of slides; data that is persistent in memory loads quickly (case for very small number sslides), but not when training from large number of slides; issues with having .compute() within getitem(), yet needing to take into account data augmentations (albumentations) for the mask of the image for semantic segmentation task when loading data, which can make the dataloading operation if more daskified a bit more complex:

Issue is with the getitem, when the data is loaded, it passes quickly through the DL model.

Potentially nice ideas:

@lvaickus , can you comment more here?

@sumanthratna

sumanthratna commented 4 years ago

I'll need to look into this more, but https://docs.dask.org/en/latest/caching.html might help.

It might also help if we daskify the entire model, so we don't need to compute inputs from the DataLoader each time. Ideally, we'd be able to use model.loss.compute() at each iteration to update the weights and biases accordingly. If I'm remembering correctly, this will essentially prune the model (pruning really isn't the right word here) and only calculate the loss instead of the final output. Of course, this comes with the disadvantage that you can't see the outputs for each image. If the user enables verbosity, we could compute the final output.

jlevy44 commented 4 years ago

That's an interesting approach. I will have to think more. I think the difficulty is when you want to randomly read from a few to tens of TB of slides at a time, and doing so quickly. Once it's loaded, it passes through the model quickly, but you are correct in that many calls to compute from within the DataLoader is not something that we should continue to do. Another solution could be a stacked hdf5 array like many others do and just memory map from there and each slice is one-to-one correspondent with the location annotations in the DB file, but the drawback there is the creation of a separate stack for each patch size, which is less than ideal when scaling up.